Source code for concepts.dsl.expression

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : expression.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/19/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Data structures for expressions in a DSL.

All classes extend the basic :class:`Expression` class. They can be categorized into the following groups:

- :class:`ObjectOrValueOutputExpression` is the base class for expressions that output objects or values. This is only used for type hinting.
- :class:`ObjectOutputExpression` and :class:`ValueOutputExpression` are the expressions that output objects or values.
- :class:`VariableExpression` which is the expression that refers to a variable.
- :class:`VariableAssignmentExpression` which assigns a value to a variable.

Under the :class:`ValueOutputExpression` category, there are a few sub-categories:

- :class:`NullExpression` which is the expression that outputs a null value.
- :class:`ConstantExpression` which is the expression that outputs a constant value.
- :class:`ListCreationExpression` which is the expression that creates a list.
- :class:`ListExpansionExpression` which is the expression that expands a list into a sequence of values (e.g., plan steps).
- :class:`FunctionApplicationExpression` which represents the application of a function.
- :class:`ListFunctionApplicationExpression` which represents the application of a function to a list of arguments.
- :class:`BoolExpression` which represents Boolean operations (and, or, not).
- :class:`QuantificationExpression` which represents quantification (forall, exists).
- :class:`GeneralizedQuantificationExpression` which represents generalized quantification (iota, all, counting quantifiers).
- :class:`FindOneExpression` which represents the find-one quantification.
- :class:`FindAllExpression` which represents the find-all quantification.
- :class:`PredicateEqualExpression` which represents the equality test between a state variable and a value.
- :class:`ObjectCompareExpression` which represents the comparison between two objects.
- :class:`ValueCompareExpression` which represents the comparison between two values.
- :class:`ConditionExpression` which represents the ternary conditional expression.
- :class:`ConditionalSelectExpression` which represents the conditional selection for some computed value.
- :class:`DeicticSelectExpression` which represents the deictic selection for some computed value (i.e., forall quantifiers).

Under the :class:`ObjectOutputExpression` category, there are a few sub-categories:

- :class:`ObjectConstantExpression` which is the expression that outputs a constant object.

Under the :class:`VariableAssignmentExpression` category, there are a few sub-categories:

- :class:`AssignExpression` which is the expression that assigns a value to a state variable.
- :class:`ConditionalAssignExpression` which is the expression that assigns a value to a state variable conditionally.
- :class:`DeicticAssignExpression` which is the expression that assigns values to state variables with deictic expressions (i.e., forall quantifiers).

