Source code for concepts.pdsketch.parsers.pdsketch_v3_parser

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : pdsketch_v3_parser.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/10/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import os
import os.path as osp
import itertools
import contextlib
import collections
import jacinle

from typing import Any, Optional, Union, Sequence, Tuple, Set, List, Dict
from dataclasses import dataclass, field
from lark import Lark, Tree
from lark.visitors import Transformer, Interpreter, v_args
from lark.indenter import PythonIndenter

import concepts.dsl.expression as E
from concepts.dsl.dsl_types import TypeBase, AutoType, VectorValueType, TupleType, Variable, ObjectConstant, UnnamedPlaceholder

from concepts.pdsketch.domain import Domain, Problem3, State
from concepts.pdsketch.operator import Precondition, Effect, Implementation, OperatorApplicationExpression
from concepts.pdsketch.generator import FancyGenerator, GeneratorApplicationExpression
from concepts.pdsketch.regression_rule import RegressionRuleBodyItemType, FindExpression, AchieveExpression, RuntimeAssignExpression, RegressionCommitFlag
from concepts.pdsketch.regression_rule import ConditionalRegressionRuleBodyExpression, LoopRegressionRuleBodyExpression
from concepts.pdsketch.regression_rule import RegressionRuleApplicationExpression
from concepts.pdsketch.executor import PDSketchExecutor

logger = jacinle.get_logger(__name__)
inline_args = v_args(inline=True)


__all__ = [
    'PDSketchV3Parser', 'path_resolver',
    'PDSketchV3PathResolver', 'PDSketchV3DomainTransformer', 'PDSketchV3ProblemTransformer', 'PDSketch3LiteralTransformer', 'PDSketch3ExpressionInterpreter',
    'load_domain_file3', 'load_domain_string3', 'load_domain_string3_incremental',
    'load_problem_file3', 'load_problem_string3'
]


