Source code for concepts.dm.crow.parsers.cdl_parser

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : cdl_parser.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import os
import os.path as osp
import contextlib
import jacinle

from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter

import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, VectorValueType, ListType, BatchedListType, BOOL, Variable, ObjectConstant, UnnamedPlaceholder, QINDEX
from concepts.dsl.dsl_functions import Function, FunctionType
from concepts.dsl.tensor_state import StateObjectList

from concepts.dm.crow.crow_function import CrowFunction
from concepts.dm.crow.controller import CrowControllerApplicationExpression
from concepts.dm.crow.behavior import CrowPrecondition, CrowBehaviorBodyItem, CrowBehaviorBodyPrimitiveBase, CrowBehaviorBodySuiteBase, CrowBehaviorCommit
from concepts.dm.crow.behavior import CrowAchieveExpression, CrowUntrackExpression, CrowBindExpression, CrowRuntimeAssignmentExpression, CrowAssertExpression, CrowFeatureAssignmentExpression, CrowBehaviorApplicationExpression
from concepts.dm.crow.behavior import CrowBehaviorConditionSuite, CrowBehaviorWhileLoopSuite, CrowBehaviorForeachLoopSuite, CrowBehaviorStatementOrdering, CrowBehaviorOrderingSuite
from concepts.dm.crow.crow_generator import CrowGeneratorApplicationExpression
from concepts.dm.crow.crow_domain import CrowDomain, CrowProblem, CrowState
from concepts.dm.crow.crow_expression_utils import crow_replace_expression_variables
from concepts.dm.crow.behavior_utils import execute_effect_statements

logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)


__all__ = [
    'CDLParser', 'path_resolver',
    'CDLPathResolver', 'CDLDomainTransformer', 'CDLProblemTransformer', 'CDLLiteralTransformer', 'CDLExpressionInterpreter',
    'get_default_parser',
    'load_domain_file', 'load_domain_string', 'load_domain_string_incremental',
    'load_problem_file', 'load_problem_string',
    'parse_expression',
]


