Source code for concepts.dm.crow.executors.crow_executor

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_executor.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""This file contains the executor classes for the Crow framework.

All the executors are based on :class:`~concepts.dsl.tensor_value.TensorValue` classes.
It supports all expressions defined in :mod:`~concepts.dsl.expressions`, including the basic
function application expressions and a few first-order logic expression types. The executors
are designed to be "differentiable", which means we can directly do backpropagation on the
computed values.

The main entry for the executor is the :class:`CrowExecutor` class.
Internally it contains two executor implementations: the basic one, and an "optimistic" one,
which handles the case where the value of a variable can be unknown and "optimistic".
"""

import itertools
import contextlib
from typing import Any, Optional, Union, Iterator, Sequence, Tuple, List, Mapping, Dict, Callable, TYPE_CHECKING

import torch
import jactorch
from jacinle.logging import get_logger

import concepts.dsl.expression as E
from concepts.dsl.dsl_types import BOOL, INT64, FLOAT32, STRING, AutoType, TensorValueTypeBase, ScalarValueType, VectorValueType, NamedTensorValueType, PyObjValueType, BatchedListType, QINDEX, Variable
from concepts.dsl.expression import Expression, BoolOpType, QuantificationOpType
from concepts.dsl.expression_visitor import ExpressionVisitor
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorizedPyObjValues, TensorValue, MaskedTensorStorage
from concepts.dsl.tensor_state import StateObjectReference, StateObjectList, TensorState, NamedObjectTensorState
from concepts.dsl.tensor_value_utils import expand_argument_values
from concepts.dsl.constraint import (
    OPTIM_MAGIC_NUMBER_MAGIC, is_optimistic_value, OptimisticValue, cvt_opt_value,
    Constraint, EqualityConstraint, ConstraintSatisfactionProblem, SimulationFluentConstraintFunction
)
from concepts.dsl.parsers.parser_base import ParserBase
from concepts.dsl.executors.tensor_value_executor import BoundedVariablesDictCompatible, TensorValueExecutorBase, TensorValueExecutorReturnType

from concepts.dm.crow.crow_function import CrowFeature, CrowFunction, CrowFunctionEvaluationMode
from concepts.dm.crow.crow_domain import CrowDomain, CrowState
from concepts.dm.crow.executors.python_function import CrowPythonFunctionRef, CrowPythonFunctionCrossRef, CrowSGC

if TYPE_CHECKING:
    from concepts.dm.crow.interfaces.controller_interface import CrowSimulationControllerInterface


logger = get_logger(__file__)

__all__ = [
    'CrowExecutor', 'CrowExecutionDefaultVisitor', 'CrowExecutionCSPVisitor',
]


[docs] class CrowExecutor(TensorValueExecutorBase):
[docs] def __init__(self, domain: CrowDomain, parser: Optional[ParserBase] = None, load_external_function_implementations: bool = True): """Initialize a Crow expression executor. Args: domain: the domain of this executor. parser: the parser to be used. This argument is optional. If provided, the execute function can take strings as input. load_external_function_implementations: whether to load the external function implementations defined in the domain file. """ super().__init__(domain, parser) self._csp = None self._optimistic_execution = False self._sgc = None self._default_visitor = CrowExecutionDefaultVisitor(self) self._csp_visitor = CrowExecutionCSPVisitor(self) self._register_default_function_implementations() if load_external_function_implementations: self._register_external_function_implementations_from_domain() self._simulation_interface = None self._effect_update_from_simulation = False self._effect_update_from_execution = False self._effect_state_index = None
@property def simulation_interface(self) -> Optional['CrowSimulationControllerInterface']: return self._simulation_interface
[docs] def set_simulation_interface(self, simulation_interface: 'CrowSimulationControllerInterface'): self._simulation_interface = simulation_interface
@property def csp(self) -> Optional[ConstraintSatisfactionProblem]: """The CSP that describes the constraints in past executions.""" return self._csp @property def sgc(self) -> Optional[CrowSGC]: """The SGC (state-goal-constraints) context.""" return self._sgc @property def optimistic_execution(self) -> bool: """Whether to execute the expression optimistically (i.e., treat all CSP constraints True).""" return self._optimistic_execution
[docs] @contextlib.contextmanager def with_csp(self, csp: Optional[ConstraintSatisfactionProblem]): """A context manager to temporarily set the CSP of the executor.""" old_csp = self._csp self._csp = csp yield self._csp = old_csp
[docs] @contextlib.contextmanager def with_sgc(self, sgc: Optional[CrowSGC]): """A context manager to temporarily set the SGC of the executor.""" old_sgc = self._sgc self._sgc = sgc yield self._sgc = old_sgc
def _register_default_function_implementations(self): for t in itertools.chain(self.domain.types.values(), [BOOL, INT64, FLOAT32, STRING]): if isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, ScalarValueType) or isinstance(t, ScalarValueType): # NB(Jiayuan Mao @ 2024/07/11): Can't use the lambda function here, because the value of `t` is not captured. self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation( f'type::{t.typename}::hash', CrowPythonFunctionRef(lambda x: x.tensor.item(), support_batch=False, unwrap_values=False) ) elif isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, VectorValueType): self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) elif isinstance(t, PyObjValueType): if t.base_typename == 'string': def compare(x, y): return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values == y.tensor.values, dtype=torch.bool), x.batch_dims) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(compare, support_batch=True, auto_broadcast=True, unwrap_values=False)) def compare_neg(x, y): return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values != y.tensor.values, dtype=torch.bool), x.batch_dims) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(compare_neg, support_batch=True, auto_broadcast=True, unwrap_values=False)) for fname, cross_ref_name in self.domain.external_function_crossrefs.items(): self.register_function_implementation(fname, CrowPythonFunctionCrossRef(cross_ref_name)) def _register_external_function_implementations_from_domain(self): for filepath in self.domain.external_function_implementation_files: self.load_external_function_implementations_from_file(filepath)
[docs] def load_external_function_implementations_from_file(self, filepath: str): from jacinle.utils.imp import load_module_filename module = load_module_filename(filepath) for name, func in module.__dict__.items(): if isinstance(func, CrowPythonFunctionRef): self.register_function_implementation(name, func)
_domain: CrowDomain @property def domain(self) -> CrowDomain: return self._domain @property def state(self) -> CrowState: """The state of the executor.""" return self._state @property def effect_update_from_simulation(self) -> bool: """A context variable indicating whether the current effect should be updated from simulation, instead of the evaluation of expressions.""" return self._effect_update_from_simulation @property def effect_update_from_execution(self) -> bool: """A context variable indicating whether the current effect should be updated from the execution of the operator.""" return self._effect_update_from_execution @property def effect_state_index(self) -> Optional[int]: return self._effect_state_index
[docs] def parse(self, string: Union[str, Expression], *, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None) -> Expression: if isinstance(string, Expression): return string return self._domain.parse(string, state=state, variables=variables)
_function_implementations: Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]] @property def function_implementations(self) -> Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]]: return self._function_implementations
[docs] def register_function_implementation(self, name: str, func: Union[Callable, CrowPythonFunctionRef, CrowPythonFunctionCrossRef]): if isinstance(func, CrowPythonFunctionRef): self._function_implementations[name] = func.set_executor(self) elif isinstance(func, CrowPythonFunctionCrossRef): self._function_implementations[name] = func else: self._function_implementations[name] = CrowPythonFunctionRef(func)
[docs] def get_function_implementation(self, name: str) -> CrowPythonFunctionRef: while name in self._function_implementations: func = self._function_implementations[name] if isinstance(func, CrowPythonFunctionCrossRef): name = func.cross_ref_name else: return func raise KeyError(f'Function {name} not found.')
[docs] def execute( self, expression: Union[Expression, str], state: Optional[TensorState] = None, bounded_variables: Optional[BoundedVariablesDictCompatible] = None, csp: Optional[ConstraintSatisfactionProblem] = None, sgc: Optional[CrowSGC] = None, bypass_bounded_variable_check: bool = False, optimistic_execution: bool = False ) -> TensorValueExecutorReturnType: """Execute an expression. Args: expression: the expression to execute. state: the state to use. If None, the current state of the executor will be used. bounded_variables: the bounded variables to use. If None, the current bounded variables of the executor will be used. csp: the constraint satisfaction problem to use. If None, the current CSP of the executor will be used. sgc: the SGC (state-goal-constraints) context to use. If None, the current SGC context of the executor will be used. bypass_bounded_variable_check: whether to bypass the check of the bounded variables. optimistic_execution: whether to execute the expression optimistically (i.e., treat all CSP constraints True). Returns: the TensorValue object. """ if isinstance(expression, str): all_variables = list() if bounded_variables is not None: for k, v in bounded_variables.items(): if isinstance(k, str): all_variables.append(Variable(k, AutoType)) elif isinstance(k, Variable): all_variables.append(k) expression = self.parse(expression, state=state, variables=all_variables) state = state if state is not None else self._state csp = csp if csp is not None else self._csp sgc = sgc if sgc is not None else self._sgc bounded_variables = bounded_variables if bounded_variables is not None else self._bounded_variables with self.with_state(state), self.with_csp(csp), self.with_sgc(sgc), self.with_bounded_variables(bounded_variables, bypass_bounded_variable_check=bypass_bounded_variable_check): self._optimistic_execution, backup = optimistic_execution, self._optimistic_execution try: return self._execute(expression) finally: self._optimistic_execution = backup
def _execute(self, expression: Expression) -> TensorValueExecutorReturnType: if self.csp is not None: return self._csp_visitor.visit(expression) return self._default_visitor.visit(expression)
[docs] @contextlib.contextmanager def update_effect_mode(self, evaluation_mode: CrowFunctionEvaluationMode, state_index: Optional[int] = None): old_from_simulation = self._effect_update_from_simulation old_from_execution = self._effect_update_from_execution old_state_index = self._effect_state_index self._effect_update_from_simulation = evaluation_mode is CrowFunctionEvaluationMode.SIMULATION self._effect_update_from_execution = evaluation_mode is CrowFunctionEvaluationMode.EXECUTION self._effect_state_index = state_index yield self._effect_update_from_simulation = old_from_simulation self._effect_update_from_execution = old_from_execution self._effect_state_index = old_state_index
def _fast_index(value, ind): if isinstance(value, TensorValue): return value.fast_index(ind) if len(ind) == 0: return value raise ValueError(f'Unsupported value type: {value}.') class _CrowUnaryFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x): return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor), x.batch_dims) class _CrowArithFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x, y): return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims) class _CrowComparisonFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x, y): return TensorValue(BOOL, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims)
[docs] class CrowExecutionDefaultVisitor(ExpressionVisitor): """The underlying default implementation for :class:`CrowExecutor`. This function does not handle CSPs (a.k.a. optimistic execution)."""
[docs] def __init__(self, executor: CrowExecutor): """Initialize a PDExpressionExecutionDefaultVisitor. Args: executor: the executor that uses this visitor. """ self.executor = executor
@property def csp(self) -> ConstraintSatisfactionProblem: return self.executor.csp
[docs] def visit_null_expression(self, expr: E.NullExpression) -> Any: return None
[docs] def visit_variable_expression(self, expr: E.VariableExpression) -> TensorValueExecutorReturnType: variable = expr.variable if variable.dtype is AutoType: return self.executor.retrieve_bounded_variable_by_name(variable.name) return self.executor.bounded_variables[variable.dtype.typename][variable.name]
[docs] def visit_object_constant_expression(self, expr: E.ObjectConstantExpression) -> Union[StateObjectReference, ListValue]: const = expr.constant if isinstance(const.name, (StateObjectReference, StateObjectList)): return const.name state = self.executor.state assert isinstance(state, NamedObjectTensorState) if isinstance(const, ListValue): return StateObjectList(const.dtype, [StateObjectReference(c.name, state.get_typed_index(c.name, c.dtype.typename), c.dtype) for c in const.values]) return StateObjectReference( const.name, state.get_typed_index(const.name, const.dtype.typename), const.dtype )
[docs] def visit_constant_expression(self, expr: E.ConstantExpression) -> TensorValueExecutorReturnType: value = expr.constant assert isinstance(value, (TensorValue, ListValue)) return value
[docs] def visit_list_creation_expression(self, expr: E.ListCreationExpression) -> Any: argument_values = self.forward_args(*expr.arguments, force_tuple=True) return ListValue(expr.return_type, argument_values)
[docs] def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> Any: raise RuntimeError('List expansion is not supported in the expression evaluation.')
[docs] def visit_function_application_expression( self, expr: E.FunctionApplicationExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None ) -> TensorValueExecutorReturnType: function = expr.function return_type = function.return_type state = self.executor.state assert isinstance(function, (CrowFeature, CrowFunction)) if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True) has_list_values = any(isinstance(v, ListValue) for v in argument_values) if isinstance(function, CrowFunction) and function.is_generator_placeholder: # always true branch if has_list_values: raise NotImplementedError('List values are not supported in the generator placeholder function.') argument_values = expand_argument_values(argument_values) batched_value = None for argv in argument_values: if isinstance(argv, TensorValue): batched_value = argv break assert batched_value is not None rv = torch.ones( batched_value.tensor.shape[:batched_value.total_batch_dims], dtype=torch.bool, device=batched_value.tensor.device if isinstance(batched_value.tensor, torch.Tensor) else None ) assert return_type == BOOL # Add "true" asserts to the csp. if self.csp is not None and not self.executor.optimistic_execution: expr_string = expr.cached_string(-1) for ind in _iter_tensor_indices(rv): self.csp.add_constraint(Constraint.from_function( function, # I think there is some bug here... for "StateObjectReference" [argv.fast_index(tuple(ind)) if isinstance(argv, TensorValue) else argv for argv in argument_values], True ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) return TensorValue( BOOL, batched_value.batch_variables, rv, batch_dims=state.batch_dims ) elif function.is_cacheable and state is not None and function.name in state.features: # state maybe None if we are evaluating a "constraint" function. argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values] batch_variables = list() anonymous_index = 0 accessors = list() index_types = list() for i, (arg, value) in enumerate(zip(expr.arguments, argument_values)): if value == QINDEX: batch_variables.append(arg.variable.name) accessors.append(value) elif isinstance(value, StateObjectList): batch_variables.append(f'@{anonymous_index}') index_types.append(value.element_type) anonymous_index += 1 accessors.append(value.array_accessor) else: accessors.append(value) value = state.features[function.name][tuple(accessors)] if 'dirty_features' in state.internals and function.name in state.internals['dirty_features']: value_opt = state.features[function.name].tensor_optimistic_values[tuple(argument_values)] if (value_opt < 0).any().item(): if function.is_derived: # Case 1: derived function. with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}): return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values) else: # Case 2: external direct function. return self._rename_derived_function_application(self.forward_external_function(function.name, argument_values, return_type, expression=expr), function.arguments, expr.arguments, argument_values) if len(index_types) == 0: return value.rename_batch_variables(batch_variables, force=True) else: rv_dtype = BatchedListType(value.dtype, index_types) return value.rename_batch_variables(batch_variables, dtype=rv_dtype, force=True) elif function.is_derived: with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}): return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values) else: # dynamic predicate is exactly the same thing as a pre-defined external function. # only supports external functions with a single return value. return self._rename_derived_function_application(self.forward_external_function(function.name, argument_values, return_type, expression=expr), function.arguments, expr.arguments, argument_values)
def _rename_derived_function_application(self, rv: TensorValue, function_argument_variables, outer_arguments, argument_values): if not isinstance(rv, TensorValue): return rv output_batch_variables = list(rv.batch_variables) for function_argument_variable, outer_argument_expr, argument_value in zip(function_argument_variables, outer_arguments, argument_values): if argument_value is QINDEX: assert function_argument_variable.name in output_batch_variables, f'Variable {function_argument_variable.name} not found in the output batch variables {output_batch_variables}. Report this as a bug.' assert isinstance(outer_argument_expr, E.VariableExpression), 'Only variable arguments can be QINDEX. Report this as a bug.' index = output_batch_variables.index(function_argument_variable.name) output_batch_variables[index] = outer_argument_expr.variable.name return rv.clone(clone_tensor=False).rename_batch_variables(output_batch_variables, clone=False)
[docs] def visit_list_function_application_expression(self, expr: E.ListFunctionApplicationExpression) -> Any: raise DeprecationWarning('List function application is deprecated.') argument_values = self.forward_args(*expr.arguments, force_tuple=True) return self.visit_function_application_expression(expr, argument_values)
# if nr_values is None: # return self.visit_function_application_expression(expr, argument_values) # else: # rvs = list() # for i in range(nr_values): # this_argv = tuple(argv.values[i] if isinstance(argv, ListValue) else argv for argv in argument_values) # rv = self.visit_function_application_expression(expr, this_argv) # rvs.append(rv) # return ListValue(expr.return_type, rvs)
[docs] def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True) argument_values = expand_argument_values(argument_values) assert len(argument_values) > 0 assert all(isinstance(v, TensorValue) for v in argument_values) dtype = argument_values[0].dtype batch_variables = argument_values[0].batch_variables if expr.bool_op is BoolOpType.NOT: assert len(argument_values) == 1 return TensorValue( dtype, batch_variables, torch.logical_not(argument_values[0].tensor) if argument_values[0].tensor.dtype == torch.bool else 1 - argument_values[0].tensor, batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.AND: if len(argument_values) == 1: return argument_values[0] return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amin(dim=-1), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.OR: if len(argument_values) == 1: return argument_values[0] return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amax(dim=-1), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.XOR: if len(argument_values) == 1: return argument_values[0] for argv in argument_values: if argv.tensor.requires_grad: raise RuntimeError('XOR does not support gradients.') return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).sum(dim=-1) % 2, None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.IMPLIES: assert len(argument_values) == 2 return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.max(1 - argument_values[0].tensor, argument_values[1].tensor), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) else: raise ValueError(f'Unknown bool op type: {expr.bool_op}.')
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType: if value is None: with self.executor.new_bounded_variables({expr.variable: QINDEX}): value = self.forward_args(expr.expression) assert isinstance(value, TensorValue) if expr.quantification_op is QuantificationOpType.BATCHED: return value dtype = value.dtype batch_variables = value.batch_variables variable_index = batch_variables.index(expr.variable.name) batch_variables = batch_variables[:variable_index] + batch_variables[variable_index + 1:] if value.tensor_mask is None: tensor = value.tensor mask = None else: tensor, mask = value.tensor, value.tensor_mask if expr.quantification_op is QuantificationOpType.FORALL: tensor = (tensor * mask + (1 - mask)).to(tensor.dtype) elif expr.quantification_op is QuantificationOpType.EXISTS: tensor = (tensor * mask).to(tensor.dtype) else: raise ValueError(f'Unknown quantification op type: {expr.quantification_op}.') if expr.quantification_op is QuantificationOpType.FORALL: return TensorValue( dtype, batch_variables, MaskedTensorStorage(tensor.amin(dim=variable_index + value.batch_dims), None, mask), batch_dims=self.executor.state.batch_dims ) elif expr.quantification_op is QuantificationOpType.EXISTS: return TensorValue( dtype, batch_variables, MaskedTensorStorage(tensor.amax(dim=variable_index + value.batch_dims), None, mask), batch_dims=self.executor.state.batch_dims ) else: raise ValueError(f'Unknown quantifier type: {expr.quantification_op}.')
[docs] def visit_object_compare_expression(self, expr: E.ObjectCompareExpression) -> Any: v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1]) values = list() batched_variables = list() for i, v in enumerate([v1, v2]): if v is QINDEX: arg = expr.arguments[i] assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.' batched_variables.append(arg.variable.name) values.append(torch.arange(0, len(self.executor.state.object_type2name[arg.variable.dtype.typename]), dtype=torch.int64)) elif isinstance(v, StateObjectReference): values.append(torch.tensor(v.index, dtype=torch.int64)) else: raise ValueError(f'Unsupported value type: {v}.') if v1 is QINDEX: if v2 is QINDEX: value = values[0].unsqueeze(1).eq(values[1].unsqueeze(0)) else: value = values[0].eq(values[1]) else: value = values[0].eq(values[1]) if expr.compare_op is E.CompareOpType.EQ: pass elif expr.compare_op is E.CompareOpType.NEQ: value = torch.logical_not(value) else: raise ValueError(f'Unknown compare op type for object types: {expr.compare_op}.') return TensorValue(BOOL, batched_variables, value, batch_dims=self.executor.state.batch_dims)
[docs] def visit_value_compare_expression(self, expr: E.ValueCompareExpression) -> Any: v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1]) mapping = { E.CompareOpType.EQ: 'equal', E.CompareOpType.NEQ: 'not_equal', E.CompareOpType.LT: 'less', E.CompareOpType.LEQ: 'less_equal', E.CompareOpType.GT: 'greater', E.CompareOpType.GEQ: 'greater_equal', } target_op = mapping[expr.compare_op] if isinstance(v1.dtype, PyObjValueType): rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr) elif isinstance(v1.dtype, NamedTensorValueType) or isinstance(v1.dtype, ScalarValueType): rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr) else: raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {v1.dtype} and {v2.dtype}.') return rv
[docs] def visit_condition_expression(self, expr: E.ConditionExpression) -> Any: raise NotImplementedError('Condition expression is not supported in the expression evaluation.')
[docs] def visit_find_one_expression(self, expr: E.FindOneExpression) -> Any: with self.executor.new_bounded_variables({expr.variable: QINDEX}): values = self.visit(expr.expression) assert values.batch_dims == 0 assert len(values.batch_variables) == 1 x = (values.tensor > 0.5) objects = x.nonzero().squeeze(-1).detach().cpu().tolist() if len(objects) == 0: raise RuntimeError('No object found. Currently the executor does not support this case.') names = self.executor.state.object_type2name[expr.variable.dtype.typename] return StateObjectReference(names[objects[0]], objects[0], expr.variable.dtype)
[docs] def visit_find_all_expression(self, expr: E.FindAllExpression) -> Any: with self.executor.new_bounded_variables({expr.variable: QINDEX}): values = self.visit(expr.expression) assert values.batch_dims == 0 assert len(values.batch_variables) == 1 x = (values.tensor > 0.5) objects = x.nonzero().squeeze(-1).detach().cpu().tolist() names = self.executor.state.object_type2name[expr.variable.dtype.typename] return StateObjectList(expr.return_type, [StateObjectReference(names[i], i, expr.variable.dtype) for i in objects])
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType: if feature is None or value is None: feature, value = self.forward_args(expr.predicate, expr.value) feature, value = expand_argument_values([feature, value]) if isinstance(feature.dtype, PyObjValueType): rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr) elif isinstance(feature.dtype, NamedTensorValueType): rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr) else: raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {feature.dtype} and {value.dtype}.') return rv
[docs] def visit_assign_expression(self, expr: E.AssignExpression): # TODO(Jiayuan Mao @ 2024/01/22): is this really the right thing to do? if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: return state: CrowState = self.executor.state argument_values = list(self.forward_args(*expr.predicate.arguments, force_tuple=True)) target_value = self.forward_args(expr.value) for i, (arg, value) in enumerate(zip(expr.predicate.arguments, argument_values)): if value == QINDEX: assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.' # TODO(Jiayuan Mao @ 2024/08/3): I think we need to think about how to align the batch variable dimensions... pass elif isinstance(value, StateObjectList): argument_values[i] = value.array_accessor elif isinstance(value, StateObjectReference): argument_values[i] = value.index # if state.features[expr.predicate.function.name].quantized: # if not value.quantized: # value = self.executor.value_quantizer.quantize_value(value) # else: # if value.quantized: # value = self.executor.value_quantizer.unquantize_value(value) function_name = expr.predicate.function.name if function_name not in state.features: if function_name not in self.executor.domain.features and function_name not in self.executor.domain.functions: raise ValueError(f'Function {function_name} not found in the state.') state.init_dirty_feature(self.executor.domain.functions[function_name]) if isinstance(target_value, ListValue): if len(target_value.values) == 0: return target_value = TensorValue(target_value.values[0].dtype, 1 + len(target_value.values[0].batch_variables), torch.stack([x.tensor for x in target_value.values])) state.features[function_name][tuple(argument_values)] = target_value if 'dirty_features' in state.internals and function_name in state.internals['dirty_features']: state.features[function_name].tensor_optimistic_values[tuple(argument_values)] = 0
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType: value, condition = self.forward_args(expr.predicate, expr.condition) value = value.clone() if value.tensor_mask is None: value.tensor_mask = condition.tensor else: value.tensor_mask = torch.min(value.tensor_mask, condition.tensor) return value
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> TensorValueExecutorReturnType: with self.executor.new_bounded_variables({expr.variable: QINDEX}): return self.visit(expr.expression)
CONDITIONAL_ASSIGN_QUANTIZE = False
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression): state = self.executor.state argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values] value = self.forward_args(expr.value) condition = self.forward_args(expr.condition) condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor feature = state.features[expr.predicate.function.name] origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too. # assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0 if value.is_tensorized_pyobj: raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.') else: if condition_tensor.dim() < value.tensor.dim(): condition_tensor = condition_tensor.unsqueeze(-1) state.features[expr.predicate.function.name].tensor[argument_values] = ( condition_tensor.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor).to(origin_tensor.dtype) * origin_tensor )
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression): with self.executor.new_bounded_variables({expr.variable: QINDEX}): self.visit(expr.expression)
[docs] def forward_args(self, *args, force_tuple: bool = False, expand_list_arguments: bool = False) -> Union[TensorValueExecutorReturnType, Tuple[TensorValueExecutorReturnType, ...]]: if len(args) == 1 and not force_tuple: rvs = self.visit(args[0]) else: rvs = tuple(self.visit(arg) for arg in args) if expand_list_arguments: expanded_rvs = list() for rv in rvs: if isinstance(rv, ListValue): expanded_rvs.extend(rv.values) else: expanded_rvs.append(rv) if not force_tuple and len(expanded_rvs) == 1: return expanded_rvs[0] else: return tuple(expanded_rvs) else: return rvs
[docs] def forward_external_function( self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[Expression] = None ) -> TensorValue: external_function = self.executor.get_function_implementation(function_name) assert isinstance(external_function, CrowPythonFunctionRef) function_def = expression.function if isinstance(expression, E.FunctionApplicationExpression) else None return external_function.forward(argument_values, return_type=return_type, auto_broadcast=auto_broadcast, function_def=function_def)
[docs] class CrowExecutionCSPVisitor(CrowExecutionDefaultVisitor):
[docs] def __init__(self, executor: CrowExecutor): super().__init__(executor)
[docs] def forward_external_function( self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[E.FunctionApplicationExpression] = None ) -> TensorValue: need_reset_simulation_state = False need_simulation_optimistic_execution = False # This is a hack to handle the case where the expression is ValueCompareExpression if isinstance(expression, E.FunctionApplicationExpression): function = expression.function if isinstance(function, CrowFunction) and function.is_simulation_dependent or function.is_execution_dependent: assert self.executor.state is not None assert self.executor.csp is not None # print(f'function_name: {function_name}, csp_t={self.executor.csp.get_state_timestamp()}, state_t={self.executor.state.simulation_state_index}') if self.executor.csp.get_state_timestamp() != self.executor.state.simulation_state_index: need_simulation_optimistic_execution = True else: if self.executor.state.simulation_state_index > 0: need_reset_simulation_state = True argument_values = expand_argument_values(argument_values) tensor_values = [argv for argv in argument_values if isinstance(argv, TensorValue)] optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None] if len(optimistic_masks) > 0 or need_simulation_optimistic_execution: if need_simulation_optimistic_execution: if len(optimistic_masks) == 0: if len(tensor_values) > 0: optimistic_mask = torch.ones(size=tensor_values[0].tensor.shape[:tensor_values[0].total_batch_dims], dtype=torch.bool, device=tensor_values[0].tensor.device) elif all(isinstance(argv, StateObjectReference) for argv in argument_values): optimistic_mask = torch.ones(size=tuple(), dtype=torch.bool, device=None) else: raise NotImplementedError('Unsupported case for simulation dependent functions. Either all arguments are StateObjectReference or at least one of the values are TensorValue.') else: optimistic_mask = optimistic_masks[0].clone() optimistic_mask[...] = True rv = TensorValue( return_type, batch_variables=tensor_values[0].batch_variables if len(tensor_values) > 0 else [], tensor=torch.zeros(optimistic_mask.shape), batch_dims=tensor_values[0].batch_dims if len(tensor_values) > 0 else 0 ) else: optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1) if need_reset_simulation_state: with self.executor.simulation_interface.restore_context(): self.executor.simulation_interface.restore_state_keep(self.executor.state.simulation_state, self.executor.state.simulation_state_index) rv = super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression) else: rv = super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression) if optimistic_mask.sum().item() == 0: return rv rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expression.cached_string(-1) if isinstance(expression, E.ValueCompareExpression): if expression.compare_op is E.CompareOpType.EQ: constraint_function = Constraint.EQUAL else: raise NotImplementedError(f'Unsupported compare op type: {expression.compare_op} for optimistic execution.') elif isinstance(expression, E.PredicateEqualExpression): constraint_function = Constraint.EQUAL elif isinstance(expression, E.FunctionApplicationExpression): constraint_function = expression.function else: raise NotImplementedError(f'Unsupported expression type: {expression} for optimistic execution.') for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) new_identifier = self.executor.csp.new_var(return_type, wrap=True) rv.tensor_optimistic_values[ind] = new_identifier.identifier self.csp.add_constraint(Constraint.from_function( constraint_function, [_fast_index(argv, ind) for argv in argument_values], new_identifier ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) return rv if need_reset_simulation_state: with self.executor.simulation_interface.restore_context(): self.executor.simulation_interface.restore_state_keep(self.executor.state.simulation_state, self.executor.state.simulation_state_index) return super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression) else: return super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
[docs] def visit_function_application_expression(self, expr: E.FunctionApplicationExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: return super().visit_function_application_expression(expr, argument_values)
[docs] def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True) argument_values = list(expand_argument_values(argument_values)) for argv in argument_values: assert argv.dtype == BOOL, 'Boolean expression only supports boolean values in CSP mode.' optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None] if len(optimistic_masks) > 0: optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1) if optimistic_mask.sum().item() > 0: rv = super().visit_bool_expression(expr, argument_values) rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expr.cached_string(-1) for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) this_argv = [argv.fast_index(ind, wrap=False) for argv in argument_values] determined = None if expr.return_type == BOOL: if expr.bool_op is BoolOpType.NOT: pass # nothing we can do. elif expr.bool_op is BoolOpType.AND: if 0 in this_argv or False in this_argv: determined = False elif expr.bool_op is BoolOpType.OR: if 1 in this_argv or True in this_argv: determined = True this_argv = [v for v in this_argv if isinstance(v, OptimisticValue)] else: # generalized boolean operations. pass if determined is None: new_identifier = self.csp.new_var(BOOL) rv.tensor_optimistic_values[ind] = new_identifier self.csp.add_constraint(Constraint( expr.bool_op, this_argv, cvt_opt_value(new_identifier, BOOL), ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) else: rv[ind] = determined return rv else: return super().visit_bool_expression(expr, argument_values) else: # if len(optimistic_masks) == 0 return super().visit_bool_expression(expr, argument_values)
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> Any: if value is None: with self.executor.new_bounded_variables({expr.variable: QINDEX}): value = self.forward_args(expr.expression) assert isinstance(value, TensorValue) assert value.dtype == BOOL, 'Quantification expression only supports boolean values in CSP mode.' value.init_tensor_optimistic_values() rv = super().visit_quantification_expression(expr, value) dim = value.batch_variables.index(expr.variable.name) + value.batch_dims value_transposed = value.tensor optimistic_values_transposed = value.tensor_optimistic_values if dim != value.tensor.ndim - 1: value_transposed = value_transposed.transpose(dim, -1) # put the target dimension last. optimistic_values_transposed = optimistic_values_transposed.transpose(dim, -1) optimistic_mask_transposed = is_optimistic_value(optimistic_values_transposed) value_transposed = torch.where( optimistic_mask_transposed, optimistic_values_transposed, value_transposed.to(optimistic_values_transposed.dtype) ) optimistic_mask = optimistic_mask_transposed.any(dim=-1) if optimistic_mask.sum().item() == 0: return rv rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expr.cached_string(-1) for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) this_argv = value_transposed[ind].tolist() determined = None if expr.quantification_op is QuantificationOpType.FORALL: if 0 in this_argv or False in this_argv: determined = False else: if 1 in this_argv or True in this_argv: determined = True this_argv = list(filter(is_optimistic_value, this_argv)) if determined is None: new_identifier = self.csp.new_var(BOOL) rv.tensor_optimistic_values[ind] = new_identifier self.csp.add_constraint(Constraint( expr.quantification_op, [OptimisticValue(value.dtype, int(v)) for v in this_argv], OptimisticValue(value.dtype, new_identifier), ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) else: rv.tensor[ind] = determined return rv
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> Any: if feature is None or value is None: feature, value = self.forward_args(expr.predicate, expr.value) feature, value = expand_argument_values([feature, value]) feature.init_tensor_optimistic_values() value.init_tensor_optimistic_values() optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values)) if optimistic_mask.sum().item() > 0: raise NotImplementedError('Optimistic execution is not supported for predicate equal expression.') rv = super().visit_predicate_equal_expression(expr, feature, value) return rv
# feature.init_tensor_optimistic_values() # value.init_tensor_optimistic_values() # optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values)) # if optimistic_mask.sum().item() == 0: # return rv # rv.init_tensor_optimistic_values() # if self.executor.optimistic_execution: # rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC # else: # expr_string = expr.cached_string(-1) # for ind in optimistic_mask.nonzero().tolist(): # ind = tuple(ind) # this_argv = feature.fast_index(ind), value.fast_index(ind) # new_identifier = self.csp.new_var(BOOL) # rv.tensor_optimistic_values[ind] = new_identifier # self.csp.add_constraint(EqualityConstraint( # *[cvt_opt_value(v, feature.dtype) for v in this_argv], # OptimisticValue(BOOL, new_identifier) # ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) # return rv
[docs] def visit_assign_expression(self, expr: E.AssignExpression) -> Any: if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: feature = self.executor.state.features[expr.predicate.function.name] feature.init_tensor_optimistic_values() argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) assert self.executor.effect_state_index is not None, 'Effect action index must be set if the target predicate will be updated from simulation.' expr_string = expr.cached_string(-1) for entry_values in _expand_tensor_indices(feature, argument_values): if self.executor.optimistic_execution: raise RuntimeError('Optimistic execution is not supported for effect update from simulation.') else: opt_identifier = self.csp.new_var(feature.dtype, wrap=True) feature.tensor_optimistic_values[entry_values] = opt_identifier.identifier self.csp.add_constraint(Constraint( SimulationFluentConstraintFunction(self.executor.effect_state_index, expr.predicate.function, entry_values, is_execution_constraint=self.executor.effect_update_from_execution), [], opt_identifier, note=f'{expr_string}::{entry_values}' if len(entry_values) > 0 else expr_string )) else: return super().visit_assign_expression(expr)
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType: return super().visit_conditional_select_expression(expr)
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> Any: return super().visit_deictic_select_expression(expr)
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression) -> Any: if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: raise NotImplementedError('Conditional assign is not supported in simulation mode.') if self.executor.optimistic_execution: raise RuntimeError('Optimistic execution is not supported for conditional assign.') state = self.executor.state argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) value = self.forward_args(expr.value) condition = self.forward_args(expr.condition) condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor condition_tensor = (condition_tensor > 0.5).to(torch.bool) feature = state.features[expr.predicate.function.name] origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too. # assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0 # NB(Jiayuan Mao @ 2023/08/15): conditional assignment does not support "soft" assignment. if value.is_tensorized_pyobj: raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.') else: if condition_tensor.dim() < value.tensor.dim(): condition_tensor_expanded = condition_tensor.unsqueeze(-1) else: condition_tensor_expanded = condition_tensor feature.tensor[argument_values] = ( condition_tensor_expanded.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor_expanded).to(origin_tensor.dtype) * origin_tensor ) feature.init_tensor_optimistic_values() if value.tensor_optimistic_values is not None: feature.tensor_optimistic_values[argument_values] = ( condition_tensor.to(torch.int64) * value.tensor_optimistic_values + (1 - condition_tensor).to(torch.int64) * state.features[expr.predicate.function.name].tensor_optimistic_values[argument_values] ) if condition.tensor_optimistic_values is None: pass else: optimistic_mask = is_optimistic_value(condition.tensor_optimistic_values) if optimistic_mask.sum().item() == 0: pass else: expr_string = expr.cached_string(-1) dtype = expr.predicate.function.return_type for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) new_identifier = self.csp.new_var(dtype, wrap=True) neg_condition_identifier = self.csp.new_var(BOOL, wrap=True) eq_1_identifier = self.csp.new_var(BOOL, wrap=True) eq_2_identifier = self.csp.new_var(BOOL, wrap=True) condition_identifier = condition.tensor_optimistic_values[ind].item() self.csp.add_constraint(EqualityConstraint( new_identifier, cvt_opt_value(value.fast_index(ind), dtype), eq_1_identifier, ), note=f'{expr_string}::{ind}::eq-1') self.csp.add_constraint(EqualityConstraint( new_identifier, cvt_opt_value(origin_tensor.fast_index(ind) if isinstance(origin_tensor, TensorizedPyObjValues) else origin_tensor[ind], dtype), eq_2_identifier ), note=f'{expr_string}::{ind}::eq-2') self.csp.add_constraint(Constraint( BoolOpType.NOT, [OptimisticValue(BOOL, condition_identifier)], neg_condition_identifier ), note=f'{expr_string}::{ind}::neg-cond') self.csp.add_constraint(Constraint( BoolOpType.OR, [neg_condition_identifier, eq_1_identifier], cvt_opt_value(True, BOOL) ), note=f'{expr_string}::{ind}::implies-new') self.csp.add_constraint(Constraint( BoolOpType.OR, [OptimisticValue(BOOL, condition_identifier), eq_2_identifier], cvt_opt_value(True, BOOL) ), note=f'{expr_string}::{ind}::implies-old') feature.tensor_optimistic_values[ind] = new_identifier.identifier
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression) -> Any: return super().visit_deictic_assign_expression(expr)
def _iter_tensor_indices(target_tensor: torch.Tensor) -> Iterator[Tuple[int, ...]]: """Iterate from the indices of a tensor. """ for ind in torch.nonzero(torch.ones_like(target_tensor)): yield tuple(ind.tolist()) def _expand_tensor_indices(target_value: TensorValue, input_indices: Tuple[Union[int, slice, StateObjectReference], ...]) -> Iterator[Tuple[int, ...]]: """Iterate over entry indices based on the input indices. Supported indices are int and QINDEX (:). Args: target_value: the target value, used to determine the size of ``QINDEX``. input_indices: the indices to iterate over. Yields: the entry indices. """ indices = list() for i, ind in enumerate(input_indices): if isinstance(ind, int): indices.append(torch.tensor([ind], dtype=torch.int64)) elif isinstance(ind, slice): assert ind.step is None and ind.start is None and ind.stop is None # == ':' indices.append(torch.arange(target_value.tensor.shape[i], dtype=torch.int64)) elif isinstance(ind, StateObjectReference): indices.append(torch.tensor([ind.index], dtype=torch.int64)) else: raise ValueError(f'Invalid index type: {ind}') if len(indices) == 0: yield tuple() return if len(indices) == 1: for x in indices[0]: yield tuple([x.item()], ) return indices = torch.meshgrid(*indices, indexing='ij') indices = [i.flatten() for i in indices] for i in range(len(indices[0])): yield tuple(indices[j][i].item() for j in range(len(indices)))