# Copyright (c) 2023-2026 Renata Hodovan, Akos Kiss.
#
# Licensed under the BSD 3-Clause License
# <LICENSE.rst or https://opensource.org/licenses/BSD-3-Clause>.
# This file may not be copied, modified, or distributed except
# according to those terms.
import json
import logging
import pickle
import struct
from math import inf
from typing import Any
import flatbuffers
from ..runtime import Rule, RuleSize, UnlexerRule, UnparserRule, UnparserRuleAlternative, UnparserRuleQuantified, UnparserRuleQuantifier
from .fbs import CreateFBRuleSize, FBRule, FBRuleAddAltIdx, FBRuleAddChildren, FBRuleAddIdx, FBRuleAddImmutable, FBRuleAddName, FBRuleAddSize, FBRuleAddSrc, FBRuleAddStart, FBRuleAddStop, FBRuleAddType, FBRuleEnd, FBRuleStart, FBRuleStartChildrenVector, FBRuleType
logger = logging.getLogger(__name__)
[docs]
class TreeCodec:
"""
Abstract base class of tree codecs that convert between trees and bytes.
"""
[docs]
def encode(self, root: Rule) -> bytes:
"""
Encode a tree into an array of bytes.
Raises :exc:`NotImplementedError` by default.
:param root: Root of the tree to be encoded.
:return: The encoded form of the tree.
"""
raise NotImplementedError()
[docs]
def decode(self, data: bytes) -> Rule | None:
"""
Decode a tree from an array of bytes.
Raises :exc:`NotImplementedError` by default.
:param data: The encoded form of a tree.
:return: Root of the decoded tree.
"""
raise NotImplementedError()
[docs]
class AnnotatedTreeCodec(TreeCodec):
"""
Abstract base class of tree codecs that can encode and decode extra data
(i.e., annotations) when converting between trees and bytes.
"""
[docs]
def encode(self, root: Rule) -> bytes:
"""
Encode a tree without any annotations. Equivalent to calling
:meth:`encode_annotated` with ``annotations=None``.
"""
return self.encode_annotated(root, None)
[docs]
def encode_annotated(self, root: Rule, annotations: Any) -> bytes:
"""
Encode a tree and associated annotations into an array of bytes.
Raises :exc:`NotImplementedError` by default.
:param root: Root of the tree to be encoded.
:param annotations: Data to be encoded along the tree. No assumption
should be made about the structure or the contents of the data, it
should be treated as opaque.
:return: The encoded form of the tree and its annotations.
"""
raise NotImplementedError()
[docs]
def decode(self, data: bytes) -> Rule | None:
"""
Decode only the tree from an array of bytes without the associated
annotations. Equivalent to calling :meth:`decode_annotated` and keeping
only the first element of the returned tuple.
"""
root, _ = self.decode_annotated(data)
return root
[docs]
def decode_annotated(self, data: bytes) -> tuple[Rule | None, Any]:
"""
Decode a tree and associated annotations from an array of bytes.
Raises :exc:`NotImplementedError` by default.
:param data: The encoded form of a tree and its annotations.
:return: Root of the decoded tree, and the decoded annotations.
"""
raise NotImplementedError()
[docs]
class PickleTreeCodec(AnnotatedTreeCodec):
"""
Tree codec based on Python's :mod:`pickle` module.
"""
[docs]
def encode_annotated(self, root: Rule, annotations: Any) -> bytes:
"""
Pickle a tree and associated annotations into an array of bytes.
"""
return pickle.dumps((root, annotations))
[docs]
def decode_annotated(self, data: bytes) -> tuple[Rule | None, Any]:
"""
Unpickle a tree and associated annotations from an array of bytes.
"""
try:
root, annotations = pickle.loads(data)
return root, annotations
except pickle.UnpicklingError:
return None, None
[docs]
class JsonTreeCodec(TreeCodec):
"""
JSON-based tree codec.
"""
def __init__(self, encoding: str = 'utf-8', encoding_errors: str = 'surrogatepass') -> None:
"""
:param encoding: The encoding to use when converting between
json-formatted text and bytes (default: utf-8).
"""
self._encoding: str = encoding
self._encoding_errors: str = encoding_errors
[docs]
def encode(self, root: Rule) -> bytes:
"""
Create the JSON representation of a tree and convert it to an array of
bytes using the specified encoding.
"""
def _rule_to_dict(node):
if isinstance(node, UnlexerRule):
return {'t': 'l', 'n': node.name, 's': node.src, 'z': [node.size.depth, node.size.tokens], 'i': node.immutable}
if isinstance(node, UnparserRule):
return {'t': 'p', 'n': node.name, 'c': node.children}
if isinstance(node, UnparserRuleAlternative):
return {'t': 'a', 'ai': node.alt_idx, 'i': node.idx, 'c': node.children}
if isinstance(node, UnparserRuleQuantified):
return {'t': 'qd', 'c': node.children}
if isinstance(node, UnparserRuleQuantifier):
return {'t': 'q', 'i': node.idx, 'b': node.start, 'e': node.stop if node.stop != inf else -1, 'c': node.children}
raise AssertionError
return json.dumps(root, default=_rule_to_dict).encode(encoding=self._encoding, errors=self._encoding_errors)
[docs]
def decode(self, data: bytes) -> Rule | None:
"""
Reconstruct a tree from a JSON representation stored in an array of
bytes using the specified encoding.
"""
def _dict_to_rule(dct):
if not isinstance(dct, dict) or 't' not in dct:
logger.warning('Invalid JSON tree node.')
return None
if dct['t'] == 'l':
return UnlexerRule(name=dct['n'], src=dct['s'], size=RuleSize(depth=dct['z'][0], tokens=dct['z'][1]), immutable=dct['i'])
if dct['t'] == 'p':
return UnparserRule(name=dct['n'], children=dct['c'])
if dct['t'] == 'a':
return UnparserRuleAlternative(alt_idx=dct['ai'], idx=dct['i'], children=dct['c'])
if dct['t'] == 'qd':
return UnparserRuleQuantified(children=dct['c'])
if dct['t'] == 'q':
return UnparserRuleQuantifier(idx=dct['i'], start=dct['b'], stop=dct['e'] if dct['e'] != -1 else inf, children=dct['c'])
logger.warning('Unknown JSON tree node type.')
return None
try:
return json.loads(data.decode(encoding=self._encoding, errors=self._encoding_errors), object_hook=_dict_to_rule)
except json.JSONDecodeError:
logger.warning('Invalid JSON input.')
return None
[docs]
class FlatBuffersTreeCodec(TreeCodec):
"""
FlatBuffers-based tree codec.
"""
def __init__(self, encoding: str = 'utf-8', encoding_errors: str = 'ignore') -> None:
"""
:param encoding: The encoding to use when converting between
flatbuffers-encoded text and bytes (default: utf-8).
"""
self._encoding: str = encoding
self._encoding_errors: str = encoding_errors
[docs]
def encode(self, root: Rule) -> bytes:
"""
Create the FlatBuffers representation of a tree.
"""
def buildFBRule(rule):
if isinstance(rule, UnlexerRule):
fb_name = builder.CreateString(rule.name, encoding=self._encoding, errors=self._encoding_errors)
fb_src = builder.CreateString(rule.src, encoding=self._encoding, errors=self._encoding_errors)
FBRuleStart(builder)
FBRuleAddType(builder, FBRuleType.UnlexerRuleType)
FBRuleAddName(builder, fb_name)
FBRuleAddSrc(builder, fb_src)
FBRuleAddSize(builder, CreateFBRuleSize(builder, rule.size.depth, rule.size.tokens))
FBRuleAddImmutable(builder, rule.immutable)
else:
children = [buildFBRule(child) for child in rule.children]
FBRuleStartChildrenVector(builder, len(children))
for fb_child in reversed(children):
builder.PrependUOffsetTRelative(fb_child)
fb_children = builder.EndVector()
if isinstance(rule, UnparserRule):
fb_name = builder.CreateString(rule.name, encoding=self._encoding, errors=self._encoding_errors)
FBRuleStart(builder)
FBRuleAddChildren(builder, fb_children)
if isinstance(rule, UnparserRule):
FBRuleAddName(builder, fb_name)
FBRuleAddType(builder, FBRuleType.UnparserRuleType)
elif isinstance(rule, UnparserRuleQuantifier):
FBRuleAddType(builder, FBRuleType.UnparserRuleQuantifierType)
FBRuleAddIdx(builder, rule.idx)
FBRuleAddStart(builder, rule.start)
FBRuleAddStop(builder, rule.stop if rule.stop != inf else -1)
elif isinstance(rule, UnparserRuleQuantified):
FBRuleAddType(builder, FBRuleType.UnparserRuleQuantifiedType)
elif isinstance(rule, UnparserRuleAlternative):
FBRuleAddType(builder, FBRuleType.UnparserRuleAlternativeType)
FBRuleAddAltIdx(builder, rule.alt_idx)
FBRuleAddIdx(builder, rule.idx)
return FBRuleEnd(builder)
builder = flatbuffers.Builder()
builder.Finish(buildFBRule(root))
return bytes(builder.Output())
[docs]
def decode(self, data: bytes) -> Rule | None:
"""
Reconstruct a tree from a FlatBuffers representation.
"""
def readFBRule(fb_rule):
rule_type = fb_rule.Type()
if rule_type == FBRuleType.UnlexerRuleType:
fb_size = fb_rule.Size()
rule = UnlexerRule(name=fb_rule.Name().decode(self._encoding, self._encoding_errors),
src=fb_rule.Src().decode(self._encoding, self._encoding_errors),
size=RuleSize(depth=fb_size.Depth(), tokens=fb_size.Tokens()),
immutable=fb_rule.Immutable())
else:
children = [readFBRule(fb_rule.Children(i)) for i in range(fb_rule.ChildrenLength())]
if rule_type == FBRuleType.UnparserRuleType:
rule = UnparserRule(name=fb_rule.Name().decode(self._encoding, self._encoding_errors), children=children)
elif rule_type == FBRuleType.UnparserRuleQuantifierType:
stop = fb_rule.Stop()
rule = UnparserRuleQuantifier(idx=fb_rule.Idx(), start=fb_rule.Start(), stop=stop if stop != -1 else inf, children=children)
elif rule_type == FBRuleType.UnparserRuleQuantifiedType:
rule = UnparserRuleQuantified(children=children)
elif rule_type == FBRuleType.UnparserRuleAlternativeType:
rule = UnparserRuleAlternative(alt_idx=fb_rule.AltIdx(), idx=fb_rule.Idx(), children=children)
else:
assert False, f'Unexpected type {rule_type}'
return rule
try:
return readFBRule(FBRule.GetRootAs(bytearray(data)))
except struct.error:
return None