[docs]class PDSketchV3PathResolver(object):
[docs] def __init__(self, search_paths: Sequence[str] = tuple()): self.search_paths = list(search_paths)
[docs] def resolve(self, filename: str, relative_filename: Optional[str] = None) -> str: if osp.exists(filename): return filename # Try the relative filename first. if relative_filename is not None: relative_dir = osp.dirname(relative_filename) full_path = osp.join(relative_dir, filename) if osp.exists(full_path): return full_path # Try the current directory second. if osp.exists(osp.join(os.getcwd(), filename)): return osp.join(os.getcwd(), filename) # Then try the search paths. for path in self.search_paths: full_path = osp.join(path, filename) if osp.exists(full_path): return full_path raise FileNotFoundError(f'File not found: {filename}')
[docs] def add_search_path(self, path: str): self.search_paths.append(path)
[docs] def remove_search_path(self, path: str): self.search_paths.remove(path)
path_resolver = PDSketchV3PathResolver()
[docs]class PDSketchV3Parser(object): """The parser for PDSketch v3.""" grammar_file = osp.join(osp.dirname(osp.abspath(__file__)), 'pdsketch-v3.grammar') """The grammar definition v3 for PDSketch."""
[docs] def __init__(self): """Initialize the parser.""" with open(self.grammar_file, 'r') as f: self.grammar = f.read() self.parser = Lark(self.grammar, start='start', postlex=PythonIndenter(), parser='lalr')
[docs] def parse(self, filename: str) -> Tree: """Parse a PDSketch v3 file. Args: filename: the filename to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ filename = path_resolver.resolve(filename) with open(filename, 'r') as f: return self.parse_str(f.read())
[docs] def parse_str(self, s: str) -> Tree: """Parse a PDSketch v3 string. Args: s: the string to parse. Returns: the parse tree. It is a :class:`lark.Tree` object. """ # NB(Jiayuan Mao @ 2024/03/13): for reasons, the pdsketch-v3 grammar file requires that the string ends with a newline. # In particular, the suite definition requires that the file ends with a _DEDENT token, which seems to be only triggered by a newline. s = s.strip() + '\n' return self.parser.parse(s)
[docs] def parse_domain(self, filename: str) -> Domain: """Parse a PDSketch v3 domain file. Args: filename: the filename to parse. Returns: the parsed domain. """ return self.transform_domain(self.parse(filename))
[docs] def parse_domain_str(self, s: str, domain: Optional[Domain] = None) -> Any: """Parse a PDSketch v3 domain string. Args: s: the string to parse. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ return self.transform_domain(self.parse_str(s), domain=domain)
[docs] def parse_problem(self, filename: str, domain: Optional[Domain] = None) -> Problem3: """Parse a PDSketch v3 problem file. Args: filename: the filename to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse(filename), domain=domain)
[docs] def parse_problem_str(self, s: str, domain: Optional[Domain] = None) -> Problem3: """Parse a PDSketch v3 problem string. Args: s: the string to parse. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ return self.transform_problem(self.parse_str(s), domain=domain)
[docs] def parse_expression(self, s: str, domain: Domain, state: Optional[State] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse a PDSketch v3 expression string. Args: s: the string to parse. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ return self.transform_expression(self.parse_str(s), domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)
[docs] @staticmethod def transform_domain(tree: Tree, domain: Optional[Domain] = None) -> Domain: """Transform a parse tree into a domain. Args: tree: the parse tree. domain: the domain to use. If not provided, a new domain will be created. Returns: the parsed domain. """ transformer = PDSketchV3DomainTransformer(domain) transformer.transform(tree) return transformer.domain
[docs] @staticmethod def transform_problem(tree: Tree, domain: Optional[Domain] = None) -> Problem3: """Transform a parse tree into a problem. Args: tree: the parse tree. domain: the domain to use. If not provided, the domain will be parsed from the problem file. Returns: the parsed problem. """ transformer = PDSketchV3ProblemTransformer(domain) transformer.transform(tree) return transformer.problem
[docs] @staticmethod def transform_expression(tree: Tree, domain: Domain, state: Optional[State] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Transform a parse tree into an expression. Args: tree: the parse tree. domain: the domain to use. state: the current state, containing objects. variables: variables from the outer scope. auto_constant_guess: whether to automatically guess whether a variable is a constant. Returns: the parsed expression. """ transformer = PDSketchV3ProblemTransformer(domain, state, auto_constant_guess=auto_constant_guess) interpreter = transformer.expression_interpreter expression_def_ctx = transformer.expression_def_ctx # the root of the tree is the `start` rule. tree = transformer.transform(tree).children[0] if variables is None: variables = tuple() with expression_def_ctx.with_variables(*variables) as ctx: return interpreter.visit(tree)
[docs]@dataclass class LiteralValue(object): """A literal value.""" value: Union[bool, int, float, complex, str]
[docs]@dataclass class LiteralList(object): """A list of literals.""" items: Tuple[Union[bool, int, float, complex, str], ...]
[docs]@dataclass class LiteralSet(object): """A set of literals.""" items: Set[Union[bool, int, float, complex, str]]
[docs]@dataclass class InTypedArgument(object): """A typed argument defined as `name in value`. This is used in forall/exists statements.""" name: str value: Any
[docs]class PDSketch3LiteralTransformer(Transformer): """The transformer for literal types. Including: - VARNAME, CONSTNAME, BASIC_TYPENAME - number, DEC_NUMBER, HEX_NUMBER, BIN_NUMBER, OCT_NUMBER, FLOAT_NUMBER, IMAG_NUMBER - boolean, TRUE, FALSE - string - literal_list - literal_set - decorator_k, decorator_kwarg, decorator_kwargs """ domain: Domain
[docs] @inline_args def typename(self, name: Union[str, TypeBase]) -> TypeBase: """Captures typenames including basic types and vector types.""" return name
[docs] @inline_args def sized_vector_typename(self, name: Union[str, TypeBase], size: int) -> VectorValueType: """Captures sized vector typenames defined as `vector[typename, size]`.""" return VectorValueType(self.domain.get_type(name), size)
[docs] @inline_args def unsized_vector_typename(self, name: Union[str, TypeBase]) -> VectorValueType: """Captures unsized vector typenames defined as `vector[typename]`.""" return VectorValueType(self.domain.get_type(name))
[docs] @inline_args def typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name: typename`.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def is_typed_argument(self, name: str, typename: Union[str, TypeBase]) -> Variable: """Captures typed arguments defined as `name is typename`. This is used in forall/exists statements.""" if isinstance(typename, str): typename = self.domain.get_type(typename) return Variable(name, typename)
[docs] @inline_args def in_typed_argument(self, name: str, value: Any) -> InTypedArgument: """Captures typed arguments defined as `name in value`. This is used in forall/exists statements.""" return InTypedArgument(name, value)
[docs] def arguments_def(self, args): """Captures the arguments definition. This is used in function definitions.""" return ArgumentsDef(tuple(args))
[docs] def VARNAME(self, token): """Captures variable names, such as `var_name`.""" return token.value
[docs] def CONSTNAME(self, token): """Captures constant names, such as `CONST_NAME`.""" return token.value
[docs] def BASIC_TYPENAME(self, token): """Captures basic type names (non-vector types), such as `int`, `float`, `bool`, `object`, etc.""" return token.value
[docs] @inline_args def number(self, value: Union[int, float, complex]) -> Union[int, float, complex]: """Captures number literals, including integers, floats, and complex numbers.""" return value
[docs] @inline_args def BIN_NUMBER(self, value: str) -> int: """Captures binary number literals.""" return int(value, 2)
[docs] @inline_args def OCT_NUMBER(self, value: str) -> int: """Captures octal number literals.""" return int(value, 8)
[docs] @inline_args def DEC_NUMBER(self, value: str) -> int: """Captures decimal number literals.""" return int(value)
[docs] @inline_args def HEX_NUMBER(self, value: str) -> int: """Captures hexadecimal number literals.""" return int(value, 16)
[docs] @inline_args def FLOAT_NUMBER(self, value: str) -> float: """Captures floating point number literals.""" return float(value)
[docs] @inline_args def IMAG_NUMBER(self, value: str) -> complex: """Captures complex number literals.""" return complex(value)
[docs] @inline_args def boolean(self, value: bool) -> bool: """Captures boolean literals.""" return value
[docs] @inline_args def TRUE(self, _) -> bool: """Captures the `True` literal.""" return True
[docs] @inline_args def FALSE(self, _) -> bool: """Captures the `False` literal.""" return False
[docs] @inline_args def ELLIPSIS(self, _) -> str: """Captures the `...` literal.""" return Ellipsis
[docs] @inline_args def string(self, value: str) -> str: """Captures string literals.""" if value[0] == value[-1] and value[0] in ('"', "'"): value = value[1:-1] return str(value)
[docs] @inline_args def literal_list(self, *items: Any) -> LiteralList: """Captures literal lists, such as `[1, 2, 3, 4]`.""" return LiteralList(tuple(items))
[docs] @inline_args def literal_set(self, *items: Any) -> LiteralSet: """Captures literal sets, such as `{1, 2, 3, 4}`.""" return LiteralSet(set(items))
[docs] @inline_args def literal(self, value: Union[bool, int, float, complex, str, LiteralList, LiteralSet]) -> Union[LiteralValue, LiteralList, LiteralSet]: """Captures literal values.""" if isinstance(value, (bool, int, float, complex, str)): return LiteralValue(value) elif isinstance(value, (LiteralList, LiteralSet)): return value else: raise ValueError(f'Invalid literal value: {value}')
[docs] @inline_args def decorator_kwarg(self, k, v: Union[LiteralValue, LiteralList, LiteralSet] = True) -> Tuple[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pair of a decorator. This is used in the decorator syntax, such as [[k=True]].""" return k, v.value if isinstance(v, LiteralValue) else v
[docs] def decorator_kwargs(self, args) -> Dict[str, Union[bool, int, float, complex, str, LiteralList, LiteralSet]]: """Captures the key-value pairs of a decorator. This is used in the decorator syntax, such as [[k=True, k2=123, k3=[1, 2, 3]]].""" return {k: v for k, v in args}
[docs]@dataclass class ArgumentsList(object): """A list of argument values. They can be variables, function calls, or other expressions.""" arguments: Tuple[Union[E.ValueOutputExpression, E.ListExpansionExpression, E.VariableExpression, bool, int, float, complex, str], ...]
[docs]@dataclass class FunctionCall(object): """A function call. This is used as the intermediate representation of the parsed expressions. Note that this includes not only function calls but also primitive operators and control flow statements. """ name: str args: ArgumentsList annotations: Optional[Dict[str, Any]] = None def __str__(self): annotation_str = '' if self.annotations is not None: annotation_str = f'[[' + ', '.join(f'{k}={v}' for k, v in self.annotations.items()) + ']] ' arg_strings = [str(arg) for arg in self.args.arguments] if sum(len(arg) for arg in arg_strings) > 80: arg_strings = [jacinle.indent_text(arg) for arg in arg_strings] return f'{annotation_str}{self.name}:\n' + '\n'.join(arg_strings) return f'{annotation_str}{self.name}(' + ', '.join(arg_strings) + ')' def __repr__(self): return f'FunctionCall{{{str(self)}}}'
[docs]@dataclass class CSList(object): """A comma-separated list of something.""" items: Tuple[Any, ...] def __str__(self): return f'CSList({", ".join(str(item) for item in self.items)})' def __repr__(self): return self.__str__()
[docs]@dataclass class Suite(object): """A suite of statements. This is used as the intermediate representation of the parsed expressions.""" items: Tuple[Any, ...] tracker: Optional['FunctionCallTracker'] = None def _init_tracker(self, use_runtime_assign: bool = False): self.tracker = FunctionCallTracker(self, dict(), use_runtime_assign=use_runtime_assign).run()
[docs] def get_all_assign_expressions(self) -> List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]]: if self.tracker is None: self._init_tracker() return self.tracker.assign_expressions
[docs] def get_all_check_expressions(self) -> List[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() return self.tracker.check_expressions
[docs] def get_all_regression_expressions(self, use_runtime_assign=True) -> List[RegressionRuleBodyItemType]: if self.tracker is None: self._init_tracker(use_runtime_assign=use_runtime_assign) return self.tracker.regression_expressions
[docs] def get_all_expr_expression(self, allow_multiple_expressions: bool = False) -> Optional[Union[E.ValueOutputExpression, Tuple[E.ValueOutputExpression, ...]]]: if self.tracker is None: self._init_tracker() if len(self.tracker.expr_expressions) == 1: return self.tracker.expr_expressions[0] if not allow_multiple_expressions: raise ValueError(f'Multiple expressions found in a single suite: {self.tracker.expr_expressions}') if len(self.tracker.expr_expressions) == 0: return tuple() return tuple(self.tracker.expr_expressions)
[docs] def get_combined_return_expression(self, allow_expr_expressions: bool = False) -> Optional[E.ValueOutputExpression]: if self.tracker is None: self._init_tracker() if self.tracker.return_expression is not None: return self.tracker.return_expression if allow_expr_expressions: return self.get_all_expr_expression(allow_multiple_expressions=False) return None
def __str__(self): if len(self.items) == 0: return 'Suite{}' if len(self.items) == 1: return f'Suite{{{self.items[0]}}}' return 'Suite{\n' + '\n'.join(jacinle.indent_text(str(item)) for item in self.items) + '\n}' def __repr__(self): return self.__str__()
[docs]class FunctionCallTracker(object): """This class is used to track the function calls and other statements in a suite. It supports simulating the execution of the program and generating the post-condition of the program. """
[docs] def __init__(self, suite: Suite, init_local_variables: Optional[Dict[str, Any]] = None, use_runtime_assign: bool = False): self.suite = suite self.local_variables = dict() if init_local_variables is None else init_local_variables self.assign_expressions = list() self.assign_expression_signatures = dict() self.check_expressions = list() self.expr_expressions = list() self.regression_expressions = list() self.return_expression = None self.use_runtime_assign = use_runtime_assign
local_variables_stack: List[Dict[str, Any]] """The assignments of local variables.""" assign_expressions: List[Tuple[E.VariableAssignmentExpression, Dict[str, Any]]] """A list of assign expressions.""" assign_expression_signatures: Dict[Tuple[str, ...], E.VariableAssignmentExpression] """A dictionary of assign expressions, indexed by their signatures.""" check_expressions: List[E.ValueOutputExpression] """A list of check expressions.""" expr_expressions: List[E.ValueOutputExpression] """A list of expr expressions (i.e. bare expressions in the body).""" return_expression: Optional[E.ValueOutputExpression] """The return expression. This is either None or a single expression.""" use_runtime_assign: bool """Whether to use runtime assign expressions. If this is set to True, assignment expressions for local variables will be converted into runtime assign expressions.""" regression_expressions: List[Union[ OperatorApplicationExpression, FindExpression, AchieveExpression, RuntimeAssignExpression, E.ListExpansionExpression, RegressionCommitFlag, RegressionRuleApplicationExpression, ConditionalRegressionRuleBodyExpression, LoopRegressionRuleBodyExpression ]] def _g( self, expr: Union[E.Expression, UnnamedPlaceholder, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, RuntimeAssignExpression, Implementation] ) -> Union[E.Expression, UnnamedPlaceholder, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, RuntimeAssignExpression, Implementation]: from concepts.pdsketch.predicate import flatten_expression if isinstance(expr, OperatorApplicationExpression): return OperatorApplicationExpression(expr.operator, [self._g(arg) for arg in expr.arguments]) if isinstance(expr, GeneratorApplicationExpression): return GeneratorApplicationExpression(expr.generator, [self._g(arg) for arg in expr.arguments]) if isinstance(expr, RegressionRuleApplicationExpression): return RegressionRuleApplicationExpression(expr.rule, [self._g(arg) for arg in expr.arguments]) if isinstance(expr, RuntimeAssignExpression): return RuntimeAssignExpression(expr.variable, self._g(expr.expression)) if isinstance(expr, Implementation): return Implementation(expr.name, [self._g(arg) for arg in expr.arguments]) if isinstance(expr, UnnamedPlaceholder): return expr if not isinstance(expr, E.Expression): raise ValueError(f'Invalid expression: {expr}') return flatten_expression(expr, { E.VariableExpression(Variable(k, None)): v for k, v in self.local_variables.items() }) def _get_deictic_signature(self, e, known_deictic_vars=tuple()) -> Optional[Tuple[str, ...]]: if isinstance(e, E.DeicticAssignExpression): known_deictic_vars = known_deictic_vars + (e.variable.name,) return self._get_deictic_signature(e.expression, known_deictic_vars) elif isinstance(e, E.AssignExpression): args = [x.name if x.name not in known_deictic_vars else '?' for x in e.predicate.arguments] return tuple((e.predicate.function.name, *args)) else: return None def _mark_assign(self, *exprs: E.VariableAssignmentExpression, annotations: Optional[Dict[str, Any]] = None): if annotations is None: annotations = dict() for expr in exprs: signature = self._get_deictic_signature(expr) if signature is not None: if signature in self.assign_expression_signatures: raise ValueError(f'Duplicate assign expression: {expr} vs {self.assign_expression_signatures[signature]}') self.assign_expressions.append((expr, annotations)) self.assign_expression_signatures[signature] = expr else: self.assign_expressions.append((expr, annotations))
[docs] @jacinle.log_function(verbose=False) def run(self): """Simulate the execution of the program and generates an equivalent return statement. This function handles if-else conditions. However, loops are not allowed.""" # jacinle.log_function.print('Current suite:', self.suite) current_return_statement = None current_return_statement_condition_neg = None for item in self.suite.items: assert isinstance(item, FunctionCall), f'Invalid item in suite: {item}' if item.name == 'assign': if isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and item.args.arguments[1] is Ellipsis: self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], E.NullExpression(item.args.arguments[0].return_type))), annotations=item.annotations) elif isinstance(item.args.arguments[0], E.FunctionApplicationExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression)): self._mark_assign(self._g(E.AssignExpression(item.args.arguments[0], item.args.arguments[1])), annotations=item.annotations) elif isinstance(item.args.arguments[0], E.VariableExpression) and isinstance(item.args.arguments[1], (E.ValueOutputExpression, E.VariableExpression, E.FindAllExpression)): if self.use_runtime_assign: if item.args.arguments[0].name not in self.local_variables: self.local_variables[item.args.arguments[0].name] = E.VariableExpression(item.args.arguments[0]) self.regression_expressions.append(RuntimeAssignExpression( item.args.arguments[0].variable, self._g(item.args.arguments[1]) )) else: self.local_variables[item.args.arguments[0].name] = self._g(item.args.arguments[1]) else: raise ValueError(f'Invalid assignment: {item}. Types: {type(item.args.arguments[0])}, {type(item.args.arguments[1])}.') elif item.name == 'check': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid check expression: {item.args.arguments[0]}' self.check_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'expr': if isinstance(item.args.arguments[0], (OperatorApplicationExpression, RegressionRuleApplicationExpression, E.ListExpansionExpression)): self.regression_expressions.append(self._g(item.args.arguments[0])) else: assert isinstance( item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression, GeneratorApplicationExpression, Implementation) ), f'Invalid expr expression: {item.args.arguments[0]}' self.expr_expressions.append(self._g(item.args.arguments[0])) elif item.name == 'find': arguments = item.args.arguments[0].items body: Suite = item.args.arguments[1] self.regression_expressions.append(FindExpression(arguments, body.get_combined_return_expression(allow_expr_expressions=True), **item.annotations if item.annotations is not None else dict())) elif item.name == 'achieve': term = item.args.arguments[0] self.regression_expressions.append(AchieveExpression(term, tuple(), **item.annotations if item.annotations is not None else dict())) # TODO(Jiayuan Mao @ 2024/03/2): implement achieve-maintain elif item.name == 'return': assert isinstance(item.args.arguments[0], (E.ValueOutputExpression, E.VariableExpression)), f'Invalid return expression: {item.args.arguments[0]}' self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, self._g(item.args.arguments[0])) break elif item.name == 'if': condition = self._g(item.args.arguments[0]) neg_condition = E.NotExpression(condition) assert isinstance(condition, E.ValueOutputExpression), f'Invalid condition: {condition}. Type: {type(condition)}.' t_suite = item.args.arguments[1] f_suite = item.args.arguments[2] t_tracker = FunctionCallTracker(t_suite, self.local_variables.copy()).run() f_tracker = FunctionCallTracker(f_suite, self.local_variables.copy()).run() assert set(t_tracker.local_variables.keys()) == set(f_tracker.local_variables.keys()), f'Local variables in the true and false branches are not consistent: {t_tracker.local_variables.keys()} vs {f_tracker.local_variables.keys()}' new_local_variables = t_tracker.local_variables for k, v in t_tracker.local_variables.items(): if f_tracker.local_variables[k] != v: new_local_variables[k] = E.ConditionExpression(condition, v, f_tracker.local_variables[k]) self.local_variables = new_local_variables # TODO(Jiayuan Mao @ 2024/03/2): optimize the implementation for this by merging the conditions. for expr, annotations in t_tracker.assign_expressions: self._mark_assign(_make_conditional_assign(expr, condition), annotations=annotations) for expr, annotations in f_tracker.assign_expressions: self._mark_assign(_make_conditional_assign(expr, neg_condition), annotations=annotations) for expr in t_tracker.check_expressions: self.check_expressions.append(_make_conditional_implies(condition, expr)) for expr in f_tracker.check_expressions: self.check_expressions.append(_make_conditional_implies(neg_condition, expr)) if len(t_tracker.expr_expressions) != len(f_tracker.expr_expressions): raise ValueError(f'Number of bare expressions in the true and false branches are not consistent: {len(t_tracker.expr_expressions)} vs {len(f_tracker.expr_expressions)}') if len(t_tracker.expr_expressions) == 0: pass elif len(t_tracker.expr_expressions) == 1: self.expr_expressions.append(E.ConditionExpression(condition, t_tracker.expr_expressions[0], f_tracker.expr_expressions[0])) else: raise ValueError(f'Multiple bare expressions in the true and false branches are not supported: {t_tracker.expr_expressions} vs {f_tracker.expr_expressions}') if len(t_tracker.regression_expressions) > 0: self.regression_expressions.append(ConditionalRegressionRuleBodyExpression(condition, t_tracker.regression_expressions)) if len(f_tracker.regression_expressions) > 0: self.regression_expressions.append(ConditionalRegressionRuleBodyExpression(neg_condition, f_tracker.regression_expressions)) if t_tracker.return_expression is not None and f_tracker.return_expression is not None: # Both branches have return statements. statement = E.ConditionExpression(condition, t_tracker.return_expression, f_tracker.return_expression) self.return_expression = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, statement) break elif t_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, t_tracker.return_expression) current_return_statement_condition_neg = E.NotExpression(condition) elif f_tracker.return_expression is not None: current_return_statement = _make_conditional_return(current_return_statement, current_return_statement_condition_neg, f_tracker.return_expression) current_return_statement_condition_neg = condition else: pass elif item.name == 'forall': suite = item.args.arguments[1] tracker = FunctionCallTracker(suite, self.local_variables.copy()).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') for expr, annotations in tracker.assign_expressions: for var in item.args.arguments[0].items: expr = E.DeicticAssignExpression(var, expr) self._mark_assign(expr, annotations=annotations) for expr in tracker.check_expressions: for var in item.args.arguments[0].items: expr = E.ForallExpression(var, expr) self.check_expressions.append(expr) if len(tracker.expr_expressions) == 0: pass else: if len(tracker.expr_expressions) == 1: merged = tracker.expr_expressions[0] else: merged = E.AndExpression(*tracker.expr_expressions) for var in item.args.arguments[0].items: merged = E.ForallExpression(var, merged) self.expr_expressions.append(merged) if len(tracker.regression_expressions) > 0: raise ValueError(f'Regression rules are not allowed in a forall statement: {tracker.regression_expressions}') if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a forall statement: {tracker.return_expression}') elif item.name == 'forall_in': suite = item.args.arguments[0] tracker = FunctionCallTracker(suite, self.local_variables.copy()).run() for k in tracker.local_variables: if k in self.local_variables and self.local_variables[k] != tracker.local_variables[k]: raise ValueError(f'Local variable {k} is assigned in the forall statement but has been assigned before: {self.local_variables[k]} vs {tracker.local_variables[k]}') for expr, annotations in tracker.assign_expressions: raise ValueError(f'Assign statements are not allowed in a forall_in statement: {expr}') for expr in tracker.check_expressions: self.check_expressions.append(E.AndExpression(expr)) if len(tracker.expr_expressions) == 0: pass else: if len(tracker.expr_expressions) == 1: merged = tracker.expr_expressions[0] else: merged = E.AndExpression(*tracker.expr_expressions) self.expr_expressions.append(merged) if len(tracker.regression_expressions) > 0: # TODO(Jiayuan Mao @ 2024/03/12): implement the rest parts of regression statements. for expr in tracker.regression_expressions: if isinstance(expr, AchieveExpression): self.regression_expressions.append(E.ListExpansionExpression(E.AndExpression(expr.goal))) else: raise ValueError(f'Regression rule items except for achieve statements are not allowed in a forall_in statement: {expr}') if tracker.return_expression is not None: raise ValueError(f'Return statement is not allowed in a forall_in statement: {tracker.return_expression}') elif item.name == 'pass': pass # jacinle.log_function.print('Local variables:', self.local_variables) # jacinle.log_function.print('Assign expressions:', self.assign_expressions) # jacinle.log_function.print('Check expressions:', self.check_expressions) # jacinle.log_function.print('Expr expressions:', self.expr_expressions) # jacinle.log_function.print('Return expression:', self.return_expression) return self
def _make_conditional_implies(condition: E.ValueOutputExpression, test: E.ValueOutputExpression): if isinstance(test, E.BoolExpression) and test.op == E.BoolOpType.IMPLIES: if isinstance(test.arguments[0], E.BoolExpression) and test.arguments[0].op == E.BoolOpType.AND: return E.ImpliesExpression(E.AndExpression(condition, *test.arguments[0].arguments), test.arguments[1]) else: return E.ImpliesExpression(E.AndExpression(condition, test.arguments[0]), test.arguments[1]) else: return E.ImpliesExpression(condition, test) def _make_conditional_return(current_stmt: Optional[E.ValueOutputExpression], current_condition_neg: Optional[E.ValueOutputExpression], new_stmt: E.ValueOutputExpression): if current_stmt is None: return new_stmt return E.ConditionExpression(current_condition_neg, new_stmt, current_stmt) def _make_conditional_assign(assign_stmt: E.VariableAssignmentExpression, condition: E.ValueOutputExpression): if isinstance(assign_stmt, E.AssignExpression): return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, condition) elif isinstance(assign_stmt, E.ConditionalAssignExpression): if isinstance(assign_stmt.condition, E.BoolExpression) and assign_stmt.condition.op == E.BoolOpType.AND: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, *assign_stmt.condition.arguments)) else: return E.ConditionalAssignExpression(assign_stmt.predicate, assign_stmt.value, E.AndExpression(condition, assign_stmt.condition)) elif isinstance(assign_stmt, E.DeicticAssignExpression): return E.DeicticAssignExpression(assign_stmt.variable, _make_conditional_assign(assign_stmt.expression, condition)) else: raise ValueError(f'Invalid assign statement: {assign_stmt}') def _make_conditional_regression(regression_stmt: Union[OperatorApplicationExpression, FindExpression, AchieveExpression, RegressionCommitFlag, RegressionRuleApplicationExpression, ConditionalRegressionRuleBodyExpression], condition: E.ValueOutputExpression): # NB(Jiayuan Mao @ 2024/03/2): Not used. Because the original ConditionalRegressionRuleBodyExpression already supports recursive conditions. if isinstance(regression_stmt, ConditionalRegressionRuleBodyExpression): if isinstance(regression_stmt.condition, E.BoolExpression) and regression_stmt.condition.op == E.BoolOpType.AND: return ConditionalRegressionRuleBodyExpression(E.AndExpression(condition, *regression_stmt.condition.arguments), regression_stmt.body) else: return ConditionalRegressionRuleBodyExpression(E.AndExpression(condition, regression_stmt.condition), regression_stmt.body) else: return ConditionalRegressionRuleBodyExpression(condition, (regression_stmt, ))
[docs]def gen_term_expr(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is used to generate the following functions: - mul_expr - arith_expr - shift_expr Args: expr_typename: the name of the expression type. This is only used for printing the debug information. Returns: the generated term expression function. """ @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] raise NotImplementedError(f'{expr_typename} expression is not supported in the current version.') assert len(values) % 2 == 1, f'[{expr_typename}] expressions expected an odd number of values, got {len(values)}. Values: {values}.' result = values[0] for i in range(1, len(values), 2): result = FunctionCall(values[i], ArgumentsList((result, values[i + 1]))) # print(f'[{expr_typename}] result: {result}') return result return term
[docs]def gen_term_expr_noop(expr_typename: str): """Generate a term expression function. This function is used to generate the term expression functions for the transformer. It is named `_noop` because the arguments to the function does not contain the operator being used. Therefore, we have to specify the operator name manually (`expr_typename`). This is used for the following functions: - bitand_expr - bitxor_expr - bitor_expr """ op_mapping = { 'bitand': E.BoolOpType.AND, 'bitxor': E.BoolOpType.XOR, 'bitor': E.BoolOpType.OR, } @inline_args def term(self, *values: Any): values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.BoolExpression(op_mapping[expr_typename], _canonize_arguments(values)) # print(f'[{expr_typename}] result: {result}') return result return term
def _canonize_single_argument(arg: Any, dtype: Optional[TypeBase] = None) -> Union[E.ValueOutputExpression, E.VariableExpression, type(Ellipsis)]: if isinstance(arg, (E.ObjectOrValueOutputExpression, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression, Implementation)): return arg if isinstance(arg, Variable): return E.VariableExpression(arg) if isinstance(arg, (bool, int, float, complex)): return E.ConstantExpression.from_value(arg, dtype=dtype) if arg is Ellipsis: return Ellipsis raise ValueError(f'Invalid argument: {arg}') def _canonize_arguments(args: Optional[Union[ArgumentsList, tuple, list]] = None, dtypes: Optional[Sequence[TypeBase]] = None) -> Tuple[Union[E.ValueOutputExpression, E.VariableExpression], ...]: if args is None: return tuple() canonized_args = list() # TODO(Jiayuan Mao @ 2024/03/2): Strictly check the allowability of "Ellipsis" in the arguments. for i, arg in enumerate(args.arguments if isinstance(args, ArgumentsList) else args): if arg is Ellipsis: canonized_args.append(Ellipsis) else: canonized_args.append(_canonize_single_argument(arg, dtype=dtypes[i] if dtypes is not None else None)) return tuple(canonized_args) def _has_list_arguments(args: Tuple[E.ObjectOrValueOutputExpression, ...]) -> bool: for arg in args: if arg.return_type.is_list_type: return True return False
[docs]class PDSketch3ExpressionInterpreter(Interpreter): """The transformer for expressions. Including: - typename - sized_vector_typename - unsized_vector_typename - typed_argument - is_typed_argument - in_typed_argument - arguments_def - atom_expr_funccall - atom_varname - atom - power - factor - unary_op_expr - mul_expr - arith_expr - shift_expr - bitand_expr - bitxor_expr - bitor_expr - comparison_expr - not_test - and_test - or_test - cond_test - test - test_nocond - tuple - list - cs_list - suite - expr_stmt - expr_list_expansion_stmt - assign_stmt - annotated_assign_stmt - local_assign_stmt - pass_stmt - check_stmt - return_stmt - achieve_stmt - compound_assign_stmt - compound_check_stmt - compound_return_stmt - compound_achieve_stmt - if_stmt - forall_stmt - forall_test - exists_test - findall_test - forall_in_test - exists_in_test """
[docs] def __init__(self, domain: Optional[Domain], state: Optional[State], expression_def_ctx: E.ExpressionDefinitionContext, auto_constant_guess: bool = False): super().__init__() self.domain = domain self.state = state self.expression_def_ctx = expression_def_ctx self.auto_constant_guess = auto_constant_guess self.generator_impl_outputs = None self.local_variables = dict()
[docs] def set_domain(self, domain: Domain): self.domain = domain
[docs] @contextlib.contextmanager def local_variable_guard(self): backup = self.local_variables.copy() yield self.local_variables = backup
[docs] @contextlib.contextmanager def set_generator_impl_outputs(self, outputs: List[Variable]): backup = self.generator_impl_outputs self.generator_impl_outputs = outputs yield self.generator_impl_outputs = backup
domain: Domain expression_def_ctx: E.ExpressionDefinitionContext
[docs] def visit(self, tree: Any) -> Any: if isinstance(tree, Tree): return super().visit(tree) return tree
[docs] @inline_args def atom_varname(self, name: str) -> Union[E.VariableExpression, E.ObjectConstantExpression, E.ValueOutputExpression]: """Captures variable names such as `var_name`.""" if name in self.local_variables: return self.local_variables[name] if self.state is not None and name in self.state.object_name2defaultindex: return E.ObjectConstantExpression(ObjectConstant(name, self.domain.types[self.state.get_typename(name)])) # TODO(Jiayuan Mao @ 2024/03/12): smartly guess the type of the variable. if not self.expression_def_ctx.has_variable(name): if self.auto_constant_guess: return E.ObjectConstantExpression(ObjectConstant(name, AutoType)) variable = self.expression_def_ctx.wrap_variable(name) return variable
[docs] @inline_args def atom_expr_funccall(self, annotations: dict, name: str, args: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression, Implementation, OperatorApplicationExpression, GeneratorApplicationExpression, RegressionRuleApplicationExpression]: """Captures function calls, such as `func_name(arg1, arg2, ...)`.""" annotations: Optional[dict] = self.visit(annotations) args: Optional[ArgumentsList] = self.visit(args) if annotations is None: annotations = dict() if args is None: args = ArgumentsList(tuple()) if self.domain.has_function(name): function = self.domain.get_function(name) args_c = _canonize_arguments(args, function.ftype.argument_types) if _has_list_arguments(args_c): return E.ListFunctionApplicationExpression(function, args_c) return E.FunctionApplicationExpression(function, args_c) elif self.domain.has_operator(name): operator = self.domain.get_operator(name) args_c = _canonize_arguments(args, operator.argument_types) if len(args_c) > 0 and args_c[-1] is Ellipsis: args_c = args_c[:-1] + tuple([UnnamedPlaceholder(t) for t in operator.argument_types[len(args_c) - 1:]]) return OperatorApplicationExpression(operator, args_c) elif self.domain.has_generator(name): generator = self.domain.get_generator(name) args_c = _canonize_arguments(args, generator.argument_types) if isinstance(generator, FancyGenerator): raise NotImplementedError('Fancy generators are not supported in the current version.') return GeneratorApplicationExpression(generator, args_c) elif self.domain.has_regression_rule(name): rule = self.domain.get_regression_rule(name) args_c = _canonize_arguments(args, rule.argument_types) return RegressionRuleApplicationExpression(rule, args_c) else: if 'action_impl' in annotations and annotations['action_impl']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Controller function {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_predicate(name, argument_types, self.domain.get_type('__control__'), observation=False, state=False) return Implementation(predicate.name, args_c) elif 'generator_impl' in annotations and annotations['generator_impl']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Generator function {name} not found, creating a new one with argument types {argument_types}.') assert self.generator_impl_outputs is not None, f'Generator implementation {name} requires generator outputs to be set.' generator_outputs = TupleType([value.dtype for value in self.generator_impl_outputs]) predicate = self.domain.define_predicate(name, argument_types, generator_outputs, observation=False, state=False, is_generator_function=True) return Implementation(predicate.name, args_c) elif 'regression_impl' in annotations and annotations['regression_impl']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Regression function {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_predicate(name, argument_types, self.domain.get_type('__totally_ordered_plan__'), observation=False, state=False) return E.FunctionApplicationExpression(predicate, args_c) elif 'inplace_generator' in annotations and annotations['inplace_generator']: args_c = _canonize_arguments(args) argument_types = [arg.return_type for arg in args_c] # logger.warning(f'Generator placeholder function {name} not found, creating a new one with argument types {argument_types}.') predicate = self.domain.define_predicate( name, argument_types, self.domain.get_type('bool'), observation=False, state=False, generator_placeholder=annotations.get('generator_placeholder', True) ) assert 'inplace_generator_targets' in annotations, f'Inplace generator {name} requires inplace generator targets to be set.' inplace_generator_targets = annotations['inplace_generator_targets'] generator_name = 'gen_' + name generator_arguments = predicate.arguments generator_goal = E.FunctionApplicationExpression(predicate, [E.VariableExpression(arg) for arg in generator_arguments]) output_argument_names = [x.value for x in inplace_generator_targets.items] output_indices = list() for i, arg in enumerate(args_c): if isinstance(arg, E.VariableExpression) and arg.name in output_argument_names: output_argument_names.remove(arg.name) output_indices.append(i) assert len(output_argument_names) == 0, f'Mismatched output arguments for inplace generator {name}: {output_argument_names}' inputs = [arg for i, arg in enumerate(generator_arguments) if i not in output_indices] outputs = [arg for i, arg in enumerate(generator_arguments) if i in output_indices] impl = Implementation(generator_name + '_impl', inputs) self.domain.define_generator(generator_name, generator_arguments, generator_goal, inputs, outputs, impl) return E.FunctionApplicationExpression(predicate, args_c) else: raise KeyError(f'Function {name} not found. Note that recursive function calls are not supported in the current version.')
[docs] @inline_args def atom_subscript(self, name: str, index: Tree) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]: """Captures subscript expressions such as `name[index1, index2, ...]`.""" predicate = self.domain.get_predicate(name) index: CSList = self.visit(index) if not predicate.is_state_variable: raise ValueError(f'Invalid subscript expression: {name} is not a state variable. Expression: {name}[{index.items}]') items = index.items if len(items) == 1 and items[0] is Ellipsis: return E.FunctionApplicationExpression(predicate, tuple()) arguments = _canonize_arguments(index.items, dtypes=predicate.ftype.argument_types) if _has_list_arguments(arguments): return E.ListFunctionApplicationExpression(predicate, arguments) return E.FunctionApplicationExpression(predicate, arguments)
[docs] @inline_args def atom(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: """Captures atoms. This is used to Captures the base case of the expression, including literal constants, variables, and subscript expressions.""" return value
[docs] def arguments(self, args: Tree) -> ArgumentsList: """Captures the argument list. This is used in function calls.""" args = self.visit_children(args) return ArgumentsList(tuple(args))
[docs] @inline_args def power(self, base: Union[FunctionCall, Variable], exp: Optional[float] = None) -> Union[FunctionCall, Variable]: """The highest-priority expression. This is used to capture the power expression, such as `base ** exp`. If `exp` is None, it is treated as `base ** 1`.""" if exp is None: return base raise NotImplementedError('Power expression is not supported in the current version.')
[docs] @inline_args def factor(self, value: Union[FunctionCall, Variable]) -> Union[FunctionCall, Variable]: return value
[docs] @inline_args def unary_op_expr(self, op: str, value: Union[FunctionCall, Variable]) -> FunctionCall: raise NotImplementedError('Unary operators are not supported in the current version.')
mul_expr = gen_term_expr('mul') arith_expr = gen_term_expr('add') shift_expr = gen_term_expr('shift') bitand_expr = gen_term_expr_noop('bitand') bitxor_expr = gen_term_expr_noop('bitxor') bitor_expr = gen_term_expr_noop('bitor')
[docs] @inline_args def comparison_expr(self, *values: Union[E.ValueOutputExpression, E.VariableExpression]) -> E.ValueOutputExpression: if len(values) == 1: return self.visit(values[0]) assert len(values) % 2 == 1, f'[compare] expressions expected an odd number of values, got {len(values)}. Values: {values}.' values = [self.visit(value) for value in values] results = list() for i in range(1, len(values), 2): if values[i - 1].return_type.is_object_type and values[i + 1].return_type.is_object_type: results.append(E.ObjectCompareExpression(E.CompareOpType.from_string(values[i][0].value), values[i - 1], values[i + 1])) if len(results) == 1: return results[0] result = E.AndExpression(*results) return result
[docs] @inline_args def not_test(self, value: Any) -> E.NotExpression: return E.NotExpression(*_canonize_arguments([self.visit(value)]))
[docs] @inline_args def and_test(self, *values: Any) -> E.AndExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.AndExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def or_test(self, *values: Any) -> E.OrExpression: values = [self.visit(value) for value in values] if len(values) == 1: return values[0] result = E.OrExpression(*_canonize_arguments(values)) return result
[docs] @inline_args def cond_test(self, value1: Any, cond: Any, value2: Any) -> E.ConditionExpression: return E.ConditionExpression(*_canonize_arguments([self.visit(cond), self.visit(value1), self.visit(value2)]))
[docs] @inline_args def test(self, value: Any): return self.visit(value)
[docs] @inline_args def test_nocond(self, value: Any): return self.visit(value)
[docs] @inline_args def tuple(self, *values: Any): return tuple(self.visit(v) for v in values)
[docs] @inline_args def list(self, *values: Any): return E.ListCreationExpression(_canonize_arguments([self.visit(v) for v in values]))
[docs] @inline_args def cs_list(self, *values: Any): return CSList(tuple(self.visit(v) for v in values))
[docs] @inline_args def suite(self, *values: Tree) -> Suite: with self.local_variable_guard(): values = [self.visit(value) for value in values] return Suite(tuple(v for v in values if v is not None))
[docs] @inline_args def expr_stmt(self, value: Tree): value = self.visit(value) if value is Ellipsis: return None return FunctionCall('expr', ArgumentsList((_canonize_single_argument(value),)))
[docs] @inline_args def expr_list_expansion_stmt(self, value: Any): value = _canonize_single_argument(self.visit(value)) return FunctionCall('expr', ArgumentsList((E.ListExpansionExpression(value), )))
[docs] @inline_args def assign_stmt(self, target: Any, value: Any): return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args def annotated_assign_stmt(self, annotations: dict, target: Any, value: Any): return FunctionCall('assign', ArgumentsList((_canonize_single_argument(self.visit(target)), _canonize_single_argument(self.visit(value)))), annotations)
[docs] @inline_args def local_assign_stmt(self, target: str, value: Any = None): assert isinstance(target, str), f'Invalid local variable name: {target}' value = _canonize_single_argument(self.visit(value)) self.local_variables[target] = E.VariableExpression(Variable(target, value.return_type)) if value is not None: return FunctionCall('assign', ArgumentsList((self.local_variables[target], value))) return None
[docs] @inline_args def pass_stmt(self): return FunctionCall('pass', ArgumentsList(tuple()))
[docs] @inline_args def check_stmt(self, value: Any): return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def achieve_stmt(self, value: CSList): return FunctionCall('achieve', ArgumentsList(_canonize_arguments(self.visit(value).items)))
[docs] @inline_args def compound_assign_stmt(self, target: Variable, value: Any): return FunctionCall('assign', ArgumentsList((self.visit(target), _canonize_single_argument(self.visit(value)))))
[docs] @inline_args def compound_check_stmt(self, value: Any): return FunctionCall('check', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_return_stmt(self, value: Any): return FunctionCall('return', ArgumentsList((_canonize_single_argument(self.visit(value)), )))
[docs] @inline_args def compound_achieve_stmt(self, value: Any): return FunctionCall('achieve', ArgumentsList(_canonize_arguments([self.visit(value)])))
[docs] @inline_args def if_stmt(self, cond: Any, suite: Any, else_suite: Optional[Any] = None): cond = _canonize_single_argument(self.visit(cond)) with self.local_variable_guard(): suite = self.visit(suite) if else_suite is None: else_suite = Suite((FunctionCall('pass', ArgumentsList(tuple())), )) else: with self.local_variable_guard(): else_suite = self.visit(else_suite) return FunctionCall('if', ArgumentsList((cond, suite, else_suite)))
[docs] @inline_args def forall_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) return FunctionCall('forall', ArgumentsList((cs_list, suite)))
[docs] @inline_args def forall_in_stmt(self, variables_cs_list: Any, values_cs_list: Any, suite: Any): variables_cs_list = self.visit(variables_cs_list) values_cs_list = self.visit(values_cs_list) with self.local_variable_guard(): for variable_item, value_item in zip(variables_cs_list.items, values_cs_list.items): self.local_variables[variable_item] = value_item suite = self.visit(suite) return FunctionCall('forall_in', ArgumentsList((suite, )))
def _quantification_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) for item in reversed(cs_list.items): body = quantification_cls(item, body) return body
[docs] @inline_args def forall_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ForallExpression)
[docs] @inline_args def exists_test(self, cs_list: Any, suite: Any): return self._quantification_expression(cs_list, suite, E.ExistsExpression)
[docs] @inline_args def findall_test(self, variable: Variable, suite: Any): with self.expression_def_ctx.new_variables(variable), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) return E.FindAllExpression(variable, body)
def _quantification_in_expression(self, cs_list: Any, suite: Any, quantification_cls): cs_list = self.visit(cs_list) with self.local_variable_guard(): item: InTypedArgument for item in cs_list.items: self.local_variables[item.name] = self.visit(item.value) suite = self.visit(suite) body = suite.get_combined_return_expression(allow_expr_expressions=True) return quantification_cls(body)
[docs] @inline_args def forall_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.AndExpression)
[docs] @inline_args def exists_in_test(self, cs_list: Any, suite: Any): return self._quantification_in_expression(cs_list, suite, E.OrExpression)
[docs] @inline_args def find_stmt(self, cs_list: Any, suite: Any): cs_list = self.visit(cs_list) with self.expression_def_ctx.new_variables(*cs_list.items), self.local_variable_guard(): suite = self.visit(suite) body = suite.get_combined_return_expression() for item in cs_list.items: self.local_variables[item.name] = E.VariableExpression(item) return FunctionCall('find', ArgumentsList((cs_list, suite)))
[docs] @inline_args def annotated_compound_stmt(self, annotations: dict, stmt: Any): stmt = self.visit(stmt) stmt.annotations = annotations return stmt
[docs]@dataclass class ArgumentsDef(object): arguments: Tuple[Variable, ...]
[docs]@dataclass class PreconditionPart(object): suite: Tree
[docs]@dataclass class EffectPart(object): suite: Tree
[docs]@dataclass class GoalPart(object): suite: Tree
[docs]@dataclass class BodyPart(object): suite: Tree
[docs]@dataclass class SideEffectPart(object): suite: Tree
[docs]@dataclass class ImplPart(object): suite: Tree
[docs]@dataclass class InPart(object): suite: Tree
[docs]@dataclass class OutPart(object): suite: Tree
[docs]class PDSketchV3DomainTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[Domain] = None): super().__init__() self.domain = Domain(pdsketch_version=3) if domain is None else domain self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=None, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=False)
domain: Domain expression_def_ctx: E.ExpressionDefinitionContext expression_interpreter: PDSketch3ExpressionInterpreter
[docs] @inline_args def pragma_definition(self, pragma: Dict[str, Any]): print('pragma_definition', pragma)
[docs] @inline_args def type_definition(self, typename, basetype: Optional[Union[str, TypeBase]]): print(f'type_definition:: {typename=} {basetype=}') self.domain.define_type(typename, basetype)
[docs] @inline_args def derived_feature_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Optional[Tree]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = self.domain.get_type('bool') elif isinstance(ret, str): ret = self.domain.get_type(ret) return_stmt = None if suite is not None: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) if return_stmt is None: self.domain.define_predicate(name, args.arguments, ret, **annotations) else: self.domain.define_derived(name, args.arguments, ret, return_stmt, **annotations) print(f'derived_feature_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def derived_function_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], ret: Optional[Union[str, TypeBase]], suite: Tree): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) if ret is None: ret = self.domain.get_type('bool') elif isinstance(ret, str): ret = self.domain.get_type(ret) with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(suite) return_stmt = suite.get_combined_return_expression(allow_expr_expressions=False) if return_stmt is None: self.domain.define_predicate(name, args.arguments, ret, state=False, observation=False, **annotations) else: self.domain.define_derived(name, args.arguments, ret, return_stmt, state=False, **annotations) print(f'derived_function_definition:: {name=} {args.arguments=} {ret=} {annotations=} {suite=}') if return_stmt is not None: print(jacinle.indent_text(f'Return statement: {return_stmt}'))
[docs] @inline_args def action_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def action_effect_definition(self, suite: Tree) -> EffectPart: return EffectPart(suite)
[docs] @inline_args def action_impl_definition(self, suite: Tree) -> ImplPart: return ImplPart(suite)
[docs] @inline_args def action_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, EffectPart, ImplPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'action_definition:: {name=} {args.arguments=} {annotations=}') precondition = list() effect = list() implementation = None for part in parts: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(part.suite) if isinstance(part, PreconditionPart): suite = suite.get_all_check_expressions() precondition = [Precondition(x) for x in suite] print(jacinle.indent_text(f'Precondition: {precondition}')) elif isinstance(part, EffectPart): suite = suite.get_all_assign_expressions() effect = [Effect(x, **a) for x, a in suite] print(jacinle.indent_text(f'Effect: {effect}')) elif isinstance(part, ImplPart): # TODO(Jiayuan Mao @ 2024/03/2): For now we just allow a single expression. suite = suite.get_all_expr_expression(allow_multiple_expressions=False) implementation = suite print(jacinle.indent_text(f'Implementation: {implementation}')) self.domain.define_operator(name, args.arguments, precondition, effect, controller=implementation, **annotations)
[docs] @inline_args def generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, InPart, OutPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'generator_definition:: {name=} {args.arguments=} {annotations=}') inputs = list() outputs = list() goal = None implementation = None with self.expression_def_ctx.with_variables(*args.arguments): for part in parts: if isinstance(part, ImplPart): continue suite = self.expression_interpreter.visit(part.suite) if isinstance(part, GoalPart): goal = suite.get_all_expr_expression(allow_multiple_expressions=False) print(jacinle.indent_text(f'Goal: {goal}')) elif isinstance(part, InPart): inputs = [E.VariableExpression(x) for x in suite.items] print(jacinle.indent_text(f'Inputs: {inputs}')) elif isinstance(part, OutPart): outputs = [E.VariableExpression(x) for x in suite.items] print(jacinle.indent_text(f'Outputs: {outputs}')) else: raise ValueError(f'Invalid part: {part}') for part in parts: if isinstance(part, ImplPart): with self.expression_interpreter.set_generator_impl_outputs(outputs): suite = self.expression_interpreter.visit(part.suite) implementation = suite.get_all_expr_expression(allow_multiple_expressions=False) print(jacinle.indent_text(f'Implementation: {implementation}')) self.domain.define_generator(name, args.arguments, goal, inputs, outputs, implementation=implementation, **annotations)
[docs] @inline_args def undirected_generator_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'undirected_generator_definition:: {name=} {args.arguments=} {annotations=}') # TODO(Jiayuan Mao @ 2024/03/10): Implement the undirected generator definition. for part in parts: with self.expression_def_ctx.with_variables(*args.arguments), self.expression_interpreter.set_generator_impl_outputs([]): suite = self.expression_interpreter.visit(part.suite) print(jacinle.indent_text(f'{part.__class__.__name__}: ' + str(suite)))
[docs] @inline_args def generator_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def generator_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def generator_in_definition(self, values: Tree) -> InPart: return InPart(values)
[docs] @inline_args def generator_out_definition(self, values: Tree) -> OutPart: return OutPart(values)
[docs] @inline_args def generator_impl_definition(self, suite: Tree) -> ImplPart: return ImplPart(suite)
[docs] @inline_args def regression_rule_definition(self, annotations: Optional[dict], name: str, args: Optional[ArgumentsDef], *parts: Union[PreconditionPart, GoalPart, BodyPart]): if annotations is None: annotations = dict() if args is None: args = ArgumentsDef(tuple()) print(f'regression_rule_definition:: {name=} {args.arguments=} {annotations=}') precondition = list() goal = None body = list() side_effect = list() for part in parts: with self.expression_def_ctx.with_variables(*args.arguments): suite = self.expression_interpreter.visit(part.suite) if isinstance(part, BodyPart): print(jacinle.indent_text(f'Body: {suite}')) if isinstance(part, PreconditionPart): suite = suite.get_all_check_expressions() precondition = [Precondition(x) for x in suite] print(jacinle.indent_text(f'Precondition: {precondition}')) elif isinstance(part, GoalPart): if len(suite.items) != 1: raise NotImplementedError('Multiple goals are not supported in the current version.') goal = suite.items[0] print(jacinle.indent_text(f'Goal: {goal}')) elif isinstance(part, BodyPart): body = suite.get_all_regression_expressions(use_runtime_assign=True) print(jacinle.indent_text(f'Body:')) for x in body: print(jacinle.indent_text(str(x), level=2)) elif isinstance(part, SideEffectPart): suite = suite.get_all_assign_expressions() side_effect = [Effect(x, **a) for x, a in suite] print(jacinle.indent_text(f'SideEffect: {side_effect}')) else: raise ValueError(f'Invalid part: {part}') self.domain.define_regression_rule(name, args.arguments, precondition, goal, side_effect, body, **annotations)
[docs] @inline_args def regression_rule_precondition_definition(self, suite: Tree) -> PreconditionPart: return PreconditionPart(suite)
[docs] @inline_args def regression_rule_goal_definition(self, suite: Tree) -> GoalPart: return GoalPart(suite)
[docs] @inline_args def regression_rule_body_definition(self, suite: Tree) -> BodyPart: return BodyPart(suite)
[docs] @inline_args def regression_rule_side_effect_definition(self, suite: Tree) -> SideEffectPart: return SideEffectPart(suite)
[docs]class PDSketchV3ProblemTransformer(PDSketch3LiteralTransformer):
[docs] def __init__(self, domain: Optional[Domain] = None, state: Optional[State] = None, auto_constant_guess: bool = False): super().__init__() self.domain = None self.state = None self.problem = None self.expression_def_ctx = None self.expression_interpreter = None self.auto_constant_guess = auto_constant_guess if domain is not None: self._init_domain(domain, state)
domain: Optional[Domain] state: Optional[State] def _init_domain(self, domain: Domain, state: Optional[State] = None): if self.domain is not None: raise ValueError('Domain is already initialized. Cannot overwrite the domain.') self.domain = domain self.problem = Problem3(domain=self.domain) self.expression_def_ctx = E.ExpressionDefinitionContext(domain=self.domain) self.expression_interpreter = PDSketch3ExpressionInterpreter(domain=self.domain, state=state, expression_def_ctx=self.expression_def_ctx, auto_constant_guess=self.auto_constant_guess)
[docs] @inline_args def domain_def(self, filename: str): if self.domain is not None: logger.warning('Domain is already initialized. Skip the in-place domain loading.') return domain = _parser.parse_domain(filename) self._init_domain(domain)
[docs] @inline_args def objects_definition(self, *objects): for o in objects: self.problem.add_object(o.name, o.dtype.typename) self.expression_interpreter.local_variables[o.name] = E.ObjectConstantExpression(ObjectConstant(o.name, o.dtype))
[docs] @inline_args def init_definition(self, suite: Tree): self.problem.init_state() suite = self.expression_interpreter.visit(suite) executor = PDSketchExecutor(self.domain) for stmt, _ in suite.get_all_assign_expressions(): executor.execute(stmt, state=self.problem.state) for stmt in suite.get_all_expr_expression(allow_multiple_expressions=True): if isinstance(stmt, E.FunctionApplicationExpression): executor.execute(E.AssignExpression(stmt, E.ConstantExpression.TRUE), state=self.problem.state)
[docs] @inline_args def goal_definition(self, suite: Tree): suite = self.expression_interpreter.visit(suite) suite = suite.get_all_expr_expression(allow_multiple_expressions=False) self.problem.set_goal(suite)
_parser = PDSketchV3Parser()
[docs]def load_domain_file3(filename:str) -> Domain: """Load a domain file. Args: filename: the filename of the domain file. Returns: the domain object. """ return _parser.parse_domain(filename)
[docs]def load_domain_string3(string: str) -> Domain: """Load a domain from a string. Args: string: the string containing the domain definition. Returns: the domain object. """ return _parser.parse_domain_str(string)
[docs]def load_domain_string3_incremental(domain: Domain, string: str) -> Domain: """Load a domain from a string incrementally. Args: domain: the domain object to be updated. string: the string containing the domain definition. Returns: the domain object. """ return _parser.parse_domain_str(string, domain=domain)
[docs]def load_problem_file3(filename: str, domain: Optional[Domain] = None) -> Problem3: """Load a problem file. Args: filename: the filename of the problem file. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return _parser.parse_problem(filename, domain=domain)
[docs]def load_problem_string3(string: str, domain: Optional[Domain] = None) -> Problem3: """Load a problem from a string. Args: string: the string containing the problem definition. domain: the domain object. If not provided, the domain will be loaded from the domain file specified in the problem file. Returns: the problem object. """ return _parser.parse_problem_str(string, domain=domain)
[docs]def parse_expression3(domain: Domain, string: str, state: Optional[State] = None, variables: Optional[Sequence[Variable]] = None, auto_constant_guess: bool = True) -> E.Expression: """Parse an expression. Args: domain: the domain object. string: the string containing the expression. state: the current state, containing objects. variables: the variables. auto_constant_guess: whether to guess whether a variable is a constant. Returns: the parsed expression. """ return _parser.parse_expression(string, domain, state=state, variables=variables, auto_constant_guess=auto_constant_guess)