[docs] class CDLPathResolver(object):
[docs] def __init__(self, search_paths: Sequence[str] = tuple()): self.search_paths = list(search_paths)
[docs] def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str: if osp.exists(filename): return filename # Try the relative filename first. if relative_filename is not None: relative_dir = osp.dirname(relative_filename) full_path = osp.join(relative_dir, filename) if osp.exists(full_path): return full_path # Try the current directory second. if osp.exists(osp.join(os.getcwd(), filename)): return osp.join(os.getcwd(), filename) # Then try the search paths. for path in self.search_paths: full_path = osp.join(path, filename) if osp.exists(full_path): return full_path raise FileNotFoundError(f'File not found: {filename}')
[docs] def add_search_path(self, path: str): self.search_paths.append(path)
[docs] def remove_search_path(self, path: str): self.search_paths.remove(path)
path_resolver = CDLPathResolver() g_parser_verbose = False
[docs] def set_parser_verbose(verbose: bool = True): global g_parser_verbose g_parser_verbose = verbose
[docs] class CDLParser(object): """The parser for PDSketch v3.""" grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'cdl.grammar') """The grammar definition v3 for PDSketch."""
[docs] def __init__(self): """Initialize the parser.""" with open(self.grammar_file, 'r') as f: self.grammar = f.read() self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs] def parse(self, filename: str) -> Tree: """Parse a PDSketch v3 file. Args: filename: the filename to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ filename = path_resolver.resolve(filename) with open(filename, 'r') as f: return self.parse_str(f.read())
[docs] def parse_str(self, s: str) -> Tree: """Parse a PDSketch v3 string. Args: s: the string to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ # NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline. # In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline. s = s.strip() + '\n' return self.parser.parse(s)
[docs] def parse_domain(self, filename: str) -> CrowDomain: """Parse a PDSketch v3 domain file. Args: filename: the filename to parse. Returns: the parsed domain. """ return self.transform_domain(self.parse(filename))
[docs] def parse_domain_str(self, s: str, domain: Optional[CrowDomain] = None) -> Any: """Parse a PDSketch v3 domain string. Args: s: the string to parse. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ return self.transform_domain(self.parse_str(s), domain=domain)
[docs] def parse_problem(self, filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Parse a PDSketch v3 problem file. Args: filename: the filename to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse(filename), domain=domain)
[docs] def parse_problem_str(self, s: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Parse a PDSketch v3 problem string. Args: s: the string to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse_str(s), domain=domain)
[docs] def parse_expression(self, s: str, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse a PDSketch v3 expression string. Args: s: the string to parse. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs] @staticmethod def transform_domain(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowDomain: """Transform a parse tree into a domain. Args: tree: the parse tree. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ transformer = CDLDomainTransformer(domain) transformer.transform(tree) return transformer.domain
[docs] @staticmethod def transform_problem(tree: Tree, domain: Optional[CrowDomain] = None) -> CrowProblem: """Transform a parse tree into a problem. Args: tree: the parse tree. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ transformer = CDLProblemTransformer(domain) transformer.transform(tree) return transformer.problem
[docs] @staticmethod def transform_expression(tree: Tree, domain: CrowDomain, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Transform a parse tree into an expression. Args: tree: the parse tree. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ transformer = CDLProblemTransformer(domain, state, auto_constant_guess=auto_constant_guess) interpreter = transformer.expression_interpreter expression_def_ctx = transformer.expression_def_ctx # the root of the tree is the `start` rule. tree = transformer.transform(tree).children[0] if variables is None: variables = tuple() with expression_def_ctx.with_variables(*variables) as ctx: return interpreter.visit(tree)
[docs] @dataclass class LiteralValue(object): """A literal value.""" value: Union[bool, int, float, complex, str]
[docs] @dataclass class LiteralList(object): """A list of literals.""" items: Tuple[Union[bool, int, float, complex, str], ...]
[docs] @dataclass class LiteralSet(object): """A set of literals.""" items: Set[Union[bool, int, float, complex, str]]
[docs] @dataclass class InTypedArgument(object): """A typed argument defined as `name in value`. This is used in forall/exists statements.""" name: str value: Any
[docs] class CDLLiteralTransformer(Transformer): """The transformer for literal types. Including: - VARNAME, CONSTNAME, BASIC_TYPENAME - number, DEC_NUMBER, HEX_NUMBER, BIN_NUMBER, OCT_NUMBER, FLOAT_NUMBER, IMAG_NUMBER - boolean, TRUE, FALSE - string - literal_list - literal_set - decorator_k, decorator_kwarg, decorator_kwargs """ domain: CrowDomain
[docs] @inline_args def typename(self, name: Union[str, TypeBase]) -> TypeBase: """Captures typenames including basic types and vector types.""" return name
[docs] @inline_args def sized_vector_typename(self, name: Union[str, TypeBase], size: int) -> VectorValueType: """Captures sized vector typenames defined as `vector[typename, size]`.""" return VectorValueType(self.domain.get_type(name), size)
[docs] @inline_args def unsized_vector_typename(self, name: Union[str, TypeBase]) -> VectorValueType: """Captures unsized vector typenames defined as `vector[typename]`.""" return VectorValueType(self.domain.get_type(name))
[docs] @inline_args def batched_typename(self, element_dtype: Union[str, TypeBase], indices: Tree) -> BatchedListType: """Captures batched typenames defined as `typename[indices]`.""" element_dtype = self.domain.get_type(element_dtype) return BatchedListType(element_dtype, [self.domain.get_type(name) for name in indices.children])
[docs] @inline_args def typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name: typename`.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def multi_typed_arguments(self, name: Tree, typename: Union[str, TypeBase]) -> 'CSList': """Captures multiple typed arguments defined as `name1, name2: typename`.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return CSList(tuple(Variable(n, typename) for n in name.children))
[docs] @inline_args def is_typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name is typename`. This is used in forall/exists statements.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def in_typed_argument(self, name: str, value: Any) -> InTypedArgument: """Captures typed arguments defined as `name in value`. This is used in forall/exists statements.""" return InTypedArgument(name, value)
[docs] def arguments_def(self, args): """Captures the arguments definition. This is used in function definitions.""" return ArgumentsDef(tuple(args))
[docs] def VARNAME(self, token): """Captures variable names, such as `var_name`.""" return token.value
[docs] def CONSTNAME(self, token): """Captures constant names, such as `CONST_NAME`.""" return token.value
[docs] def BASIC_TYPENAME(self, token): """Captures basic type names (non-vector types), such as `int`, `float`, `bool`, `object`, etc.""" return token.value
[docs] @inline_args def number(self, value: Union[int, float, complex]) -> Union[int, float, complex]: """Captures number literals, including integers, floats, and complex numbers.""" return value
[docs] @inline_args def BIN_NUMBER(self, value: str) -> int: """Captures binary number literals.""" return int(value, 2)
[docs] @inline_args def OCT_NUMBER(self, value: str) -> int: """Captures octal number literals.""" return int(value, 8)
[docs] @inline_args def DEC_NUMBER(self, value: str) -> int: """Captures decimal number literals.""" return int(value)
[docs] @inline_args def HEX_NUMBER(self, value: str) -> int: """Captures hexadecimal number literals.""" return int(value, 16)
[docs] @inline_args def FLOAT_NUMBER(self, value: str) -> float: """Captures floating point number literals.""" return float(value)
[docs] @inline_args def IMAG_NUMBER(self, value: str) -> complex: """Captures complex number literals.""" return complex(value)
[docs] @inline_args def boolean(self, value: bool) -> bool: """Captures boolean literals.""" return value
[docs] @inline_args def TRUE(self, _) -> bool: """Captures the `True` literal.""" return True
[docs] @inline_args def FALSE(self, _) -> bool: """Captures the `False` literal.""" return False
[docs] @inline_args def ELLIPSIS(self, _) -> str: """Captures the `...` literal.""" return Ellipsis
[docs] @inline_args def string(self, value: str) -> str: """Captures string literals.""" if value[0] == value[-1] and value[0] in ('"', "'"): value = value[1:-1] return str(value)
[docs] @inline_args def literal_list(self, *items: Any) -> LiteralList: """Captures literal lists, such as `[1, 2, 3, 4]`.""" return LiteralList(tuple(items))
[docs] @inline_args def literal_set(self, *items: Any) -> LiteralSet: """Captures literal sets, such as `{1, 2, 3, 4}`.""" return LiteralSet(set(items))
[docs] @inline_args def literal(self, value: Union[bool, int, float, complex, str, LiteralList, LiteralSet]) -> Union[LiteralValue, LiteralList, LiteralSet]: """Captures literal values.""" if isinstance(value, (bool, int, float, complex, str)): return LiteralValue(value) elif isinstance(value, (LiteralList, LiteralSet)): return value else: raise ValueError(f'Invalid literal value: {value}')
[docs] @inline_args def decorator_kwarg(self, k, v: Union[LiteralValue, LiteralList, LiteralSet] = True) -> Tuple[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pair of a decorator. This is used in the decorator syntax, such as [[k=True]].""" return k, v.value if isinstance(v, LiteralValue) else v
[docs] def decorator_kwargs(self, args) -> Dict[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pairs of a decorator. This is used in the decorator syntax, such as [[k=True, k2=123, k3=[1, 2, 3]]].""" return {k: v for k, v in args}
[docs] @dataclass class ArgumentsList(object): """A list of argument values. They can be variables, function calls, or other expressions.""" arguments: Tuple[Union['Suite', E.ValueOutputExpression, E.ListExpansionExpression, E.VariableExpression, bool, int, float, complex, str], ...]
[docs] @dataclass class FunctionCall(object): """A function call. This is used as the intermediate representation of the parsed expressions. Note that this includes not only function calls but also primitive operators and control flow statements. """ name: str args: ArgumentsList annotations: Optional[Dict[str, Any]] = None def __str__(self): annotation_str = '' if self.annotations is not None: annotation_str = f'[[' + ', '.join(f'{k}={v}' for k, v in self.annotations.items()) + ']] ' arg_strings = [str(arg) for arg in self.args.arguments] if sum(len(arg) for arg in arg_strings) > 80: arg_strings = [jacinle.indent_text(arg) for arg in arg_strings] return f'{annotation_str}{self.name}:\n' + '\n'.join(arg_strings) return f'{annotation_str}{self.name}(' + ', '.join(arg_strings) + ')' def __repr__(self): return f'FunctionCall{{{str(self)}}}'
[docs] @dataclass class CSList(object): """A comma-separated list of something.""" items: Tuple[Any, ...] def __str__(self): return f'CSList({", ".join(str(item) for item in self.items)})' def __repr__(self): return self.__str__()
[docs] @dataclass class Suite(object): """A suite of statements. This is used as the intermediate representation of the parsed expressions.""" items: Tuple[Any, ...] local_variables: Dict[str, Any] = field(default_factory=dict) tracker: Optional['FunctionCallTracker'] = None def _init_tracker(self, use_runtime_assign: bool = False): self.tracker = FunctionCallTracker(self, dict(), use_runtime_assign=use_runtime_assign).run()
[docs] def get_all_assign_expressions(self) -> List[Union[CrowFeatureAssignmentExpression, CrowBehaviorForeachLoopSuite, CrowBehaviorConditionSuite]]: if self.tracker is None: self._init_tracker() return self.tracker.assign_expressions
[docs] def get_all_check_expressions(self) -> List[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() return self.tracker.check_expressions
[docs] def get_all_behavior_expressions(self, use_runtime_assign=True) -> List[CrowBehaviorBodyItem]: if self.tracker is None: self._init_tracker(use_runtime_assign=use_runtime_assign) return self.tracker.behavior_expressions
[docs] def get_all_expr_expression(self, allow_multiple_expressions: bool = False) -> Optional[Union[E.ValueOutputExpression, Tuple[E.ValueOutputExpression, ...]]]: if self.tracker is None: self._init_tracker() if len(self.tracker.expr_expressions) == 1: return self.tracker.expr_expressions[0] if not allow_multiple_expressions: raise ValueError(f'Multiple expressions found in a single suite: {self.tracker.expr_expressions}') if len(self.tracker.expr_expressions) == 0: return tuple() return tuple(self.tracker.expr_expressions)
[docs] def get_combined_return_expression(self, allow_expr_expressions: bool = False, allow_multiple_expressions: bool = False) -> Optional[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() if self.tracker.return_expression is not None: return self.tracker.return_expression if allow_expr_expressions: return self.get_all_expr_expression(allow_multiple_expressions=allow_multiple_expressions) return None
def __str__(self): if len(self.items) == 0: return 'Suite{}' if len(self.items) == 1: return f'Suite{{{self.items[0]}}}' return 'Suite{\n' + '\n'.join(jacinle.indent_text(str(item)) for item in self.items) + '\n}' def __repr__(self): return self.__str__()
[docs] class FunctionCallTracker(object): """This class is used to track the function calls and other statements in a suite. It supports simulating the execution of the program and generating the post-condition of the program. """
[docs] def __init__(self, suite: Suite, init_local_variables: Optional[Dict[str, Any]] = None, use_runtime_assign: bool = False): self.suite = suite self.local_variables = dict() if init_local_variables is None else init_local_variables self.assign_expressions = list() self.assign_expression_signatures = dict() self.check_expressions = list() self.expr_expressions = list() self.behavior_expressions = list() self.return_expression = None self.use_runtime_assign = use_runtime_assign
local_variables_stack: List[Dict[str, Any]] """The assignments of local variables.""" assign_expressions: List[Union[CrowFeatureAssignmentExpression, CrowBehaviorForeachLoopSuite, CrowBehaviorConditionSuite]] """A list of assign expressions.""" assign_expression_signatures: Dict[Tuple[str, ...], E.VariableAssignmentExpression] """A dictionary of assign expressions, indexed by their signatures.""" check_expressions: List[E.ValueOutputExpression] """A list of check expressions.""" expr_expressions: List[E.ValueOutputExpression] """A list of expr expressions (i.e. bare expressions in the body).""" return_expression: Optional[E.ValueOutputExpression] """The return expression. This is either None or a single expression.""" use_runtime_assign: bool """Whether to use runtime assign expressions. If this is set to True, assignment expressions for local variables will be converted into runtime assign expressions.""" behavior_expressions: List[CrowBehaviorBodyItem] def _g( self, expr: Union[E.Expression, UnnamedPlaceholder, CrowBehaviorBodyItem] ) -> Union[E.Expression, UnnamedPlaceholder, CrowBehaviorBodyItem]: if isinstance(expr, (CrowControllerApplicationExpression, CrowBehaviorApplicationExpression, CrowGeneratorApplicationExpression)): return expr if isinstance(expr, (CrowBehaviorBodyPrimitiveBase, CrowBehaviorBodySuiteBase)): return expr if not isinstance(expr, E.Expression): raise ValueError(f'Invalid expression: {expr}') return crow_replace_expression_variables(expr, { E.VariableExpression(Variable(k, AutoType)): v for k, v in self.local_variables.items() }) def _get_deictic_signature(self, e, known_deictic_vars=tuple()) -> Optional[Tuple[str, ...]]: if isinstance(e, E.DeicticAssignExpression): known_deictic_vars = known_deictic_vars + (e.variable.name,) return self._get_deictic_signature(e.expression, known_deictic_vars) elif isinstance(e, E.AssignExpression): args = [x.name if x.name not in known_deictic_vars else '?' for x in e.predicate.arguments] return tuple((e.predicate.function.name, *args)) else: return None def _mark_assign(self, *exprs: E.VariableAssignmentExpression, annotations: Optional[Dict[str, Any]] = None): if annotations is None: annotations = dict() for expr in exprs: signature = self._get_deictic_signature(expr) if signature is not None: if signature in self.assign_expression_signatures: raise ValueError(f'Duplicate assign expression: {expr} vs {self.assign_expression_signatures[signature]}') self.assign_expressions.append((expr, annotations)) self.assign_expression_signatures[signature] = expr else: self.assign_expressions.append((expr, annotations))
[docs] @jacinle.log_function(verbose=False) def run(self): """Simulate the execution of the program and generates an equivalent return statement. This function handles if-else conditions. However, loops are not allowed.""" # jacinle.log_function.print('Current suite:', self.suite) current_return_statement = None current_return_statement_condition_neg = None for item in self.suite.items: assert isinstance(item, FunctionCall), f'Invalid item in suite: {item}' if item.name == 'assign': if isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and item.args.arguments[1] is Ellipsis: # TODO(Jiayuan Mao @ 2024/07/17): implement this for the new CrowFeatureAssignExpression. self.assign_expressions.append(CrowFeatureAssignmentExpression( self._g(item.args.arguments[0]), E.NullExpression(item.args.arguments[0].return_type), **item.annotations if item.annotations is not None else dict() )) elif isinstance(item.args.arguments[0], (E.ListFunctionApplicationExpression, E.FunctionApplicationExpression)) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression)): self.assign_expressions.append(CrowFeatureAssignmentExpression( self._g(item.args.arguments[0]), self._g(item.args.arguments[1]), **item.annotations if item.annotations is not None else dict() )) elif isinstance(item.args.arguments[0], E.VariableExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression, E.FindAllExpression, E.FindOneExpression)): if self.use_runtime_assign and (item.annotations is None or 'symbol' not in item.annotations): # Runtime assign / assignment to feature variables. if item.args.arguments[0].name not in self.local_variables: self.local_variables[item.args.arguments[0].name] = E.VariableExpression(item.args.arguments[0]) self.behavior_expressions.append(CrowRuntimeAssignmentExpression( item.args.arguments[0].variable, self._g(item.args.arguments[1]) )) else: self.local_variables[item.args.arguments[0].name] = self._g(item.args.arguments[1]) else: raise ValueError(f'Invalid assignment: {item}. Types: {type(item.args.arguments[0])}, {type(item.args.arguments[1])}.') elif item.name == 'check': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid check expression: {item.args.arguments[0]}' self.check_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'expr': if isinstance(item.args.arguments[0], (CrowControllerApplicationExpression, CrowBehaviorApplicationExpression, E.ListExpansionExpression)): self.behavior_expressions.append(self._g(item.args.arguments[0])) else: assert isinstance( item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression, CrowGeneratorApplicationExpression) ), f'Invalid expr expression: {item.args.arguments[0]}' self.expr_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'bind': arguments = item.args.arguments[0].items body = item.args.arguments[1] self.behavior_expressions.append(CrowBindExpression(arguments, body)) elif item.name == 'achieve': term = item.args.arguments[0] self.behavior_expressions.append(CrowAchieveExpression(term, **item.annotations if item.annotations is not None else dict())) elif item.name == 'pachieve': term = item.args.arguments[0] self.behavior_expressions.append(CrowAchieveExpression(term, is_policy_achieve=True, **item.annotations if item.annotations is not None else dict())) elif item.name == 'untrack': term = item.args.arguments[0] self.behavior_expressions.append(CrowUntrackExpression(term)) elif item.name == 'assert': term = item.args.arguments[0] self.behavior_expressions.append(CrowAssertExpression(term, **item.annotations if item.annotations is not None else dict())) elif item.name == 'commit': self.behavior_expressions.append(CrowBehaviorCommit(**item.annotations if item.annotations is not None else dict())) elif item.name == 'return': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid return expression: {item.args.arguments[0]}' self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, self._g(item.args.arguments[0])) break elif item.name == 'ordering': suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() self.local_variables = tracker.local_variables behavior_expressions = tracker.behavior_expressions if item.args.arguments[0] == 'promotable unordered': prog = CrowBehaviorOrderingSuite('promotable', (CrowBehaviorOrderingSuite('unordered', behavior_expressions),)) elif item.args.arguments[0] == 'promotable sequential': prog = CrowBehaviorOrderingSuite('promotable', behavior_expressions) elif item.args.arguments[0] == 'critical unordered': prog = CrowBehaviorOrderingSuite('critical', (CrowBehaviorOrderingSuite('unordered', behavior_expressions),)) elif item.args.arguments[0] == 'critical sequential': prog = CrowBehaviorOrderingSuite('critical', behavior_expressions) else: assert ' ' not in item.args.arguments[0], f'Invalid ordering type: {item.args.arguments[0]}' prog = CrowBehaviorOrderingSuite(item.args.arguments[0], behavior_expressions) self.behavior_expressions.append(prog) elif item.name == 'if': condition = self._g(item.args.arguments[0]) neg_condition = E.NotExpression(condition) assert isinstance(condition, E.ValueOutputExpression), f'Invalid condition: {condition}. Type: {type(condition)}.' t_suite = item.args.arguments[1] f_suite = item.args.arguments[2] t_tracker = FunctionCallTracker(t_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() f_tracker = FunctionCallTracker(f_suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() assert set(t_tracker.local_variables.keys()) == set(f_tracker.local_variables.keys()), f'Local variables in the true and false branches are not consistent: {t_tracker.local_variables.keys()} vs {f_tracker.local_variables.keys()}' new_local_variables = t_tracker.local_variables for k, v in t_tracker.local_variables.items(): if f_tracker.local_variables[k] != v: new_local_variables[k] = E.ConditionExpression(condition, v, f_tracker.local_variables[k]) self.local_variables = new_local_variables if len(t_tracker.assign_expressions) > 0 or len(f_tracker.assign_expressions) > 0: self.assign_expressions.append(CrowBehaviorConditionSuite(condition, t_tracker.assign_expressions, f_tracker.assign_expressions if len(f_tracker.assign_expressions) > 0 else None)) for expr in t_tracker.check_expressions: self.check_expressions.append(_make_conditional_implies(condition, expr)) for expr in f_tracker.check_expressions: raise RuntimeError(f'Check statements in the false branch are not supported: {expr}') # TODO(Jiayuan Mao @ 2024/07/17): implement false-branch check statements. self.check_expressions.append(_make_conditional_implies(neg_condition, expr)) if len(t_tracker.expr_expressions) != len(f_tracker.expr_expressions): raise ValueError(f'Number of bare expressions in the true and false branches are not consistent: {len(t_tracker.expr_expressions)} vs {len(f_tracker.expr_expressions)}') if len(t_tracker.expr_expressions) == 0: pass elif len(t_tracker.expr_expressions) == 1: self.expr_expressions.append(E.ConditionExpression(condition, t_tracker.expr_expressions[0], f_tracker.expr_expressions[0])) else: raise ValueError(f'Multiple bare expressions in the true and false branches are not supported: {t_tracker.expr_expressions} vs {f_tracker.expr_expressions}') if len(t_tracker.behavior_expressions) > 0: self.behavior_expressions.append(CrowBehaviorConditionSuite(condition, t_tracker.behavior_expressions, f_tracker.behavior_expressions)) if t_tracker.return_expression is not None and f_tracker.return_expression is not None: # Both branches have return statements. statement = E.ConditionExpression(condition, t_tracker.return_expression, f_tracker.return_expression) self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, statement) break elif t_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, t_tracker.return_expression) current_return_statement_condition_neg = E.NotExpression(condition) elif f_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, f_tracker.return_expression) current_return_statement_condition_neg = condition else: pass elif item.name == 'while': condition = self._g(item.args.arguments[0]) suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the while statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') for expr in tracker.assign_expressions: raise ValueError(f'Assign statements are not allowed in a while statement: {expr}') for expr in tracker.check_expressions: raise ValueError(f'Check statements are not allowed in a while statement: {expr}') if len(tracker.expr_expressions) > 0: raise ValueError(f'Expr statements are not allowed in a while statement: {tracker.expr_expressions}') if len(tracker.behavior_expressions) > 0: self.behavior_expressions.append(CrowBehaviorWhileLoopSuite(condition, tracker.behavior_expressions, **item.annotations if item.annotations is not None else dict())) if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a while statement: {tracker.return_expression}') elif item.name == 'foreach': suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the foreach statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') statements = tracker.assign_expressions if len(statements) > 0: for var in item.args.arguments[0].items: statements = [CrowBehaviorForeachLoopSuite(var, statements)] self.assign_expressions.extend(statements) for expr in tracker.check_expressions: for var in item.args.arguments[0].items: expr = E.ForallExpression(var, expr) self.check_expressions.append(expr) if len(tracker.expr_expressions) == 0: pass else: if len(tracker.expr_expressions) == 1: merged = tracker.expr_expressions[0] else: merged = E.AndExpression(*tracker.expr_expressions) for var in item.args.arguments[0].items: merged = E.ForallExpression(var, merged) self.expr_expressions.append(merged) if len(tracker.behavior_expressions) > 0: assert len(item.args.arguments[0].items) == 1, f'Invalid number of variables in the foreach statement: {item.args.arguments[0].items}. Currently only one variable is supported.' self.behavior_expressions.append(CrowBehaviorForeachLoopSuite(item.args.arguments[0].items[0], tracker.behavior_expressions)) if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a foreach statement: {tracker.return_expression}') elif item.name == 'foreach_in': variables = item.args.arguments[0].items values = item.args.arguments[1].items suite = item.args.arguments[2] if len(variables) != 1 or len(values) != 1: raise NotImplementedError(f'Currently only one variable and one value are supported in a foreach_in statement: {variables} vs {values}') tracker = FunctionCallTracker(suite, self.local_variables.copy(), use_runtime_assign=self.use_runtime_assign).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the foreach_in statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') statements = tracker.assign_expressions if len(statements) > 0: for var in item.args.arguments[1].items: statements = [CrowBehaviorForeachLoopSuite(var, statements)] self.assign_expressions.extend(statements) if len(tracker.check_expressions) > 0: raise NotImplementedError(f'Check statements are not allowed in a foreach_in statement: {tracker.check_expressions}') if len(tracker.expr_expressions) > 0: raise NotImplementedError(f'Expr statements are not allowed in a foreach_in statement: {tracker.expr_expressions}') if len(tracker.behavior_expressions) > 0: # TODO(Jiayuan Mao @ 2024/03/12): implement the rest parts of action statements. expressions = tracker.behavior_expressions for var, value in reversed(list(zip(variables, values))): expressions = [CrowBehaviorForeachLoopSuite(var, expressions, loop_in_expression=value)] self.behavior_expressions.extend(expressions) if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a foreach_in statement: {tracker.return_expression}') elif item.name == 'pass': pass # jacinle.log_function.print('Local variables:', self.local_variables) # jacinle.log_function.print('Assign expressions:', self.assign_expressions) # jacinle.log_function.print('Check expressions:', self.check_expressions) # jacinle.log_function.print('Expr expressions:', self.expr_expressions) # jacinle.log_function.print('Return expression:', self.return_expression) return self
def _make_conditional_implies(condition: E.ValueOutputExpression, test: E.ValueOutputExpression): if isinstance(test, E.BoolExpression) and test.op == E.BoolOpType.IMPLIES: if isinstance(test.arguments[0], E.BoolExpression) and test.arguments[0].op == E.BoolOpType.AND: return E.ImpliesExpression(E.AndExpression(condition, *test.arguments[0].arguments), test.arguments[1]) else: return E.ImpliesExpression(E.AndExpression(condition, test.arguments[0]), test.arguments[1]) else: return E.ImpliesExpression(condition, test) def _make_conditional_return(current_stmt: Optional[E.ValueOutputExpression], current_condition_neg: Optional[E.ValueOutputExpression], new_stmt: E.ValueOutputExpression): if current_stmt is None: return new_stmt return E.ConditionExpression(current_condition_neg, new_stmt, current_stmt) def _make_conditional_assign(assign_stmt: E.VariableAssignmentExpression, condition: E.ValueOutputExpression): if isinstance(assign_stmt, E.AssignExpression): return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, condition) elif isinstance(assign_stmt, E.ConditionalAssignExpression): if isinstance(assign_stmt.condition, E.BoolExpression) and assign_stmt.condition.op == E.BoolOpType.AND: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, *assign_stmt.condition.arguments)) else: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, assign_stmt.condition)) elif isinstance(assign_stmt, E.DeicticAssignExpression): return E.DeicticAssignExpression(assign_stmt.variable, _make_conditional_assign(assign_stmt.expression, condition)) else: raise ValueError(f'Invalid assign statement: {assign_stmt}') g_term_op_mapping = { '*': 'mul', '/': 'div', '+': 'add', '-': 'sub', }
[docs] def gen_term_expr(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is used to generate the following functions: - mul_expr - arith_expr - shift_expr Args: expr_typename: the name of the expression type. This is only used for printing the debug information. Returns: the generated term expression function. """ @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.' result = values[0] for i in range(1, len(values), 2): v1, v2 = _canonize_arguments_same_dtype([result, values[i + 1]]) t = v1.return_type if t.is_variable_sized_sequence_type: t = t.element_type fname = f'type::{t.typename}::{g_term_op_mapping[values[i]]}' result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [v1, v2]) # print(f'[{expr_typename}] result: {result}') return result return term
[docs] def gen_term_expr_noop(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to specify the operator name manually (`expr_typename`). This is used for the following functions: - bitand_expr - bitxor_expr - bitor_expr """ op_mapping = { 'bitand': E.BoolOpType.AND, 'bitxor': E.BoolOpType.XOR, 'bitor': E.BoolOpType.OR, } @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values)) # print(f'[{expr_typename}] result: {result}') return result return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ObjectOrValueOutputExpression, E.VariableExpression, type(Ellipsis)]: if isinstance(arg, E.ObjectOrValueOutputExpression): return arg if isinstance(arg, (CrowControllerApplicationExpression, CrowBehaviorApplicationExpression, CrowGeneratorApplicationExpression)): return arg if isinstance(arg, Variable): return E.VariableExpression(arg) if isinstance(arg, (bool, int, float, complex, str)): return E.ConstantExpression.from_value(arg, dtype=dtype) if arg is QINDEX: return E.ObjectConstantExpression(ObjectConstant(StateObjectList(ListType(AutoType), QINDEX), ListType(AutoType))) if arg is Ellipsis: return Ellipsis raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.') def _canonize_arguments_same_dtype(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtype: Optional[TypeBase] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]: if args is None: return tuple() args = args.arguments if isinstance(args, ArgumentsList) else args if dtype is None: # Guess the dtype from the list. for arg in args: if isinstance(arg, E.ObjectOrValueOutputExpression): dtype = arg.return_type break canonized_args = list() for arg in args: canonized_args.append(_canonize_single_argument(arg, dtype=dtype)) return tuple(canonized_args) def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Sequence[TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]: if args is None: return tuple() canonized_args = list() # TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments. arguments = args.arguments if isinstance(args, ArgumentsList) else args if dtypes is not None: if len(arguments) != len(dtypes): raise ValueError(f'Number of arguments does not match the number of types: {len(arguments)} vs {len(dtypes)}. Args: {arguments}, Types: {dtypes}') for i, arg in enumerate(arguments): if arg is Ellipsis: canonized_args.append(Ellipsis) else: canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None)) return tuple(canonized_args) def _safe_is_value_type(arg: Any) -> bool: if isinstance(arg, E.ObjectOrValueOutputExpression): return arg.return_type.is_value_type if isinstance(arg, (bool, int, float, complex, str)): return True raise ValueError(f'Invalid argument: {arg}. Type: {type(arg)}.') def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool: for arg in args: if arg.return_type.is_list_type: return True return False
[docs] class CDLExpressionInterpreter(Interpreter): """The transformer for expressions. Including: - typename - sized_vector_typename - unsized_vector_typename - typed_argument - is_typed_argument - in_typed_argument - arguments_def - atom_expr_funccall - atom_varname - atom - power - factor - unary_op_expr - mul_expr - arith_expr - shift_expr - bitand_expr - bitxor_expr - bitor_expr - comparison_expr - not_test - and_test - or_test - cond_test - test - test_nocond - tuple - list - cs_list - suite - expr_stmt - expr_list_expansion_stmt - assign_stmt - annotated_assign_stmt - local_assign_stmt - pass_stmt - check_stmt - return_stmt - achieve_stmt - if_stmt - foreach_stmt - foreach_in_stmt - while_stmt - forall_test - exists_test - findall_test - forall_in_test - exists_in_test """
[docs] def __init__(self, domain: Optional[CrowDomain], state: Optional[CrowState], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False): super().__init__() self.domain = domain self.state = state self.expression_def_ctx = expression_def_ctx self.auto_constant_guess = auto_constant_guess self.generator_impl_outputs = None self.local_variables = dict()
[docs] def set_domain(self, domain: CrowDomain): self.domain = domain
[docs] @contextlib.contextmanager def local_variable_guard(self): backup = self.local_variables.copy() yield self.local_variables = backup
[docs] @contextlib.contextmanager def set_generator_impl_outputs(self, outputs: List[Variable]): backup = self.generator_impl_outputs self.generator_impl_outputs = outputs yield self.generator_impl_outputs = backup
domain: CrowDomain expression_def_ctx: E.ExpressionDefinitionContext
[docs] def visit(self, tree: Any) -> Any: if isinstance(tree, Tree): return super().visit(tree) return tree
[docs] @inline_args def atom_colon(self): return QINDEX
[docs] @inline_args def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]: """Captures variable names such as `var_name`.""" if name in self.local_variables: return self.local_variables[name] if self.state is not None and name in self.state.object_name2defaultindex: return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_default_typename(name)])) if name in self.domain.constants: constant = self.domain.constants[name] if isinstance(constant, ObjectConstant): return E.ObjectConstantExpression(constant) return E.ConstantExpression(constant) if name in self.domain.features and self.domain.features[name].nr_arguments == 0: return E.FunctionApplicationExpression(self.domain.features[name], tuple()) # TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable. if not self.expression_def_ctx.has_variable(name): if self.auto_constant_guess: return E.ObjectConstantExpression(ObjectConstant(name, AutoType)) variable = self.expression_def_ctx.wrap_variable(name) return variable
[docs] @inline_args def atom_expr_do_funccall(self, name: str, annotations: dict, args: Tree) -> CrowBehaviorBodyItem: if self.domain.has_controller(name): controller = self.domain.get_controller(name) args: Optional[ArgumentsList] = self.visit(args) args_c = _canonize_arguments(args, controller.argument_types) return CrowControllerApplicationExpression(controller, args_c) else: raise KeyError(f'Controller {name} not found.')
[docs] @inline_args def atom_expr_funccall(self, name: str, annotations: dict, args: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression, CrowBehaviorBodyItem]: """Captures function calls, such as `func_name(arg1, arg2, ...)`.""" annotations: Optional[dict] = self.visit(annotations) args: Optional[ArgumentsList] = self.visit(args) if annotations is None: annotations = dict() if args is None: args = ArgumentsList(tuple()) if self.domain.has_feature(name): function = self.domain.get_feature(name) args_c = _canonize_arguments(args, function.ftype.argument_types) return E.FunctionApplicationExpression(function, args_c, **annotations) elif self.domain.has_function(name): function = self.domain.get_function(name) args_c = _canonize_arguments(args, function.ftype.argument_types) return E.FunctionApplicationExpression(function, args_c, **annotations) elif self.domain.has_controller(name): controller = self.domain.get_controller(name) args_c = _canonize_arguments(args, controller.argument_types) return CrowControllerApplicationExpression(controller, args_c) elif self.domain.has_behavior(name): behavior = self.domain.get_behavior(name) args_c = _canonize_arguments(args, behavior.argument_types) if len(args_c) > 0 and args_c[-1] is Ellipsis: args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in behavior.argument_types[len(args_c) - 1:]]) return CrowBehaviorApplicationExpression(behavior, args_c) elif self.domain.has_generator(name): generator = self.domain.get_generator(name) args_c = _canonize_arguments(args, generator.argument_types) return CrowGeneratorApplicationExpression(generator, args_c, list()) else: if 'inplace_behavior_body' in annotations and annotations['inplace_behavior_body']: """Inplace definition of an function, used for define a __totally_ordered_plan__ function.""" args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Behavior {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_crow_function(name, argument_types, self.domain.get_type('__behavior_body__')) return E.FunctionApplicationExpression(predicate, args_c) elif 'inplace_generator' in annotations and annotations['inplace_generator']: """Inplace definition of an generator function. Typically this function is used together with the generator_placeholder annotation.""" args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_crow_function( name, argument_types, BOOL, generator_placeholder=annotations.get('generator_placeholder', True) ) assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.' inplace_generator_targets = annotations['inplace_generator_targets'] generator_name = 'gen_' + name generator_arguments = predicate.arguments generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments]) output_argument_names = [x.value for x in inplace_generator_targets.items] output_indices = list() for i, arg in enumerate(args_c): if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names: output_argument_names.remove(arg.name) output_indices.append(i) assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}' inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices] outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices] self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs) return E.FunctionApplicationExpression(predicate, args_c) else: raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs] @inline_args def atom_subscript(self, name: str, annotations: dict, index: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]: """Captures subscript expressions such as `name[index1, index2, ...]`.""" feature = self.domain.get_feature(name) index: CSList = self.visit(index) annotations: Optional[dict] = self.visit(annotations) if not feature.is_state_variable: raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]') if annotations is None: annotations = dict() items = index.items if len(items) == 1 and items[0] is Ellipsis: return E.FunctionApplicationExpression(feature, tuple()) arguments = _canonize_arguments(index.items, dtypes=feature.ftype.argument_types) return E.FunctionApplicationExpression(feature, arguments, **annotations)
[docs] @inline_args def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: """Captures the atom. This is used in the base case of the expression, including literal constants, variables, and subscript expressions.""" return value
[docs] def arguments(self, args: Tree) -> ArgumentsList: """Captures the argument list. This is used in function calls.""" args = self.visit_children(args) return ArgumentsList(tuple(args))
[docs] @inline_args def power(self, base: Union[FunctionCall, Variable], exp: Optional[Union[FunctionCall, Variable]] = None) -> Union[FunctionCall, Variable]: """The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`.""" if exp is None: return base raise NotImplementedError('Power expression is not supported in the current version.')
[docs] @inline_args def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: return value
[docs] @inline_args def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]): value = self.visit(value) if op == '+': return value if op == '-': t = value.return_type if t.is_variable_sized_sequence_type: t = t.element_type fname = f'type::{t.typename}::neg' return E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t], t)), [value]) raise NotImplementedError(f'Unary operator {op} is not supported in the current version.')
mul_expr = gen_term_expr('mul') arith_expr = gen_term_expr('add') shift_expr = gen_term_expr('shift') bitand_expr = gen_term_expr_noop('bitand') bitxor_expr = gen_term_expr_noop('bitxor') bitor_expr = gen_term_expr_noop('bitor')
[docs] @inline_args def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression: if len(values) == 1: return self.visit(values[0]) assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.' values = [self.visit(value) for value in values] results = list() for i in range(1, len(values), 2): if not _safe_is_value_type(values[i - 1]) and not _safe_is_value_type(values[i + 1]): results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1])) elif _safe_is_value_type(values[i - 1]) and _safe_is_value_type(values[i + 1]): v1, v2 = _canonize_arguments_same_dtype([values[i - 1], values[i + 1]]) results.append(E.ValueCompareExpression(E.CompareOpType.from_string(values[i][0].value), v1, v2)) else: raise ValueError(f'Invalid comparison: {values[i - 1]} vs {values[i + 1]}') if len(results) == 1: return results[0] result = E.AndExpression(*results) return result
[docs] @inline_args def not_test(self, value: Any) -> E.NotExpression: return E.NotExpression(*_canonize_arguments([self.visit(value)]))
[docs] @inline_args def and_test(self, *values: Any) -> E.AndExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.AndExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def or_test(self, *values: Any) -> E.OrExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.OrExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression: return E.ConditionExpression(*_canonize_arguments([self.visit(cond), self.visit(value1), self.visit(value2)]))
[docs] @inline_args def test(self, value: Any): return self.visit(value)
[docs] @inline_args def test_nocond(self, value: Any): return self.visit(value)
[docs] @inline_args def tuple(self, *values: Any): return tuple(self.visit(v) for v in values)
[docs] @inline_args def list(self, *values: Any): return E.ListCreationExpression(_canonize_arguments([self.visit(v) for v in values]))
[docs] @inline_args def cs_list(self, *values: Any): return CSList(tuple(self.visit(v) for v in values))
[docs] @inline_args def suite(self, *values: Tree, activate_variable_guard: bool = True) -> Suite: if activate_variable_guard: with self.local_variable_guard(): values = [self.visit(value) for value in values] local_variables = self.local_variables.copy() else: values = [self.visit(value) for value in values] local_variables = self.local_variables.copy() return Suite(tuple(v for v in values if v is not None), local_variables)
[docs] @inline_args def expr_stmt(self, value: Tree): value = self.visit(value) if value is Ellipsis: return None # NB(Jiayuan Mao @ 2024/06/21): for handling string literals as docs. if isinstance(value, str): return None return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs] @inline_args def expr_list_expansion_stmt(self, value: Any): value = _canonize_single_argument(self.visit(value)) return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs] @inline_args def compound_expr_stmt(self, value: Tree): value = self.visit(value) if value is Ellipsis: return None if isinstance(value, str): return None return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
def _make_additive_assign_stmt(self, lv, rv, op: str, annotations: dict): if op == '=': return FunctionCall('assign', ArgumentsList((lv, rv)), annotations) if op in ('+=', '-=', '*=', '/='): t = lv.return_type if t.is_variable_sized_sequence_type: t = t.element_type fname = f'type::{t.typename}::{g_term_op_mapping[op[0]]}' result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [lv, rv]) return FunctionCall('assign', ArgumentsList((lv, result)), annotations) if op == '%=': t = lv.return_type if t.is_variable_sized_sequence_type: t = t.element_type fname = f'type::{t.typename}::mod' result = E.FunctionApplicationExpression(CrowFunction(fname, FunctionType([t, t], t)), [lv, rv]) return FunctionCall('assign', ArgumentsList((lv, result)), annotations) if op in ('&=', '|=', '^='): mapping = { '&': E.BoolOpType.AND, '|': E.BoolOpType.OR, '^': E.BoolOpType.XOR, } result = E.BoolExpression(mapping[op[0]], (lv, rv)) return FunctionCall('assign', ArgumentsList((lv, result)), annotations) raise ValueError(f'Invalid assignment operator: {op}')
[docs] def assign_stmt_inner(self, op: str, target: Any, value: Any, annotations: dict): if target.data == 'atom_varname': target_lv = target.children[0] else: target_lv = _canonize_single_argument(self.visit(target)) # left value if isinstance(target_lv, str): if target_lv in self.local_variables: annotations.setdefault('local', True) target_lv = self.local_variables[target_lv] target_rv = _canonize_single_argument(self.visit(value)) # return FunctionCall('assign', ArgumentsList((target_lv, target_rv)), annotations) return self._make_additive_assign_stmt(target_lv, target_rv, op, annotations) else: if target_lv in self.domain.features and self.domain.get_feature(target_lv).nr_arguments == 0: target_lv = E.FunctionApplicationExpression(self.domain.get_feature(target_lv), tuple()) # return FunctionCall('assign', ArgumentsList((target_lv, _canonize_single_argument(self.visit(value)))), annotations) return self._make_additive_assign_stmt(target_lv, _canonize_single_argument(self.visit(value)), op, annotations) else: raise NameError(f'Invalid assignment target: it is not a local variable and not a feature with 0 arguments: {target_lv}') # return FunctionCall('assign', ArgumentsList((target_lv, _canonize_single_argument(self.visit(value)))), annotations) return self._make_additive_assign_stmt(target_lv, _canonize_single_argument(self.visit(value)), op, annotations)
[docs] @inline_args def assign_stmt(self, target: Any, op: Any, value: Any): return self.assign_stmt_inner(op.value, target, value, dict())
[docs] @inline_args def annotated_assign_stmt(self, annotations: dict, target: Any, op: Any, value: Any): return self.assign_stmt_inner(op.value, target, value, annotations)
[docs] @inline_args def let_assign_stmt(self, target: str, value: Any = None): assert isinstance(target, str), f'Invalid local variable name: {target}' value = _canonize_single_argument(self.visit(value)) self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type if value is not None else AutoType)) if value is not None: return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)), {'local': True}) return None
[docs] @inline_args def symbol_assign_stmt(self, target: str, value: Any): assert isinstance(target, str), f'Invalid symbol variable name: {target}' value = _canonize_single_argument(self.visit(value)) if target in self.local_variables: raise RuntimeError(f'Local symbol variable {target} has been assigned before.') self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type)) if value is not None: return FunctionCall('assign', ArgumentsList((self.local_variables[target], value)), {'symbol': True}) return None
[docs] @inline_args def pass_stmt(self): return FunctionCall('pass', ArgumentsList(tuple()))
[docs] @inline_args def commit_stmt(self, kwargs: dict): return FunctionCall('commit', ArgumentsList(tuple()), kwargs)
[docs] @inline_args def check_stmt(self, value: Any): # Todo: run check statements. return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def achieve_once_stmt(self, value: CSList): return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)), {'once': True})
[docs] @inline_args def achieve_hold_stmt(self, value: CSList): return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)), {'once': False})
[docs] @inline_args def pachieve_once_stmt(self, value: CSList): return FunctionCall('pachieve', ArgumentsList(_canonize_arguments(self.visit(value).items)), {'once': True})
[docs] @inline_args def pachieve_hold_stmt(self, value: CSList): return FunctionCall('pachieve', ArgumentsList(_canonize_arguments(self.visit(value).items)), {'once': False})
[docs] @inline_args def untrack_stmt(self, value: CSList): return FunctionCall('untrack', ArgumentsList(_canonize_arguments(self.visit(value).items)))
[docs] @inline_args def assert_once_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )), {'once': True})
[docs] @inline_args def assert_hold_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )), {'once': False})
[docs] @inline_args def return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_assign_stmt(self, target: Variable, value: Any): return FunctionCall('assign', ArgumentsList((self.visit(target), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args def compound_check_stmt(self, value: Any): return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_achieve_once_stmt(self, value: Any): return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])), {'once': True})
[docs] @inline_args def compound_achieve_hold_stmt(self, value: Any): return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])), {'once': False})
[docs] @inline_args def compound_untrack_stmt(self, value: Any): return FunctionCall('untrack', ArgumentsList(_canonize_arguments([self.visit(value)])))
[docs] @inline_args def compound_assert_once_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )), {'once': True})
[docs] @inline_args def compound_assert_hold_stmt(self, value: Any): return FunctionCall('assert', ArgumentsList((_canonize_single_argument(self.visit(value)), )), {'once': False})
[docs] @inline_args def compound_return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def ordered_suite(self, ordering_op: Any, body: Any): ordering_op = self.visit(ordering_op) assert body.data.value == 'suite', f'Invalid body type: {body}' if ordering_op in ('promotable', 'unordered', 'promotable unordered', 'promotable sequential'): with self.local_variable_guard(): body = self.visit(body) return FunctionCall('ordering', ArgumentsList((ordering_op, body))) else: body = self.visit_children(body) body = Suite(tuple(body), self.local_variables.copy()) return FunctionCall('ordering', ArgumentsList((ordering_op, body)))
[docs] @inline_args def ordering_op(self, *ordering_op: Any): return ' '.join([x.value for x in ordering_op])
[docs] @inline_args def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None): cond = _canonize_single_argument(self.visit(cond)) with self.local_variable_guard(): suite = self.visit(suite) if else_suite is None: else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), )) else: with self.local_variable_guard(): else_suite = self.visit(else_suite) return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs] @inline_args def foreach_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) return FunctionCall('foreach', ArgumentsList((cs_list, suite)))
[docs] @inline_args def foreach_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any) -> object: values_cs_list = self.visit(values_cs_list) variables_cs_list = self.visit(variables_cs_list) if len(variables_cs_list.items) != len(values_cs_list.items): raise ValueError(f'Number of variables does not match the number of values: {len(variables_cs_list.items)} vs {len(values_cs_list.items)}. Variables: {variables_cs_list.items}, Values: {values_cs_list.items}') variable_items = list() for i in range(len(variables_cs_list.items)): # Variables are just names, not typed Variables. So we need to wrap them. return_type = values_cs_list.items[i].return_type if not return_type.is_list_type: raise ValueError(f'Invalid foreach_in statement: {values_cs_list.items[i]} is not a list.') variable_items.append(Variable(variables_cs_list.items[i], return_type.element_type)) variables_cs_list = CSList(tuple(variable_items)) with self.expression_def_ctx.new_variables(*variables_cs_list.items), self.local_variable_guard(): suite = self.visit(suite) return FunctionCall('foreach_in', ArgumentsList((variables_cs_list, values_cs_list, suite, )))
[docs] @inline_args def while_stmt(self, cond: Any, suite: Any): cond = _canonize_single_argument(self.visit(cond)) with self.local_variable_guard(): suite = self.visit(suite) return FunctionCall('while', ArgumentsList((cond, suite)))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items): body = self.visit(suite) for item in reversed(cs_list.items): body = quantification_cls(item, body) return body
[docs] @inline_args def forall_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs] @inline_args def exists_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs] @inline_args def findall_test(self, variable: Variable, suite: Any): with self.expression_def_ctx.new_variables(variable), self.local_variable_guard(): body = self.visit(suite) return E.FindAllExpression(variable, body)
[docs] @inline_args def findone_test(self, variable: Variable, suite: Any): with self.expression_def_ctx.new_variables(variable), self.local_variable_guard(): body = self.visit(suite) return E.FindOneExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.local_variable_guard(): item: InTypedArgument for item in cs_list.items: self.local_variables[item.name] = self.visit(item.value) body = self.visit(suite) return quantification_cls(body)
[docs] @inline_args def forall_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs] @inline_args def exists_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs] @inline_args def bind_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True, allow_multiple_expressions=True) if isinstance(body, list): body = E.AndExpression(*body) for item in cs_list.items: self.local_variables[item.name] = E.VariableExpression(item) return FunctionCall('bind', ArgumentsList((cs_list, body)))
[docs] @inline_args def bind_stmt_no_where(self, cs_list: Any): """Captures bind statements without a body. For example: .. code-block:: python bind x: int, y: int """ cs_list = self.visit(cs_list) for item in cs_list.items: self.local_variables[item.name] = E.VariableExpression(item) return FunctionCall('bind', ArgumentsList((cs_list, E.NullExpression(BOOL))))
[docs] @inline_args def annotated_compound_stmt(self, annotations: dict, stmt: Any): stmt = self.visit(stmt) if stmt.annotations is None: stmt.annotations = annotations else: stmt.annotations.update(annotations) return stmt
[docs] @dataclass class ArgumentsDef(object): arguments: Tuple[Variable, ...]
[docs] @dataclass class PreconditionPart(object): suite: Tree
[docs] @dataclass class EffectPart(object): suite: Tree
[docs] @dataclass class GoalPart(object): suite: Tree
[docs] @dataclass class BodyPart(object): suite: Tree
[docs] @dataclass class SideEffectPart(object): suite: Tree
[docs] @dataclass class InPart(object): suite: Tree
[docs] @dataclass class OutPart(object): suite: Tree
[docs] class CDLDomainTransformer(CDLLiteralTransformer):
[docs] def __init__(self, domain: Optional[CrowDomain] = None, auto_init_domain: bool = True): super().__init__() if auto_init_domain or domain is not None: self._domain = CrowDomain() if domain is None else domain self._expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self._expression_interpreter = CDLExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False) else: self._domain = None self._expression_def_ctx = None self._expression_interpreter = None
@property def domain(self) -> CrowDomain: if self._domain is None: raise ValueError('Domain is not initialized.') return self._domain @property def expression_def_ctx(self) -> E.ExpressionDefinitionContext: if self._expression_def_ctx is None: raise ValueError('Expression definition context is not initialized.') return self._expression_def_ctx @property def expression_interpreter(self) -> CDLExpressionInterpreter: if self._expression_interpreter is None: raise ValueError('Expression interpreter is not initialized.') return self._expression_interpreter
[docs] @inline_args def pragma_definition(self, pragma: Dict[str, Any]): if g_parser_verbose: print('pragma_definition', pragma) self._handle_pragma(pragma)
def _handle_pragma(self, pragma: Dict[str, Any]): pass
[docs] @inline_args def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]): if g_parser_verbose: print(f'type_definition:: {typename=} {basetype=}') self.domain.define_type(typename, basetype)
[docs] @inline_args def object_constant_definition(self, name: str, typename: str): if g_parser_verbose: print(f'object_constant_definition:: {name=} {typename=}') self.domain.define_object_constant(name, typename)
[docs] @inline_args def feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = BOOL elif isinstance(ret, str): ret = self.domain.get_type(ret) return_stmt = None if suite is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) self.domain.define_feature(name, args.arguments, ret, derived_expression=return_stmt, **annotations) if g_parser_verbose: print(f'feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = BOOL elif isinstance(ret, str): ret = self.domain.get_type(ret) return_stmt = None if suite is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) self.domain.define_crow_function(name, args.arguments, ret, derived_expression=return_stmt, **annotations) if g_parser_verbose: print(f'function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def controller_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], effect: Optional[EffectPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if effect is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(effect.suite) suite = suite.get_all_assign_expressions() effect = CrowBehaviorOrderingSuite.make_sequential(*suite) self.domain.define_controller(name, args.arguments, effect, **annotations) if g_parser_verbose: print(f'controller_definition:: {name=} {args.arguments=}')
[docs] @inline_args def behavior_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def behavior_effect_definition(self, suite: Tree) -> EffectPart: return EffectPart(suite)
[docs] @inline_args def behavior_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def behavior_body_definition(self, suite: Tree) -> BodyPart: return BodyPart(suite)
[docs] @inline_args def behavior_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, EffectPart, BodyPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if g_parser_verbose: print(f'behavior_definition:: {name=} {args.arguments=} {annotations=}') precondition = list() goals = list() body = list() effect = None local_variables = None for part in parts: with self.expression_def_ctx.with_variables(*args.arguments): if isinstance(part, EffectPart): self.expression_interpreter.local_variables = local_variables if local_variables is not None else dict() suite = self.expression_interpreter.visit(part.suite) self.expression_interpreter.local_variables = dict() else: suite = self.expression_interpreter.visit(part.suite) if isinstance(part, PreconditionPart): suite = suite.get_all_check_expressions() precondition = [CrowPrecondition(x) for x in suite] if g_parser_verbose: print(jacinle.indent_text(f'Precondition: {precondition}')) elif isinstance(part, GoalPart): # if len(suite.items) != 1: # raise NotImplementedError('Multiple goals are not supported in the current version.') suite = suite.get_all_expr_expression(allow_multiple_expressions=False) goals.append(suite) if g_parser_verbose: print(jacinle.indent_text(f'Goal: {suite}')) elif isinstance(part, BodyPart): local_variables = suite.local_variables body = suite.get_all_behavior_expressions(use_runtime_assign=True) if len(body) == 0: body = CrowBehaviorOrderingSuite('sequential', tuple()) elif len(body) == 1: if not isinstance(body[0], CrowBehaviorOrderingSuite) or body[0].order.value != 'sequential': body = CrowBehaviorOrderingSuite('sequential', (body[0],), _skip_simplify=True) else: body = CrowBehaviorOrderingSuite('sequential', tuple(body), _skip_simplify=True) if g_parser_verbose: print(jacinle.indent_text(f'Body:')) print(jacinle.indent_text(str(body), level=2)) elif isinstance(part, EffectPart): suite = suite.get_all_assign_expressions() effect = CrowBehaviorOrderingSuite.make_sequential(*suite) if g_parser_verbose: print(jacinle.indent_text(f'Effect: {effect}')) else: raise ValueError(f'Invalid part: {part}') if effect is None: effect = CrowBehaviorOrderingSuite.make_sequential() if len(goals) == 0: goal = E.NullExpression(BOOL) self.domain.define_behavior(name, args.arguments, goal, body, precondition, effect, **annotations) elif len(goals) == 1: self.domain.define_behavior(name, args.arguments, goals[0], body, precondition, effect, **annotations) else: goal = E.NullExpression(BOOL) self.domain.define_behavior(name, args.arguments, goal, body, precondition, effect, **annotations) for i, goal in enumerate(goals): self.domain.define_behavior(f'{name}_{i}', args.arguments, goal, body, precondition, effect, **annotations)
[docs] @inline_args def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, InPart, OutPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if g_parser_verbose: print(f'generator_definition:: {name=} {args.arguments=} {annotations=}') inputs = None outputs = None goal = None with self.expression_def_ctx.with_variables(*args.arguments): for part in parts: suite = self.expression_interpreter.visit(part.suite) if isinstance(part, GoalPart): goal = suite.get_all_expr_expression(allow_multiple_expressions=False) if g_parser_verbose: print(jacinle.indent_text(f'Goal: {goal}')) elif isinstance(part, InPart): inputs = [E.VariableExpression(x) for x in suite.items] if g_parser_verbose: print(jacinle.indent_text(f'Inputs: {inputs}')) elif isinstance(part, OutPart): outputs = [E.VariableExpression(x) for x in suite.items] if g_parser_verbose: print(jacinle.indent_text(f'Outputs: {outputs}')) else: raise ValueError(f'Invalid part: {part}') self.domain.define_generator(name, args.arguments, goal, inputs, outputs, **annotations)
[docs] @inline_args def generator_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def generator_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def generator_in_definition(self, values: Tree) -> InPart: return InPart(values)
[docs] @inline_args def generator_out_definition(self, values: Tree) -> OutPart: return OutPart(values)
[docs] class CDLProblemTransformer(CDLDomainTransformer):
[docs] def __init__(self, domain: Optional[CrowDomain] = None, state: Optional[CrowState] = None, auto_constant_guess: bool = False): super().__init__(None, auto_init_domain=False) self._problem = None self.auto_constant_guess = auto_constant_guess self._domain_is_provided = False if domain is not None: self._domain_is_provided = True self._init_domain(domain, state)
def _handle_pragma(self, pragma: Dict[str, Any]): for key, value in pragma.items(): if key.startswith('planner_'): self.problem.set_planner_option(key[len('planner_'):], value) @property def problem(self) -> CrowProblem: return self._problem def _init_domain(self, domain: CrowDomain, state: Optional[CrowState] = None): if self._domain is not None: raise ValueError('Domain is already initialized. Cannot overwrite the domain.') self._domain = domain.clone(deep=False) self._problem = CrowProblem(domain=self.domain) self._expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self._expression_interpreter = CDLExpressionInterpreter(domain=self.domain, state=state, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=self.auto_constant_guess) for o in self.domain.constants.values(): if isinstance(o, ObjectConstant): self.problem.add_object(o.name, o.dtype.typename) self.expression_interpreter.local_variables[o.name] = E.ObjectConstantExpression(ObjectConstant(o.name, o.dtype))
[docs] @inline_args def domain_def(self, filename: str): if self._domain is not None: if not self._domain_is_provided: logger.warning('Domain is already initialized. Skip the in-place domain loading.') return domain = get_default_parser().parse_domain(filename) self._init_domain(domain)
[docs] @inline_args def problem_name(self, name: str): self._problem.name = name
[docs] @inline_args def objects_definition(self, *objects): for o in objects: if isinstance(o, CSList): for oo in o.items: self.problem.add_object(oo.name, oo.dtype.typename) self.expression_interpreter.local_variables[oo.name] = E.ObjectConstantExpression(ObjectConstant(oo.name, oo.dtype)) else: self.problem.add_object(o.name, o.dtype.typename) self.expression_interpreter.local_variables[o.name] = E.ObjectConstantExpression(ObjectConstant(o.name, o.dtype))
[docs] @inline_args def init_definition(self, suite: Tree): self.problem.init_state() suite = self.expression_interpreter.visit(suite) executor = self.domain.make_executor() execute_effect_statements(executor, suite.get_all_assign_expressions(), state=self.problem.state, scope=dict()) for stmt in suite.get_all_expr_expression(allow_multiple_expressions=True): if isinstance(stmt, E.FunctionApplicationExpression): executor.execute(E.AssignExpression(stmt, E.ConstantExpression.TRUE), state=self.problem.state)
[docs] @inline_args def goal_definition(self, suite: Tree): suite = self.expression_interpreter.visit(suite) suite = suite.get_all_expr_expression(allow_multiple_expressions=False) self.problem.set_goal(suite)
_parser = None
[docs] def get_default_parser() -> CDLParser: global _parser if _parser is None: _parser = CDLParser() return _parser
[docs] def load_domain_file(filename:str) -> CrowDomain: """Load a domain file. Args: filename: the filename of the domain file. Returns: the domain object. """ return get_default_parser().parse_domain(filename)
[docs] def load_domain_string(string: str) -> CrowDomain: """Load a domain from a string. Args: string: the string containing the domain definition. Returns: the domain object. """ return get_default_parser().parse_domain_str(string)
[docs] def load_domain_string_incremental(domain: CrowDomain, string: str) -> CrowDomain: """Load a domain from a string incrementally. Args: domain: the domain object to be updated. string: the string containing the domain definition. Returns: the domain object. """ return get_default_parser().parse_domain_str(string, domain=domain)
[docs] def load_problem_file(filename: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Load a problem file. Args: filename: the filename of the problem file. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return get_default_parser().parse_problem(filename, domain=domain)
[docs] def load_problem_string(string: str, domain: Optional[CrowDomain] = None) -> CrowProblem: """Load a problem from a string. Args: string: the string containing the problem definition. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return get_default_parser().parse_problem_str(string, domain=domain)
[docs] def parse_expression(domain: CrowDomain, string: str, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse an expression. Args: domain: the domain object. string: the string containing the expression. state: the current state, containing objects. variables: the variables. auto_constant_guess: whether to guess whether a variable is a constant. Returns: the parsed expression. """ return get_default_parser().parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)