The most important classes are: :class:`VariableExpression`, :class:`ObjectConstantExpression`, :class:`ConstantExpression`, and :class:`FunctionApplicationExpression`.
"""

import contextlib
from abc import ABC, abstractmethod
from typing import Any, Optional, Union, Iterable, Tuple, Sequence, List, Dict, Callable
from functools import lru_cache

import torch
import jacinle

from jacinle.utils.enum import JacEnum
from jacinle.utils.printing import indent_text
from jacinle.utils.defaults import wrap_custom_as_default, gen_get_default

from concepts.dsl.dsl_types import FormatContext, get_format_context
from concepts.dsl.dsl_types import TypeBase, ObjectType, ValueType, SequenceType, ListType, AutoType, TensorValueTypeBase, PyObjValueType, BOOL, FLOAT32, INT64, ObjectConstant, Variable
from concepts.dsl.dsl_functions import FunctionType, Function, FunctionArgumentUnset, AnonymousFunctionArgumentGenerator
from concepts.dsl.dsl_domain import DSLDomainBase
from concepts.dsl.value import ValueBase, Value, ListValue
from concepts.dsl.tensor_value import TensorValue

try:
    from typing import TypeGuard
except ImportError:
    class _DummyTypeGuard:
        def __getitem__(self, item):
            return bool
    TypeGuard = _DummyTypeGuard()


__all__ = [
    'DSLExpressionError', 'Expression', 'ExpressionDefinitionContext', 'get_expression_definition_context',
    'ObjectOutputExpression', 'ValueOutputExpression', 'NullExpression', 'VariableExpression', 'VariableAssignmentExpression',
    'ObjectConstantExpression', 'ConstantExpression', 'ListCreationExpression',
    'FunctionApplicationError', 'FunctionApplicationExpression', 'ListFunctionApplicationExpression', 'ListExpansionExpression',
    'ConditionalSelectExpression', 'DeicticSelectExpression',
    'BoolOpType', 'BoolExpression', 'AndExpression', 'OrExpression', 'NotExpression', 'XorExpression', 'ImpliesExpression',
    'QuantificationOpType', 'QuantificationExpression', 'GeneralizedQuantificationExpression', 'ForallExpression', 'ExistsExpression', 'FindOneExpression', 'FindAllExpression',
    'CompareOpType', 'ObjectCompareExpression', 'ValueCompareExpression', 'ConditionExpression',
    'PredicateEqualExpression', 'AssignExpression', 'ConditionalAssignExpression', 'DeicticAssignExpression',
    'cvt_expression', 'cvt_expression_list', 'get_type', 'get_types',
    'is_object_output_expression', 'is_variable_assignment_expression', 'is_variable_assignment_expression',
    'is_and_expr', 'is_or_expr', 'is_not_expr', 'is_xor_expr', 'is_implies_expr', 'is_forall_expr', 'is_exists_expr',
    'iter_exprs', 'find_free_variables'
]


[docs]class DSLExpressionError(Exception): pass
[docs]class Expression(ABC): """Expression is an abstract class for all expressions in the DSL. An important note about Expression is that the class itself does not contain any "implementation." For example, the expression `and(x, y, z)` does not contain any information about how to compute the conjunction (e.g., taking the product of the three values). The actual implementation of the expression will be provided by the `Executor` classes. """ @property @abstractmethod def return_type(self) -> Optional[Union[ObjectType, ValueType, FunctionType, SequenceType]]: raise NotImplementedError()
[docs] def check_arguments(self): ctx = get_expression_definition_context() if ctx is None or ctx.check_arguments: self._check_arguments()
def _check_arguments(self): pass @abstractmethod def __str__(self) -> str: raise NotImplementedError() __repr__ = jacinle.repr_from_str
[docs] @lru_cache(maxsize=10) def cached_string(self, max_length: Optional[int] = None): if max_length is None: return str(self) else: with FormatContext(expr_max_length=max_length).as_default(): return str(self)
[docs]class ExpressionDefinitionContext(object):
[docs] def __init__( self, *variables: Variable, domain: Optional['DSLDomainBase'] = None, scope: Optional[str] = None, is_effect_definition: bool = False, slot_functions_are_sgc: bool = False, allow_auto_predicate_def: bool = True, check_arguments: bool = True, ): """Initialize the context. Args: variables: The variables that are available in the expression. domain: the domain of the expression. scope: the current definition scope (e.g., in a function). This variable will be used to generate unique names for the functions. is_effect_definition: whether the expression is defined in an effect of an operator. slot_functions_are_sgc: whether the slot functions are SGC functions (state-goal-constraints functions). allow_auto_predicate_def: whether to enable automatic predicate definition. check_arguments: whether to check the arguments of the functions. """ self.variables = list(variables) self.variable_name2obj = {v.name: v for v in self.variables} self.domain = domain self.scope = scope self.anonymous_argument_generator = AnonymousFunctionArgumentGenerator() self.is_effect_definition_stack = [is_effect_definition] self.slot_functions_are_sgc = slot_functions_are_sgc self.allow_auto_predicate_def = allow_auto_predicate_def self.check_arguments = check_arguments
OPTION_NAMES = ['allow_auto_predicate_def', 'check_arguments'] variables: List[Variable] """The list of variables.""" variable_name2obj: Dict[str, Variable] """The mapping from variable names to variables.""" domain: Optional['DSLDomainBase'] """The domain of the expression.""" scope: Optional[str] """The current definition scope (e.g., in a function). This variable will be used to generate unique names for the functions.""" anonymous_argument_generator: AnonymousFunctionArgumentGenerator """The anonymous argument generator.""" is_effect_definition_stack: List[bool] """Whether the expression is defined in an effect of an operator.""" slot_functions_are_sgc: bool """Whether the slot functions are SGC functions (state-goal-constraints functions).""" allow_auto_predicate_def: bool """Whether to enable automatic predicate definition.""" check_arguments: bool """Whether to check the arguments of the functions."""
[docs] @wrap_custom_as_default def as_default(self): yield self
[docs] def has_variable(self: 'ExpressionDefinitionContext', variable: Union[str, Variable]) -> bool: if isinstance(variable, Variable): return variable.name in self.variable_name2obj return variable in self.variable_name2obj
[docs] def get_variable(self, variable: Union[str, Variable]) -> Variable: if isinstance(variable, Variable): return variable if variable not in self.variable_name2obj: raise ValueError(f'Unknown variable: {variable}; available variables: {self.variables}.') return self.variable_name2obj[variable]
[docs] def __getitem__(self, variable: Union[str, Variable]) -> 'VariableExpression': return self.wrap_variable(variable)
[docs] def wrap_variable(self, variable: Union[str, Variable]) -> 'VariableExpression': if isinstance(variable, Variable): return VariableExpression(variable) variable_name = variable if variable_name == '??': return VariableExpression(Variable('??', AutoType)) if variable_name not in self.variable_name2obj: raise ValueError('Unknown variable: {}; available variables: {}.'.format(variable_name, self.variables)) return VariableExpression(self.variable_name2obj[variable_name])
[docs] def gen_random_named_variable(self, dtype) -> Variable: """Generate a variable expression with a random name. This utility is useful in "flatten_expression". See the doc for that function for details.""" name = self.anonymous_argument_generator.gen() return Variable(name, dtype)
[docs] @contextlib.contextmanager def with_variables(self, *args: Variable): """Reset the list of variables.""" old_variables = self.variables.copy() self.variables = list(args) self.variable_name2obj = {v.name: v for v in self.variables} yield self.variables = old_variables self.variable_name2obj = {v.name: v for v in self.variables}
[docs] @contextlib.contextmanager def new_variables(self, *args: Variable): """Adding a list of new variables. This function is a context manager, and the variables will be removed after the context is closed.""" for arg in args: if arg.name in self.variable_name2obj: raise ValueError(f'Variable {arg.name} already exists.') self.variables.append(arg) self.variable_name2obj[arg.name] = arg yield self for arg in reversed(args): self.variables.pop() del self.variable_name2obj[arg.name]
[docs] def add_variables(self, *args: Variable): """Adding a list of new variables. Unlike :meth:`new_variables`, the variables will be directly added to the current context.""" for arg in args: if arg.name in self.variable_name2obj: raise ValueError(f'Variable {arg.name} already exists.') self.variables.append(arg) self.variable_name2obj[arg.name] = arg
[docs] @contextlib.contextmanager def mark_is_effect_definition(self, is_effect_definition: bool): self.is_effect_definition_stack.append(is_effect_definition) yield self self.is_effect_definition_stack.pop()
@property def is_effect_definition(self) -> bool: return self.is_effect_definition_stack[-1]
[docs] @contextlib.contextmanager def options(self, **kwargs): for k, v in kwargs.items(): if k not in self.OPTION_NAMES: raise ValueError(f'Unknown option {k}.') old_options = {k: getattr(self, k) for k in kwargs} for k, v in kwargs.items(): setattr(self, k, v) yield self for k, v in old_options.items(): setattr(self, k, v)
get_expression_definition_context: Callable[[], Optional[ExpressionDefinitionContext]] = gen_get_default(ExpressionDefinitionContext)
[docs]class ObjectOrValueOutputExpression(Expression, ABC): @property def return_type(self) -> Union[ObjectType, ValueType, SequenceType]: raise NotImplementedError() def __str__(self) -> str: raise NotImplementedError()
[docs]class ObjectOutputExpression(ObjectOrValueOutputExpression, ABC): @property def return_type(self) -> Union[ObjectType, SequenceType]: raise NotImplementedError() def __str__(self) -> str: raise NotImplementedError()
[docs]class ValueOutputExpression(ObjectOrValueOutputExpression, ABC): @property def return_type(self) -> Union[ValueType, SequenceType]: raise NotImplementedError() def __str__(self) -> str: raise NotImplementedError()
[docs]class NullExpression(ObjectOrValueOutputExpression):
[docs] def __init__(self, dtype: Union[ObjectType, ValueType, SequenceType]): self.dtype = dtype
dtype: Union[ObjectType, ValueType, SequenceType] """The type of the null expression.""" @property def return_type(self) -> Union[ObjectType, ValueType, SequenceType]: return self.dtype def __str__(self) -> str: return 'null'
[docs]class VariableExpression(ObjectOrValueOutputExpression):
[docs] def __init__(self, variable: Variable): self.variable = variable
variable: Variable """The variable.""" @property def name(self) -> str: return self.variable.name @property def dtype(self) -> Union[ObjectType, ValueType, FunctionType, ListType]: return self.variable.dtype @property def return_type(self) -> Union[ObjectType, ValueType, FunctionType, ListType]: return self.variable.dtype def __str__(self) -> str: return f'V::{self.name}'
[docs]class VariableAssignmentExpression(Expression): @property def return_type(self): return None def __str__(self): raise NotImplementedError()
[docs]class ObjectConstantExpression(ObjectOutputExpression):
[docs] def __init__(self, constant: Union[ObjectConstant, ListValue]): self.constant = constant
constant: Union[ObjectConstant, ListValue] """The object constant.""" @property def name(self) -> str: """The name of the object.""" if not isinstance(self.constant, ObjectConstant): raise TypeError('ObjectConstantExpression.name is only available for ObjectConstant.') return self.constant.name @property def dtype(self) -> Union[ObjectType, ListType]: """The type of the object.""" if isinstance(self.constant, ListValue): return self.constant.dtype return self.constant.dtype @property def is_constant_list(self): """Whether the object is a constant list.""" return isinstance(self.constant, ListValue) @property def return_type(self) -> Union[ObjectType, ListType]: return self.constant.dtype def __str__(self) -> str: if self.is_constant_list: constant_str = ' '.join(x.name for x in self.constant.values) return f'O::{{{constant_str}}}' return f'O::{self.name}'
[docs]class ConstantExpression(ValueOutputExpression): """Constant expression always returns a constant value.""" constant: ValueBase """The constant."""
[docs] def __init__(self, value: Union[bool, int, float, str, torch.Tensor, Any, ValueBase], dtype: Optional[ValueType] = None): if isinstance(value, ValueBase): self.constant = value else: assert dtype is not None if isinstance(dtype, (TensorValueTypeBase, PyObjValueType)): self.constant = TensorValue.from_scalar(value, dtype) else: self.constant = Value(dtype, value)
@property def return_type(self) -> Union[ValueType, ListType]: if isinstance(self.constant, ListValue): assert isinstance(self.constant.dtype.element_type, ValueType) return self.constant.dtype assert isinstance(self.constant.dtype, ValueType) return self.constant.dtype
[docs] @classmethod def true(cls): return cls(torch.tensor(1, dtype=torch.int64), BOOL)
[docs] @classmethod def false(cls): return cls(torch.tensor(0, dtype=torch.int64), BOOL)
[docs] @classmethod def int64(cls, value): return cls(torch.tensor(value, dtype=torch.int64), INT64)
[docs] @classmethod def float32(cls, value): return cls(torch.tensor(value, dtype=torch.float32), FLOAT32)
[docs] @classmethod def from_value(cls, value, dtype: Optional[ValueType] = None): if isinstance(value, bool): return cls(torch.tensor(bool(value), dtype=torch.int64), dtype if dtype is not None else BOOL) elif isinstance(value, int): return cls(torch.tensor(value, dtype=torch.int64), dtype if dtype is not None else INT64) elif isinstance(value, float): return cls(torch.tensor(value, dtype=torch.float32), dtype if dtype is not None else FLOAT32) else: raise ValueError(f'Unknown value type: {type(value)}.')
def __str__(self): if isinstance(self.constant, TensorValue) and self.constant.is_single_elem: return f'C::{self.constant.single_elem()}' return f'C::{self.constant}'
ConstantExpression.TRUE = ConstantExpression.true() ConstantExpression.FALSE = ConstantExpression.false()
[docs]class ListCreationExpression(Expression):
[docs] def __init__(self, arguments: Sequence[ValueOutputExpression], element_type: Optional[TypeBase] = None): self.arguments = tuple(arguments) if len(self.arguments) == 0: assert element_type is not None, 'Must specify the element type if the list is empty.' self.element_type = element_type else: self.element_type = element_type if element_type is not None else self.arguments[0].return_type self.check_arguments()
@property def return_type(self) -> ListType: return ListType(self.element_type) def _check_arguments(self): for i, arg in enumerate(self.arguments): if arg.return_type != self.element_type: raise TypeError(f'Argument #{i} has type {arg.return_type}, which does not match the list type {self.element_type}.') def __str__(self) -> str: return f'{{{", ".join([str(arg) for arg in self.arguments])}}}'
[docs]class ListExpansionExpression(Expression):
[docs] def __init__(self, expression: ValueOutputExpression): self.expression = expression self.check_arguments() self.element_type = self.expression.return_type.element_type
expression: ValueOutputExpression """The expression.""" element_type: TypeBase """The element type.""" def _check_arguments(self): assert isinstance(self.expression, ValueOutputExpression) and self.expression.return_type.is_list_type, \ f'ListExpansionExpression only accepts ValueOutputExpressions with list-typed return, got {self.expression} which returns {self.expression.return_type}.' @property def return_type(self) -> ListType: return self.expression.return_type def __str__(self) -> str: return f'... {str(self.expression)}'
[docs]class FunctionApplicationError(Exception):
[docs] def __init__(self, index: int, expect, got): msg = f'Argument #{index} type does not match: expect {expect}, got {got}.' super().__init__(msg) self.index = index self.expect = expect self.got = got
[docs]class FunctionApplicationExpression(ValueOutputExpression): """Function application expression represents the application of a function over a list of arguments."""
[docs] def __init__(self, function: Function, arguments: Iterable[Expression]): self.function = function self.arguments = tuple(arguments) self.check_arguments()
function: Function """The function to be applied.""" arguments: Tuple[Expression, ...] """The list of arguments to the function.""" def _check_arguments(self): try: if len(self.function.arguments) != len(self.arguments): raise TypeError('Argument number mismatch: expect {}, got {}.'.format(len(self.function.arguments), len(self.arguments))) for i, (arg_def, arg) in enumerate(zip(self.function.arguments, self.arguments)): if isinstance(arg_def, Variable): if isinstance(arg, VariableExpression): if not arg.dtype.downcast_compatible(arg_def.dtype): raise FunctionApplicationError(i, arg_def.dtype, arg.dtype) elif isinstance(arg, ObjectConstantExpression): if not arg.dtype.downcast_compatible(arg_def.dtype): raise FunctionApplicationError(i, arg_def.dtype, arg.dtype) elif isinstance(arg, ConstantExpression): if not arg.return_type.downcast_compatible(arg_def.dtype): raise FunctionApplicationError(i, arg_def.dtype, arg.return_type) elif isinstance(arg, (FunctionApplicationExpression, GeneralizedQuantificationExpression)): if not arg.return_type.downcast_compatible(arg_def.dtype): raise FunctionApplicationError(i, arg_def.dtype, arg.return_type) else: raise FunctionApplicationError(i, 'VariableExpression or ObjectConstantExpression or ConstantExpression or FunctionApplication', type(arg)) elif isinstance(arg_def, ValueType): if isinstance(arg, ValueOutputExpression): pass elif isinstance(arg, VariableExpression) and isinstance(arg.return_type, ValueType): pass elif isinstance(arg, ConstantExpression): pass else: raise FunctionApplicationError(i, 'ValueOutputExpression', type(arg)) if arg_def != arg.return_type: raise FunctionApplicationError(i, arg_def, arg.return_type) else: raise TypeError('Unknown argument definition type: {}.'.format(type(arg_def))) except (TypeError, FunctionApplicationError) as e: error_header = 'Error during applying {}.\n'.format(str(self.function)) try: arguments_str = ', '.join(str(arg) for arg in self.arguments) error_header += ' Arguments: {}\n'.format(arguments_str) except Exception: # noqa pass raise TypeError(error_header + str(e)) from e @property def return_type(self) -> ValueType: return self.function.return_type def __str__(self) -> str: fmt = self.function.name + '(' arg_fmt = [str(x) for x in self.arguments] arg_fmt_len = [len(x) for x in arg_fmt] ctx = get_format_context() # The following criterion is just an approximation. A more principled way is to pass the current indent level # to the recursive calls to str(x). if ctx.expr_max_length > 0 and (sum(arg_fmt_len) + len(fmt) + 1 > ctx.expr_max_length): if sum(arg_fmt_len) > ctx.expr_max_length: fmt += '\n' + ',\n'.join([indent_text(x) for x in arg_fmt]) + '\n' else: fmt += '\n' + ', '.join(arg_fmt) + '\n' else: fmt += ', '.join(arg_fmt) fmt += ')' return fmt
[docs]class ListFunctionApplicationExpression(ValueOutputExpression): """Function application expression represents the application of a function over a list of arguments."""
[docs] def __init__(self, function: Function, arguments: Iterable[Expression]): self.function = function self.arguments = tuple(arguments) self.check_arguments()
function: Function """The function to be applied.""" arguments: Tuple[Expression, ...] """The list of arguments to the function.""" def _check_arguments(self): try: if len(self.function.arguments) != len(self.arguments): raise TypeError('Argument number mismatch: expect {}, got {}.'.format(len(self.function.arguments), len(self.arguments))) for i, (arg_def, arg) in enumerate(zip(self.function.arguments, self.arguments)): if isinstance(arg_def, Variable): if isinstance(arg, VariableExpression): if not arg.dtype.downcast_compatible(arg_def.dtype, allow_self_list=True): raise FunctionApplicationError(i, arg_def.dtype, arg.dtype) elif isinstance(arg, ObjectConstantExpression): if not arg.dtype.downcast_compatible(arg_def.dtype, allow_self_list=True): raise FunctionApplicationError(i, arg_def.dtype, arg.dtype) elif isinstance(arg, ConstantExpression): if not arg.return_type.downcast_compatible(arg_def.dtype, allow_self_list=True): raise FunctionApplicationError(i, arg_def.dtype, arg.return_type) elif isinstance(arg, (FunctionApplicationExpression, ListFunctionApplicationExpression, GeneralizedQuantificationExpression)): if not arg.return_type.downcast_compatible(arg_def.dtype, allow_self_list=True): raise FunctionApplicationError(i, arg_def.dtype, arg.return_type) elif isinstance(arg, ListCreationExpression): if not arg.return_type.downcast_compatible(arg_def.dtype, allow_self_list=True): raise FunctionApplicationError(i, arg_def.dtype, arg.return_type) else: raise FunctionApplicationError(i, 'VariableExpression or ObjectConstantExpression or ConstantExpression or FunctionApplication', type(arg)) elif isinstance(arg_def, ValueType): if isinstance(arg, ValueOutputExpression): pass elif isinstance(arg, VariableExpression) and isinstance(arg.return_type, ValueType): pass elif isinstance(arg, ConstantExpression): pass else: raise FunctionApplicationError(i, 'ValueOutputExpression', type(arg)) if arg_def != arg.return_type: raise FunctionApplicationError(i, arg_def, arg.return_type) else: raise TypeError('Unknown argument definition type: {}.'.format(type(arg_def))) except (TypeError, FunctionApplicationError) as e: error_header = 'Error during applying {}.\n'.format(str(self.function)) try: arguments_str = ', '.join(str(arg) for arg in self.arguments) error_header += ' Arguments: {}\n'.format(arguments_str) except Exception: # noqa pass raise TypeError(error_header + str(e)) from e @property def return_type(self) -> Union[ValueType, ListType]: for arg in self.arguments: if arg.return_type.is_list_type: return ListType(self.function.return_type) return self.function.return_type def __str__(self) -> str: fmt = self.function.name + '[list](' arg_fmt = [str(x) for x in self.arguments] arg_fmt_len = [len(x) for x in arg_fmt] ctx = get_format_context() # The following criterion is just an approximation. A more principled way is to pass the current indent level # to the recursive calls to str(x). if ctx.expr_max_length > 0 and (sum(arg_fmt_len) + len(fmt) + 1 > ctx.expr_max_length): if sum(arg_fmt_len) > ctx.expr_max_length: fmt += '\n' + ',\n'.join([indent_text(x) for x in arg_fmt]) + '\n' else: fmt += '\n' + ', '.join(arg_fmt) + '\n' else: fmt += ', '.join(arg_fmt) fmt += ')' return fmt
[docs]class ConditionalSelectExpression(ValueOutputExpression): """Conditional select expression represents the selection of a value based on a condition."""
[docs] def __init__(self, predicate: ValueOutputExpression, condition: ValueOutputExpression): self.predicate = predicate self.condition = condition self.check_arguments()
predicate: ValueOutputExpression """The predicate expression.""" condition: ValueOutputExpression """The condition expression.""" def _check_arguments(self): if isinstance(self.condition, ValueOutputExpression) and self.condition.return_type == BOOL: pass elif isinstance(self.condition, VariableExpression) and self.condition.return_type.downcast_compatible(BOOL): pass else: raise TypeError('Condition must be a boolean expression.') @property def return_type(self) -> ValueType: return self.predicate.return_type def __str__(self): predicate_str = str(self.predicate) condition_str = str(self.condition) if len(predicate_str) + len(condition_str) + 2 < 80: return f'cond-select({predicate_str} if {condition_str})' return f'cond-select({predicate_str} if\n{indent_text(condition_str)})'
[docs]class DeicticSelectExpression(ValueOutputExpression):
[docs] def __init__(self, variable: Variable, expr: ValueOutputExpression): self.variable = variable self.expression = expr self.check_arguments()
variable: Variable """The new quantified variable.""" expression: ValueOutputExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.variable.dtype, ObjectType) @property def return_type(self) -> ValueType: return self.expression.return_type def __str__(self): return f'deictic-select({self.variable}: {self.expression})'
[docs]class BoolOpType(JacEnum): AND = 'and' OR = 'or' NOT = 'not' XOR = 'xor' IMPLIES = 'implies'
[docs]class BoolExpression(ValueOutputExpression):
[docs] def __init__(self, bool_op_type: BoolOpType, arguments: Sequence[ValueOutputExpression]): self.bool_op = bool_op_type self.arguments = tuple(arguments) self.check_arguments()
bool_op: BoolOpType """The boolean operation. Can be AND, OR, NOT, XOR, IMPLIES.""" arguments: Tuple[ValueOutputExpression, ...] """The list of arguments.""" def _check_arguments(self): if self.bool_op is BoolOpType.NOT: assert len(self.arguments) == 1, f'Number of arguments for NotOp should be 1, got: {len(self.arguments)}.' if self.bool_op is BoolOpType.IMPLIES: assert len(self.arguments) == 2, f'Number of arguments for ImpliesOp should be 2, got: {len(self.arguments)}.' for i, arg in enumerate(self.arguments): assert isinstance(arg, (VariableExpression, ValueOutputExpression)), f'BoolOp only accepts ValueOutputExpressions, got argument #{i} of type {type(arg)}.' @property def return_type(self) -> ValueType: return self.arguments[0].return_type def __str__(self): argument_strings = [str(arg) for arg in self.arguments] if sum(len(x) for x in argument_strings) < 80: return f'{self.bool_op.value}({", ".join(argument_strings)})' arguments = ',\n'.join([indent_text(x) for x in argument_strings]) return f'{self.bool_op.value}(\n{arguments}\n)'
[docs]class AndExpression(BoolExpression): bool_op: BoolOpType """The boolean operation. Must be :py:attr:`BoolOpType.AND`.""" arguments: Tuple[ValueOutputExpression, ...]
[docs] def __init__(self, *arguments: ValueOutputExpression): super().__init__(BoolOpType.AND, arguments)
[docs]class OrExpression(BoolExpression): bool_op: BoolOpType """The boolean operation. Must be :py:attr:`BoolOpType.OR`.""" arguments: Tuple[ValueOutputExpression, ...]
[docs] def __init__(self, *arguments: ValueOutputExpression): super().__init__(BoolOpType.OR, arguments)
[docs]class NotExpression(BoolExpression): bool_op: BoolOpType """The boolean operation. Must be :py:attr:`BoolOpType.NOT`.""" arguments: Tuple[ValueOutputExpression] """The list of arguments. Must contain exactly one argument."""
[docs] def __init__(self, arg: ValueOutputExpression): super().__init__(BoolOpType.NOT, [arg])
[docs]class XorExpression(BoolExpression): bool_op: BoolOpType """The boolean operation. Must be :py:attr:`BoolOpType.XOR`.""" arguments: Tuple[ValueOutputExpression, ...]
[docs] def __init__(self, *arguments: ValueOutputExpression): super().__init__(BoolOpType.XOR, arguments)
[docs]class ImpliesExpression(BoolExpression): bool_op: BoolOpType """The boolean operation. Must be :py:attr:`BoolOpType.IMPLIES`.""" arguments: Tuple[ValueOutputExpression, ValueOutputExpression]
[docs] def __init__(self, lhs: ValueOutputExpression, rhs: ValueOutputExpression): super().__init__(BoolOpType.IMPLIES, [lhs, rhs])
[docs]class QuantificationOpType(JacEnum): FORALL = 'forall' EXISTS = 'exists'
[docs]class QuantificationExpression(ValueOutputExpression):
[docs] def __init__(self, quantification_op: QuantificationOpType, variable: Variable, expr: ValueOutputExpression): self.quantification_op = quantification_op self.variable = variable self.expression = expr self.check_arguments()
quantification_op: QuantificationOpType """The quantification operation. Can be FORALL or EXISTS.""" variable: Variable """The quantified variable.""" expression: ValueOutputExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.expression, ValueOutputExpression), f'QuantificationOp only accepts ValueOutputExpressions, got type {type(self.expression)}.' assert isinstance(self.variable.dtype, ObjectType) @property def return_type(self) -> ValueType: return self.expression.return_type def __str__(self): return f'{self.quantification_op.value}({self.variable}: {self.expression})'
[docs]class GeneralizedQuantificationExpression(ValueOutputExpression):
[docs] def __init__(self, quantification_op: Any, variable: Variable, expr: ValueOutputExpression, return_type: Optional[ValueType] = None): self.quantification_op = quantification_op self.variable = variable self.expression = expr self._return_type = return_type if return_type is not None else self.expression.return_type self.check_arguments()
quantification_op: Any """The quantification operation. It can be any data type.""" variable: Variable """The quantified variable.""" expression: ValueOutputExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.expression, ValueOutputExpression), f'QuantificationOp only accepts ValueOutputExpressions, got type {type(self.expression)}.' assert isinstance(self.variable.dtype, ObjectType) @property def return_type(self) -> ValueType: return self._return_type def __str__(self): return f'{self.quantification_op}({self.variable}: {self.expression})'
[docs]class ForallExpression(QuantificationExpression):
[docs] def __init__(self, variable: Variable, expr: ValueOutputExpression): super().__init__(QuantificationOpType.FORALL, variable, expr)
quantification_op: QuantificationOpType """The quantification operation. Must be :py:attr:`QuantificationOpType.FORALL`.""" variable: Variable expression: ValueOutputExpression
[docs]class ExistsExpression(QuantificationExpression):
[docs] def __init__(self, variable: Variable, expr: ValueOutputExpression): super().__init__(QuantificationOpType.EXISTS, variable, expr)
quantification_op: QuantificationOpType """The quantification operation. Must be :py:attr:`QuantificationOpType.EXISTS`.""" variable: Variable expression: ValueOutputExpression
[docs]class FindOneExpression(ObjectOutputExpression):
[docs] def __init__(self, variable: Variable, expr: ValueOutputExpression): self.variable = variable self.expression = expr self.check_arguments()
variable: Variable """The quantified variable.""" expression: ValueOutputExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.expression, ValueOutputExpression), f'FindAllOp only accepts ValueOutputExpressions, got type {type(self.expression)}.' assert isinstance(self.variable.dtype, ObjectType) @property def return_type(self) -> ObjectType: return self.variable.dtype def __str__(self): return f'findone({self.variable}: {self.expression})'
[docs]class FindAllExpression(ObjectOutputExpression):
[docs] def __init__(self, variable: Variable, expr: ValueOutputExpression): self.variable = variable self.expression = expr self.check_arguments()
variable: Variable """The quantified variable.""" expression: ValueOutputExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.expression, ValueOutputExpression), f'FindAllOp only accepts ValueOutputExpressions, got type {type(self.expression)}.' assert isinstance(self.variable.dtype, ObjectType) @property def return_type(self) -> ListType: return ListType(self.variable.dtype) def __str__(self): return f'findall({self.variable}: {self.expression})'
[docs]class CompareOpType(JacEnum): EQ = '==' NEQ = '!=' LT = '<' LEQ = '<=' GT = '>' GEQ = '>='
class _CompareExpressionBase(ValueOutputExpression, ABC): def __init__(self, compare_op: CompareOpType, lhs: Expression, rhs: Expression): self.compare_op = compare_op self.lhs = lhs self.rhs = rhs self.check_arguments() @property def arguments(self) -> Tuple[Expression, Expression]: return self.lhs, self.rhs @property def return_type(self) -> ValueType: return BOOL def __str__(self): return f'({self.lhs} {self.compare_op.value} {self.rhs})'
[docs]class ObjectCompareExpression(_CompareExpressionBase): def _check_arguments(self): assert self.compare_op in (CompareOpType.EQ, CompareOpType.NEQ), f'ObjectCompareExpression only accepts EQ and NEQ, got {self.compare_op}.' assert isinstance(self.lhs.return_type, ObjectType), f'lhs of ObjectCompareExpression must be of type ObjectType, got {self.lhs.return_type}.' assert isinstance(self.rhs.return_type, ObjectType), f'rhs of ObjectCompareExpression must be of type ObjectType, got {self.rhs.return_type}.' compare_op: CompareOpType """The comparison operation.""" lhs: Union[ObjectOutputExpression, VariableExpression] """The left-hand side of the comparison.""" rhs: Union[ObjectOutputExpression, VariableExpression] """The right-hand side of the comparison."""
[docs]class ValueCompareExpression(_CompareExpressionBase): def _check_arguments(self): assert isinstance(self.lhs.return_type, ValueType), f'lhs of ValueCompareExpression must be of type ValueType, got {self.lhs.return_type}.' assert isinstance(self.rhs.return_type, ValueType), f'rhs of ValueCompareExpression must be of type ValueType, got {self.rhs.return_type}.' compare_op: CompareOpType """The comparison operation.""" lhs: ValueOutputExpression """The left-hand side of the comparison.""" rhs: ValueOutputExpression """The right-hand side of the comparison."""
[docs]class ConditionExpression(ValueOutputExpression):
[docs] def __init__(self, condition: ValueOutputExpression, true_value: ValueOutputExpression, false_value: ValueOutputExpression): self.condition = condition self.true_value = true_value self.false_value = false_value self.check_arguments()
condition: ValueOutputExpression """The condition expression.""" true_value: ValueOutputExpression """The true value expression.""" false_value: ValueOutputExpression """The false value expression.""" def _check_arguments(self): assert isinstance(self.condition, ValueOutputExpression) and self.condition.return_type == BOOL, f'Condition must be a boolean expression, got {self.condition}.' assert self.true_value.return_type == self.false_value.return_type, f'True value and false value must have the same type, got {self.true_value.return_type} and {self.false_value.return_type}.' @property def return_type(self) -> ValueType: return self.true_value.return_type def __str__(self): condition_str = str(self.condition) true_value_str = str(self.true_value) false_value_str = str(self.false_value) if len(condition_str) + len(true_value_str) + len(false_value_str) + 4 < 80: return f'cond({condition_str} ? {true_value_str} : {false_value_str})' return f'cond({condition_str} ?\n{indent_text(true_value_str)}\n{indent_text(false_value_str)})'
class _PredicateValueExpression(Expression, ABC): def __init__(self, predicate: Union[VariableExpression, FunctionApplicationExpression], value: ValueOutputExpression): self.predicate = predicate self.value = value self.check_arguments() def _check_arguments(self): try: rtype = self.predicate.return_type if rtype.assignment_type() != self.value.return_type: raise FunctionApplicationError(0, f'{self.predicate.return_type}(assignment type is {rtype.assignment_type()})', self.value.return_type) except TypeError as e: raise e except FunctionApplicationError as e: error_header = 'Error during _PredicateValueExpression checking: feature = {} value = {}.\n'.format(str(self.predicate), str(self.value)) raise TypeError( error_header + f'Value type does not match: expect: {e.expect}, got {e.got}.' ) from e
[docs]class PredicateEqualExpression(ValueOutputExpression, _PredicateValueExpression): predicate: Union[VariableExpression, FunctionApplicationExpression] """The predicate expression.""" value: ValueOutputExpression """The value expression.""" def _check_arguments(self): super()._check_arguments() if not isinstance(self.predicate, (VariableExpression, FunctionApplicationExpression)): raise TypeError(f'PredicateEqualOp only support dest type VariableExpression or FunctionApplication, got {type(self.predicate)}.') @property def return_type(self): return BOOL def __str__(self): return f'equal({self.predicate}, {self.value})'
[docs]class AssignExpression(_PredicateValueExpression, VariableAssignmentExpression):
[docs] def __init__(self, predicate: FunctionApplicationExpression, value: ValueOutputExpression): _PredicateValueExpression.__init__(self, predicate, value)
predicate: FunctionApplicationExpression """The predicate expression, must be a :class:`FunctionApplicationExpression` which refers to a state variable.""" value: ValueOutputExpression """The expression for the value to assign to the state variable.""" def _check_arguments(self): super()._check_arguments() assert isinstance(self.predicate, FunctionApplicationExpression), 'AssignOp only support dest type FunctionApplication, got {}.'.format(type(self.predicate)) def __str__(self): return f'assign{{{self.predicate}: {self.value}}}'
[docs]class ConditionalAssignExpression(_PredicateValueExpression, VariableAssignmentExpression):
[docs] def __init__(self, feature: FunctionApplicationExpression, value: ValueOutputExpression, condition: ValueOutputExpression): self.condition = condition _PredicateValueExpression.__init__(self, feature, value)
predicate: FunctionApplicationExpression """The predicate expression, must be a :class:`FunctionApplicationExpression` which refers to a state variable.""" value: ValueOutputExpression """The expression for the value to assign to the state variable.""" condition: ValueOutputExpression """The condition expression.""" def _check_arguments(self): super()._check_arguments() assert isinstance(self.condition, ValueOutputExpression) and self.condition.return_type == BOOL def __str__(self): return f'cond-assign{{{self.predicate}: {self.value} if {self.condition}}}'
[docs]class DeicticAssignExpression(VariableAssignmentExpression):
[docs] def __init__(self, variable: Variable, expr: Union[VariableAssignmentExpression]): self.variable = variable self.expression = expr self.check_arguments()
variable: Variable """The quantified variable.""" expression: VariableAssignmentExpression """The internal expression.""" def _check_arguments(self): assert isinstance(self.variable.dtype, ObjectType) def __str__(self): return f'deictic-assign{{{self.variable}: {self.expression}}}'
ExpressionCompatible = Union[Expression, Variable, str, ObjectConstant, bool, int, float, torch.Tensor, ValueBase]
[docs]def cvt_expression(expr: ExpressionCompatible, dtype: Optional[Union[ObjectType, ValueType]] = None) -> Expression: """Convert an expression compatible object to an expression. Acceptable types are: * :class:`Expression`. * :class:`Variable`: return a :class:`VariableExpression`. * :class:`str`: return a :class:`ConstantExpression` with the given constant string name, or a :class:`ObjectConstantExpression` if the dtype is a :class:`ObjectType`. * :class:`ObjectConstant`: return a :class:`ObjectConstantExpression`. * :class:`bool`, :class:`int`, :class:`float`, :class:`torch.Tensor`: return a :class:`ConstantExpression`. * :class:`ValueBase`: return a :class:`ConstantExpression`. Args: expr: the expression compatible object. dtype: the expected data type of the expression. If not given, the dtype will be inferred from the given object. Returns: the converted expression. Raises: TypeError: if the given object is not an expression compatible object. """ if isinstance(expr, Expression): return expr elif isinstance(expr, Variable): return VariableExpression(expr) elif isinstance(expr, str): if isinstance(dtype, ObjectType): return ObjectConstantExpression(ObjectConstant(expr, dtype or AutoType)) elif isinstance(dtype, ValueType): return ConstantExpression(Value(dtype or AutoType, expr)) elif isinstance(expr, ObjectConstant): return ObjectConstantExpression(expr) elif isinstance(expr, bool): return ConstantExpression(torch.tensor(int(expr), dtype=torch.int64), dtype or BOOL) elif isinstance(expr, int): return ConstantExpression(torch.tensor(expr, dtype=torch.int64), dtype or INT64) elif isinstance(expr, float): return ConstantExpression(torch.tensor(expr, dtype=torch.float32), dtype or FLOAT32) elif isinstance(expr, torch.Tensor): if expr.dtype == torch.int64: return ConstantExpression(expr, dtype or INT64) elif expr.dtype == torch.float32: return ConstantExpression(expr, dtype or FLOAT32) else: raise TypeError(f'Unsupported tensor type: {expr.dtype}.') elif isinstance(expr, ValueBase): if isinstance(expr.dtype, ValueType): return ConstantExpression(expr, expr.dtype) else: raise TypeError(f'Unsupported value type: {expr.dtype}.') else: raise TypeError(f'Non-compatible expression type {type(expr)} for expression "{expr}".')
[docs]def cvt_expression_list(arguments: Sequence[ExpressionCompatible], dtypes: Optional[Sequence[Union[ObjectType, ValueType]]] = None) -> List[Expression]: """Convert a list of expression compatible objects to a list of expressions. Args: arguments: the list of expression compatible objects. dtypes: the list of expected data types of the expressions. If not given, the dtypes will be inferred from the given objects. It can be a single data type, in which case all the expressions will be converted to this data type. Returns: the list of converted expressions. """ if dtypes is None: arguments = [cvt_expression(arg) for arg in arguments] else: arguments = [cvt_expression(arg, dtype) for arg, dtype in zip(arguments, dtypes)] return arguments
[docs]def get_type(value: Any) -> Union[TypeBase, Tuple[TypeBase, ...]]: """Get the type of the given value.""" if value is FunctionArgumentUnset: return FunctionArgumentUnset elif isinstance(value, Function): return value.ftype elif isinstance(value, Expression): return value.return_type elif isinstance(value, ValueBase): return value.dtype elif isinstance(value, (bool, int, float, str)): return AutoType else: raise ValueError(f'Unknown value type: {type(value)}.')
[docs]def get_types(args=None, kwargs=None): """Get the types of the given arguments and keyword arguments.""" ret = list() if args is not None: ret.append(tuple(get_type(v) for v in args)) if kwargs is not None: ret.append({k: get_type(v) for k, v in kwargs.items()}) if len(ret) == 1: return ret[0] return tuple(ret)
[docs]def is_object_output_expression(expr: Expression) -> TypeGuard[ObjectOutputExpression]: return isinstance(expr, ObjectOutputExpression) or (isinstance(expr, VariableExpression) and isinstance(expr.variable.dtype, ObjectType))
[docs]def is_value_output_expression(expr: Expression) -> TypeGuard[ValueOutputExpression]: return isinstance(expr, ValueOutputExpression)
[docs]def is_variable_assignment_expression(expr: Expression) -> bool: return isinstance(expr, VariableAssignmentExpression)
[docs]def is_and_expr(expr: Expression) -> TypeGuard[AndExpression]: return isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.AND
[docs]def is_or_expr(expr: Expression) -> TypeGuard[OrExpression]: return isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.OR
[docs]def is_not_expr(expr: Expression) -> TypeGuard[NotExpression]: return isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.NOT
[docs]def is_xor_expr(expr: Expression) -> TypeGuard[XorExpression]: return isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.XOR
[docs]def is_implies_expr(expr: Expression) -> TypeGuard[ImpliesExpression]: return isinstance(expr, BoolExpression) and expr.bool_op is BoolOpType.IMPLIES
[docs]def is_constant_bool_expr(expr: Expression) -> TypeGuard[ConstantExpression]: if isinstance(expr, ConstantExpression) and expr.return_type == BOOL: return True return False
[docs]def is_forall_expr(expr: Expression) -> TypeGuard[ForallExpression]: return isinstance(expr, QuantificationExpression) and expr.quantification_op is QuantificationOpType.FORALL
[docs]def is_exists_expr(expr: Expression) -> TypeGuard[ExistsExpression]: return isinstance(expr, QuantificationExpression) and expr.quantification_op is QuantificationOpType.EXISTS
[docs]def iter_exprs(expr: Expression) -> Iterable[Expression]: """Iterate over all sub-expressions of the input.""" yield expr if isinstance(expr, (FunctionApplicationExpression, ListFunctionApplicationExpression)): for arg in expr.arguments: yield from iter_exprs(arg) elif isinstance(expr, ListCreationExpression): for arg in expr.arguments: yield from iter_exprs(arg) elif isinstance(expr, ListExpansionExpression): yield from iter_exprs(expr.expression) elif isinstance(expr, BoolExpression): for arg in expr.arguments: yield from iter_exprs(arg) elif isinstance(expr, QuantificationExpression): yield from iter_exprs(expr.expression) elif isinstance(expr, GeneralizedQuantificationExpression): yield from iter_exprs(expr.expression) elif isinstance(expr, VariableAssignmentExpression): yield from iter_exprs(expr.lhs) yield from iter_exprs(expr.rhs) elif isinstance(expr, PredicateEqualExpression): yield from iter_exprs(expr.predicate) yield from iter_exprs(expr.value) elif isinstance(expr, AssignExpression): yield from iter_exprs(expr.value) elif isinstance(expr, ConditionalSelectExpression): yield from iter_exprs(expr.predicate) yield from iter_exprs(expr.condition) elif isinstance(expr, ConditionalAssignExpression): yield from iter_exprs(expr.value) yield from iter_exprs(expr.condition) elif isinstance(expr, (DeicticSelectExpression, DeicticAssignExpression)): yield from iter_exprs(expr.expression) elif isinstance(expr, (VariableExpression, ConstantExpression, ObjectConstantExpression)): pass else: raise TypeError('Unknown expression type: {}.'.format(type(expr)))
[docs]def find_free_variables(expr: Expression) -> Tuple[Variable, ...]: free_variables = dict() bounded_variables = dict() def dfs(e: Expression): if isinstance(e, VariableExpression): if e.variable.name not in bounded_variables: free_variables[e.variable.name] = e.variable elif isinstance(e, ListCreationExpression): [dfs(arg) for arg in e.arguments] elif isinstance(e, ListExpansionExpression): dfs(e.expression) elif isinstance(e, (QuantificationExpression, GeneralizedQuantificationExpression)): bounded_variables[e.variable.name] = e.variable dfs(e.expression) del bounded_variables[e.variable.name] elif isinstance(e, (FunctionApplicationExpression, ListFunctionApplicationExpression)): [dfs(arg) for arg in e.arguments] elif isinstance(e, BoolExpression): [dfs(arg) for arg in e.arguments] elif isinstance(e, (ObjectCompareExpression, ValueCompareExpression)): dfs(e.lhs) dfs(e.rhs) elif isinstance(e, PredicateEqualExpression): dfs(e.predicate) dfs(e.value) elif isinstance(e, AssignExpression): dfs(e.value) elif isinstance(e, ConditionalSelectExpression): dfs(e.predicate) dfs(e.condition) elif isinstance(e, ConditionalAssignExpression): dfs(e.value) dfs(e.condition) elif isinstance(e, (DeicticSelectExpression, DeicticAssignExpression)): bounded_variables[e.variable.name] = e.variable dfs(e.expression) del bounded_variables[e.variable.name] elif isinstance(e, (ConstantExpression, ObjectConstantExpression)): pass else: raise TypeError('Unknown expression type: {}.'.format(type(e))) dfs(expr) return tuple(free_variables.values())