#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : crow_executor.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""This file contains the executor classes for the Crow framework.
All the executors are based on :class:`~concepts.dsl.tensor_value.TensorValue` classes.
It supports all expressions defined in :mod:`~concepts.dsl.expressions`, including the basic
function application expressions and a few first-order logic expression types. The executors
are designed to be "differentiable", which means we can directly do backpropagation on the
computed values.
The main entry for the executor is the :class:`CrowExecutor` class.
Internally it contains two executor implementations: the basic one, and an "optimistic" one,
which handles the case where the value of a variable can be unknown and "optimistic".
"""
import itertools
import contextlib
from typing import Any, Optional, Union, Iterator, Sequence, Tuple, List, Mapping, Dict, Callable, TYPE_CHECKING
import torch
import jactorch
from jacinle.logging import get_logger
import concepts.dsl.expression as E
from concepts.dsl.dsl_types import BOOL, INT64, FLOAT32, STRING, AutoType, TensorValueTypeBase, ScalarValueType, VectorValueType, NamedTensorValueType, PyObjValueType, BatchedListType, QINDEX, Variable
from concepts.dsl.expression import Expression, BoolOpType, QuantificationOpType
from concepts.dsl.expression_visitor import ExpressionVisitor
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorizedPyObjValues, TensorValue, MaskedTensorStorage
from concepts.dsl.tensor_state import StateObjectReference, StateObjectList, TensorState, NamedObjectTensorState
from concepts.dsl.tensor_value_utils import expand_argument_values
from concepts.dsl.constraint import (
OPTIM_MAGIC_NUMBER_MAGIC, is_optimistic_value, OptimisticValue, cvt_opt_value,
Constraint, EqualityConstraint, ConstraintSatisfactionProblem, SimulationFluentConstraintFunction
)
from concepts.dsl.parsers.parser_base import ParserBase
from concepts.dsl.executors.tensor_value_executor import BoundedVariablesDictCompatible, TensorValueExecutorBase, TensorValueExecutorReturnType
from concepts.dm.crow.crow_function import CrowFeature, CrowFunction, CrowFunctionEvaluationMode
from concepts.dm.crow.crow_domain import CrowDomain, CrowState
from concepts.dm.crow.executors.python_function import CrowPythonFunctionRef, CrowPythonFunctionCrossRef, CrowSGC
if TYPE_CHECKING:
from concepts.dm.crow.interfaces.controller_interface import CrowSimulationControllerInterface
logger = get_logger(__file__)
__all__ = [
'CrowExecutor', 'CrowExecutionDefaultVisitor', 'CrowExecutionCSPVisitor',
]
[docs]
class CrowExecutor(TensorValueExecutorBase):
[docs]
def __init__(self, domain: CrowDomain, parser: Optional[ParserBase] = None, load_external_function_implementations: bool = True):
"""Initialize a Crow expression executor.
Args:
domain: the domain of this executor.
parser: the parser to be used. This argument is optional. If provided, the execute function can take strings as input.
load_external_function_implementations: whether to load the external function implementations defined in the domain file.
"""
super().__init__(domain, parser)
self._csp = None
self._optimistic_execution = False
self._sgc = None
self._default_visitor = CrowExecutionDefaultVisitor(self)
self._csp_visitor = CrowExecutionCSPVisitor(self)
self._register_default_function_implementations()
if load_external_function_implementations:
self._register_external_function_implementations_from_domain()
self._simulation_interface = None
self._effect_update_from_simulation = False
self._effect_update_from_execution = False
self._effect_state_index = None
@property
def simulation_interface(self) -> Optional['CrowSimulationControllerInterface']:
return self._simulation_interface
[docs]
def set_simulation_interface(self, simulation_interface: 'CrowSimulationControllerInterface'):
self._simulation_interface = simulation_interface
@property
def csp(self) -> Optional[ConstraintSatisfactionProblem]:
"""The CSP that describes the constraints in past executions."""
return self._csp
@property
def sgc(self) -> Optional[CrowSGC]:
"""The SGC (state-goal-constraints) context."""
return self._sgc
@property
def optimistic_execution(self) -> bool:
"""Whether to execute the expression optimistically (i.e., treat all CSP constraints True)."""
return self._optimistic_execution
[docs]
@contextlib.contextmanager
def with_csp(self, csp: Optional[ConstraintSatisfactionProblem]):
"""A context manager to temporarily set the CSP of the executor."""
old_csp = self._csp
self._csp = csp
yield
self._csp = old_csp
[docs]
@contextlib.contextmanager
def with_sgc(self, sgc: Optional[CrowSGC]):
"""A context manager to temporarily set the SGC of the executor."""
old_sgc = self._sgc
self._sgc = sgc
yield
self._sgc = old_sgc
def _register_default_function_implementations(self):
for t in itertools.chain(self.domain.types.values(), [BOOL, INT64, FLOAT32, STRING]):
if isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, ScalarValueType) or isinstance(t, ScalarValueType):
# NB(Jiayuan Mao @ 2024/07/11): Can't use the lambda function here, because the value of `t` is not captured.
self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(
f'type::{t.typename}::hash',
CrowPythonFunctionRef(lambda x: x.tensor.item(), support_batch=False, unwrap_values=False)
)
elif isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, VectorValueType):
self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False))
elif isinstance(t, PyObjValueType):
if t.base_typename == 'string':
def compare(x, y):
return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values == y.tensor.values, dtype=torch.bool), x.batch_dims)
self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(compare, support_batch=True, auto_broadcast=True, unwrap_values=False))
def compare_neg(x, y):
return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values != y.tensor.values, dtype=torch.bool), x.batch_dims)
self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(compare_neg, support_batch=True, auto_broadcast=True, unwrap_values=False))
for fname, cross_ref_name in self.domain.external_function_crossrefs.items():
self.register_function_implementation(fname, CrowPythonFunctionCrossRef(cross_ref_name))
def _register_external_function_implementations_from_domain(self):
for filepath in self.domain.external_function_implementation_files:
self.load_external_function_implementations_from_file(filepath)
[docs]
def load_external_function_implementations_from_file(self, filepath: str):
from jacinle.utils.imp import load_module_filename
module = load_module_filename(filepath)
for name, func in module.__dict__.items():
if isinstance(func, CrowPythonFunctionRef):
self.register_function_implementation(name, func)
_domain: CrowDomain
@property
def domain(self) -> CrowDomain:
return self._domain
@property
def state(self) -> CrowState:
"""The state of the executor."""
return self._state
@property
def effect_update_from_simulation(self) -> bool:
"""A context variable indicating whether the current effect should be updated from simulation, instead of the evaluation of expressions."""
return self._effect_update_from_simulation
@property
def effect_update_from_execution(self) -> bool:
"""A context variable indicating whether the current effect should be updated from the execution of the operator."""
return self._effect_update_from_execution
@property
def effect_state_index(self) -> Optional[int]:
return self._effect_state_index
[docs]
def parse(self, string: Union[str, Expression], *, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None) -> Expression:
if isinstance(string, Expression):
return string
return self._domain.parse(string, state=state, variables=variables)
_function_implementations: Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]]
@property
def function_implementations(self) -> Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]]:
return self._function_implementations
[docs]
def register_function_implementation(self, name: str, func: Union[Callable, CrowPythonFunctionRef, CrowPythonFunctionCrossRef]):
if isinstance(func, CrowPythonFunctionRef):
self._function_implementations[name] = func.set_executor(self)
elif isinstance(func, CrowPythonFunctionCrossRef):
self._function_implementations[name] = func
else:
self._function_implementations[name] = CrowPythonFunctionRef(func)
[docs]
def get_function_implementation(self, name: str) -> CrowPythonFunctionRef:
while name in self._function_implementations:
func = self._function_implementations[name]
if isinstance(func, CrowPythonFunctionCrossRef):
name = func.cross_ref_name
else:
return func
raise KeyError(f'Function {name} not found.')
[docs]
def execute(
self, expression: Union[Expression, str],
state: Optional[TensorState] = None,
bounded_variables: Optional[BoundedVariablesDictCompatible] = None,
csp: Optional[ConstraintSatisfactionProblem] = None,
sgc: Optional[CrowSGC] = None,
bypass_bounded_variable_check: bool = False,
optimistic_execution: bool = False
) -> TensorValueExecutorReturnType:
"""Execute an expression.
Args:
expression: the expression to execute.
state: the state to use. If None, the current state of the executor will be used.
bounded_variables: the bounded variables to use. If None, the current bounded variables of the executor will be used.
csp: the constraint satisfaction problem to use. If None, the current CSP of the executor will be used.
sgc: the SGC (state-goal-constraints) context to use. If None, the current SGC context of the executor will be used.
bypass_bounded_variable_check: whether to bypass the check of the bounded variables.
optimistic_execution: whether to execute the expression optimistically (i.e., treat all CSP constraints True).
Returns:
the TensorValue object.
"""
if isinstance(expression, str):
all_variables = list()
if bounded_variables is not None:
for k, v in bounded_variables.items():
if isinstance(k, str):
all_variables.append(Variable(k, AutoType))
elif isinstance(k, Variable):
all_variables.append(k)
expression = self.parse(expression, state=state, variables=all_variables)
state = state if state is not None else self._state
csp = csp if csp is not None else self._csp
sgc = sgc if sgc is not None else self._sgc
bounded_variables = bounded_variables if bounded_variables is not None else self._bounded_variables
with self.with_state(state), self.with_csp(csp), self.with_sgc(sgc), self.with_bounded_variables(bounded_variables, bypass_bounded_variable_check=bypass_bounded_variable_check):
self._optimistic_execution, backup = optimistic_execution, self._optimistic_execution
try:
return self._execute(expression)
finally:
self._optimistic_execution = backup
def _execute(self, expression: Expression) -> TensorValueExecutorReturnType:
if self.csp is not None:
return self._csp_visitor.visit(expression)
return self._default_visitor.visit(expression)
[docs]
@contextlib.contextmanager
def update_effect_mode(self, evaluation_mode: CrowFunctionEvaluationMode, state_index: Optional[int] = None):
old_from_simulation = self._effect_update_from_simulation
old_from_execution = self._effect_update_from_execution
old_state_index = self._effect_state_index
self._effect_update_from_simulation = evaluation_mode is CrowFunctionEvaluationMode.SIMULATION
self._effect_update_from_execution = evaluation_mode is CrowFunctionEvaluationMode.EXECUTION
self._effect_state_index = state_index
yield
self._effect_update_from_simulation = old_from_simulation
self._effect_update_from_execution = old_from_execution
self._effect_state_index = old_state_index
def _fast_index(value, ind):
if isinstance(value, TensorValue):
return value.fast_index(ind)
if len(ind) == 0:
return value
raise ValueError(f'Unsupported value type: {value}.')
class _CrowUnaryFunctionImpl(object):
def __init__(self, dtype, op):
self.dtype = dtype
self.op = op
def __call__(self, x):
return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor), x.batch_dims)
class _CrowArithFunctionImpl(object):
def __init__(self, dtype, op):
self.dtype = dtype
self.op = op
def __call__(self, x, y):
return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims)
class _CrowComparisonFunctionImpl(object):
def __init__(self, dtype, op):
self.dtype = dtype
self.op = op
def __call__(self, x, y):
return TensorValue(BOOL, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims)
[docs]
class CrowExecutionDefaultVisitor(ExpressionVisitor):
"""The underlying default implementation for :class:`CrowExecutor`. This function does not handle CSPs (a.k.a. optimistic execution)."""
[docs]
def __init__(self, executor: CrowExecutor):
"""Initialize a PDExpressionExecutionDefaultVisitor.
Args:
executor: the executor that uses this visitor.
"""
self.executor = executor
@property
def csp(self) -> ConstraintSatisfactionProblem:
return self.executor.csp
[docs]
def visit_null_expression(self, expr: E.NullExpression) -> Any:
return None
[docs]
def visit_variable_expression(self, expr: E.VariableExpression) -> TensorValueExecutorReturnType:
variable = expr.variable
if variable.dtype is AutoType:
return self.executor.retrieve_bounded_variable_by_name(variable.name)
return self.executor.bounded_variables[variable.dtype.typename][variable.name]
[docs]
def visit_object_constant_expression(self, expr: E.ObjectConstantExpression) -> Union[StateObjectReference, ListValue]:
const = expr.constant
if isinstance(const.name, (StateObjectReference, StateObjectList)):
return const.name
state = self.executor.state
assert isinstance(state, NamedObjectTensorState)
if isinstance(const, ListValue):
return StateObjectList(const.dtype, [StateObjectReference(c.name, state.get_typed_index(c.name, c.dtype.typename), c.dtype) for c in const.values])
return StateObjectReference(
const.name,
state.get_typed_index(const.name, const.dtype.typename),
const.dtype
)
[docs]
def visit_constant_expression(self, expr: E.ConstantExpression) -> TensorValueExecutorReturnType:
value = expr.constant
assert isinstance(value, (TensorValue, ListValue))
return value
[docs]
def visit_list_creation_expression(self, expr: E.ListCreationExpression) -> Any:
argument_values = self.forward_args(*expr.arguments, force_tuple=True)
return ListValue(expr.return_type, argument_values)
[docs]
def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> Any:
raise RuntimeError('List expansion is not supported in the expression evaluation.')
[docs]
def visit_function_application_expression(
self, expr: E.FunctionApplicationExpression,
argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None
) -> TensorValueExecutorReturnType:
function = expr.function
return_type = function.return_type
state = self.executor.state
assert isinstance(function, (CrowFeature, CrowFunction))
if argument_values is None:
argument_values = self.forward_args(*expr.arguments, force_tuple=True)
has_list_values = any(isinstance(v, ListValue) for v in argument_values)
if isinstance(function, CrowFunction) and function.is_generator_placeholder: # always true branch
if has_list_values:
raise NotImplementedError('List values are not supported in the generator placeholder function.')
argument_values = expand_argument_values(argument_values)
batched_value = None
for argv in argument_values:
if isinstance(argv, TensorValue):
batched_value = argv
break
assert batched_value is not None
rv = torch.ones(
batched_value.tensor.shape[:batched_value.total_batch_dims],
dtype=torch.bool, device=batched_value.tensor.device if isinstance(batched_value.tensor, torch.Tensor) else None
)
assert return_type == BOOL
# Add "true" asserts to the csp.
if self.csp is not None and not self.executor.optimistic_execution:
expr_string = expr.cached_string(-1)
for ind in _iter_tensor_indices(rv):
self.csp.add_constraint(Constraint.from_function(
function,
# I think there is some bug here... for "StateObjectReference"
[argv.fast_index(tuple(ind)) if isinstance(argv, TensorValue) else argv for argv in argument_values],
True
), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string)
return TensorValue(
BOOL, batched_value.batch_variables,
rv, batch_dims=state.batch_dims
)
elif function.is_cacheable and state is not None and function.name in state.features: # state maybe None if we are evaluating a "constraint" function.
argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values]
batch_variables = list()
anonymous_index = 0
accessors = list()
index_types = list()
for i, (arg, value) in enumerate(zip(expr.arguments, argument_values)):
if value == QINDEX:
batch_variables.append(arg.variable.name)
accessors.append(value)
elif isinstance(value, StateObjectList):
batch_variables.append(f'@{anonymous_index}')
index_types.append(value.element_type)
anonymous_index += 1
accessors.append(value.array_accessor)
else:
accessors.append(value)
value = state.features[function.name][tuple(accessors)]
if 'dirty_features' in state.internals and function.name in state.internals['dirty_features']:
value_opt = state.features[function.name].tensor_optimistic_values[tuple(argument_values)]
if (value_opt < 0).any().item():
if function.is_derived:
# Case 1: derived function.
with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}):
return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values)
else:
# Case 2: external direct function.
return self._rename_derived_function_application(self.forward_external_function(function.name, argument_values, return_type, expression=expr), function.arguments, expr.arguments, argument_values)
if len(index_types) == 0:
return value.rename_batch_variables(batch_variables, force=True)
else:
rv_dtype = BatchedListType(value.dtype, index_types)
return value.rename_batch_variables(batch_variables, dtype=rv_dtype, force=True)
elif function.is_derived:
with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}):
return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values)
else:
# dynamic predicate is exactly the same thing as a pre-defined external function.
# only supports external functions with a single return value.
return self._rename_derived_function_application(self.forward_external_function(function.name, argument_values, return_type, expression=expr), function.arguments, expr.arguments, argument_values)
def _rename_derived_function_application(self, rv: TensorValue, function_argument_variables, outer_arguments, argument_values):
if not isinstance(rv, TensorValue):
return rv
output_batch_variables = list(rv.batch_variables)
for function_argument_variable, outer_argument_expr, argument_value in zip(function_argument_variables, outer_arguments, argument_values):
if argument_value is QINDEX:
assert function_argument_variable.name in output_batch_variables, f'Variable {function_argument_variable.name} not found in the output batch variables {output_batch_variables}. Report this as a bug.'
assert isinstance(outer_argument_expr, E.VariableExpression), 'Only variable arguments can be QINDEX. Report this as a bug.'
index = output_batch_variables.index(function_argument_variable.name)
output_batch_variables[index] = outer_argument_expr.variable.name
return rv.clone(clone_tensor=False).rename_batch_variables(output_batch_variables, clone=False)
[docs]
def visit_list_function_application_expression(self, expr: E.ListFunctionApplicationExpression) -> Any:
raise DeprecationWarning('List function application is deprecated.')
argument_values = self.forward_args(*expr.arguments, force_tuple=True)
return self.visit_function_application_expression(expr, argument_values)
# if nr_values is None:
# return self.visit_function_application_expression(expr, argument_values)
# else:
# rvs = list()
# for i in range(nr_values):
# this_argv = tuple(argv.values[i] if isinstance(argv, ListValue) else argv for argv in argument_values)
# rv = self.visit_function_application_expression(expr, this_argv)
# rvs.append(rv)
# return ListValue(expr.return_type, rvs)
[docs]
def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType:
if argument_values is None:
argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True)
argument_values = expand_argument_values(argument_values)
assert len(argument_values) > 0
assert all(isinstance(v, TensorValue) for v in argument_values)
dtype = argument_values[0].dtype
batch_variables = argument_values[0].batch_variables
if expr.bool_op is BoolOpType.NOT:
assert len(argument_values) == 1
return TensorValue(
dtype, batch_variables,
torch.logical_not(argument_values[0].tensor) if argument_values[0].tensor.dtype == torch.bool else 1 - argument_values[0].tensor,
batch_dims=self.executor.state.batch_dims
)
elif expr.bool_op is BoolOpType.AND:
if len(argument_values) == 1:
return argument_values[0]
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amin(dim=-1), None, argument_values[0].tensor_mask),
batch_dims=self.executor.state.batch_dims
)
elif expr.bool_op is BoolOpType.OR:
if len(argument_values) == 1:
return argument_values[0]
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amax(dim=-1), None, argument_values[0].tensor_mask),
batch_dims=self.executor.state.batch_dims
)
elif expr.bool_op is BoolOpType.XOR:
if len(argument_values) == 1:
return argument_values[0]
for argv in argument_values:
if argv.tensor.requires_grad:
raise RuntimeError('XOR does not support gradients.')
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).sum(dim=-1) % 2, None, argument_values[0].tensor_mask),
batch_dims=self.executor.state.batch_dims
)
elif expr.bool_op is BoolOpType.IMPLIES:
assert len(argument_values) == 2
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(torch.max(1 - argument_values[0].tensor, argument_values[1].tensor), None, argument_values[0].tensor_mask),
batch_dims=self.executor.state.batch_dims
)
else:
raise ValueError(f'Unknown bool op type: {expr.bool_op}.')
[docs]
def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType:
if value is None:
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
value = self.forward_args(expr.expression)
assert isinstance(value, TensorValue)
if expr.quantification_op is QuantificationOpType.BATCHED:
return value
dtype = value.dtype
batch_variables = value.batch_variables
variable_index = batch_variables.index(expr.variable.name)
batch_variables = batch_variables[:variable_index] + batch_variables[variable_index + 1:]
if value.tensor_mask is None:
tensor = value.tensor
mask = None
else:
tensor, mask = value.tensor, value.tensor_mask
if expr.quantification_op is QuantificationOpType.FORALL:
tensor = (tensor * mask + (1 - mask)).to(tensor.dtype)
elif expr.quantification_op is QuantificationOpType.EXISTS:
tensor = (tensor * mask).to(tensor.dtype)
else:
raise ValueError(f'Unknown quantification op type: {expr.quantification_op}.')
if expr.quantification_op is QuantificationOpType.FORALL:
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(tensor.amin(dim=variable_index + value.batch_dims), None, mask),
batch_dims=self.executor.state.batch_dims
)
elif expr.quantification_op is QuantificationOpType.EXISTS:
return TensorValue(
dtype, batch_variables,
MaskedTensorStorage(tensor.amax(dim=variable_index + value.batch_dims), None, mask),
batch_dims=self.executor.state.batch_dims
)
else:
raise ValueError(f'Unknown quantifier type: {expr.quantification_op}.')
[docs]
def visit_object_compare_expression(self, expr: E.ObjectCompareExpression) -> Any:
v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1])
values = list()
batched_variables = list()
for i, v in enumerate([v1, v2]):
if v is QINDEX:
arg = expr.arguments[i]
assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.'
batched_variables.append(arg.variable.name)
values.append(torch.arange(0, len(self.executor.state.object_type2name[arg.variable.dtype.typename]), dtype=torch.int64))
elif isinstance(v, StateObjectReference):
values.append(torch.tensor(v.index, dtype=torch.int64))
else:
raise ValueError(f'Unsupported value type: {v}.')
if v1 is QINDEX:
if v2 is QINDEX:
value = values[0].unsqueeze(1).eq(values[1].unsqueeze(0))
else:
value = values[0].eq(values[1])
else:
value = values[0].eq(values[1])
if expr.compare_op is E.CompareOpType.EQ:
pass
elif expr.compare_op is E.CompareOpType.NEQ:
value = torch.logical_not(value)
else:
raise ValueError(f'Unknown compare op type for object types: {expr.compare_op}.')
return TensorValue(BOOL, batched_variables, value, batch_dims=self.executor.state.batch_dims)
[docs]
def visit_value_compare_expression(self, expr: E.ValueCompareExpression) -> Any:
v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1])
mapping = {
E.CompareOpType.EQ: 'equal',
E.CompareOpType.NEQ: 'not_equal',
E.CompareOpType.LT: 'less',
E.CompareOpType.LEQ: 'less_equal',
E.CompareOpType.GT: 'greater',
E.CompareOpType.GEQ: 'greater_equal',
}
target_op = mapping[expr.compare_op]
if isinstance(v1.dtype, PyObjValueType):
rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr)
elif isinstance(v1.dtype, NamedTensorValueType) or isinstance(v1.dtype, ScalarValueType):
rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr)
else:
raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {v1.dtype} and {v2.dtype}.')
return rv
[docs]
def visit_condition_expression(self, expr: E.ConditionExpression) -> Any:
raise NotImplementedError('Condition expression is not supported in the expression evaluation.')
[docs]
def visit_find_one_expression(self, expr: E.FindOneExpression) -> Any:
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
values = self.visit(expr.expression)
assert values.batch_dims == 0
assert len(values.batch_variables) == 1
x = (values.tensor > 0.5)
objects = x.nonzero().squeeze(-1).detach().cpu().tolist()
if len(objects) == 0:
raise RuntimeError('No object found. Currently the executor does not support this case.')
names = self.executor.state.object_type2name[expr.variable.dtype.typename]
return StateObjectReference(names[objects[0]], objects[0], expr.variable.dtype)
[docs]
def visit_find_all_expression(self, expr: E.FindAllExpression) -> Any:
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
values = self.visit(expr.expression)
assert values.batch_dims == 0
assert len(values.batch_variables) == 1
x = (values.tensor > 0.5)
objects = x.nonzero().squeeze(-1).detach().cpu().tolist()
names = self.executor.state.object_type2name[expr.variable.dtype.typename]
return StateObjectList(expr.return_type, [StateObjectReference(names[i], i, expr.variable.dtype) for i in objects])
[docs]
def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType:
if feature is None or value is None:
feature, value = self.forward_args(expr.predicate, expr.value)
feature, value = expand_argument_values([feature, value])
if isinstance(feature.dtype, PyObjValueType):
rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr)
elif isinstance(feature.dtype, NamedTensorValueType):
rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr)
else:
raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {feature.dtype} and {value.dtype}.')
return rv
[docs]
def visit_assign_expression(self, expr: E.AssignExpression):
# TODO(Jiayuan Mao @ 2024/01/22): is this really the right thing to do?
if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution:
return
state: CrowState = self.executor.state
argument_values = list(self.forward_args(*expr.predicate.arguments, force_tuple=True))
target_value = self.forward_args(expr.value)
for i, (arg, value) in enumerate(zip(expr.predicate.arguments, argument_values)):
if value == QINDEX:
assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.'
# TODO(Jiayuan Mao @ 2024/08/3): I think we need to think about how to align the batch variable dimensions...
pass
elif isinstance(value, StateObjectList):
argument_values[i] = value.array_accessor
elif isinstance(value, StateObjectReference):
argument_values[i] = value.index
# if state.features[expr.predicate.function.name].quantized:
# if not value.quantized:
# value = self.executor.value_quantizer.quantize_value(value)
# else:
# if value.quantized:
# value = self.executor.value_quantizer.unquantize_value(value)
function_name = expr.predicate.function.name
if function_name not in state.features:
if function_name not in self.executor.domain.features and function_name not in self.executor.domain.functions:
raise ValueError(f'Function {function_name} not found in the state.')
state.init_dirty_feature(self.executor.domain.functions[function_name])
if isinstance(target_value, ListValue):
if len(target_value.values) == 0:
return
target_value = TensorValue(target_value.values[0].dtype, 1 + len(target_value.values[0].batch_variables), torch.stack([x.tensor for x in target_value.values]))
state.features[function_name][tuple(argument_values)] = target_value
if 'dirty_features' in state.internals and function_name in state.internals['dirty_features']:
state.features[function_name].tensor_optimistic_values[tuple(argument_values)] = 0
[docs]
def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType:
value, condition = self.forward_args(expr.predicate, expr.condition)
value = value.clone()
if value.tensor_mask is None:
value.tensor_mask = condition.tensor
else:
value.tensor_mask = torch.min(value.tensor_mask, condition.tensor)
return value
[docs]
def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> TensorValueExecutorReturnType:
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
return self.visit(expr.expression)
CONDITIONAL_ASSIGN_QUANTIZE = False
[docs]
def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression):
state = self.executor.state
argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True)
argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values]
value = self.forward_args(expr.value)
condition = self.forward_args(expr.condition)
condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor
feature = state.features[expr.predicate.function.name]
origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too.
# assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0
if value.is_tensorized_pyobj:
raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.')
else:
if condition_tensor.dim() < value.tensor.dim():
condition_tensor = condition_tensor.unsqueeze(-1)
state.features[expr.predicate.function.name].tensor[argument_values] = (
condition_tensor.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor).to(origin_tensor.dtype) * origin_tensor
)
[docs]
def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression):
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
self.visit(expr.expression)
[docs]
def forward_args(self, *args, force_tuple: bool = False, expand_list_arguments: bool = False) -> Union[TensorValueExecutorReturnType, Tuple[TensorValueExecutorReturnType, ...]]:
if len(args) == 1 and not force_tuple:
rvs = self.visit(args[0])
else:
rvs = tuple(self.visit(arg) for arg in args)
if expand_list_arguments:
expanded_rvs = list()
for rv in rvs:
if isinstance(rv, ListValue):
expanded_rvs.extend(rv.values)
else:
expanded_rvs.append(rv)
if not force_tuple and len(expanded_rvs) == 1:
return expanded_rvs[0]
else:
return tuple(expanded_rvs)
else:
return rvs
[docs]
def forward_external_function(
self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType],
return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[Expression] = None
) -> TensorValue:
external_function = self.executor.get_function_implementation(function_name)
assert isinstance(external_function, CrowPythonFunctionRef)
function_def = expression.function if isinstance(expression, E.FunctionApplicationExpression) else None
return external_function.forward(argument_values, return_type=return_type, auto_broadcast=auto_broadcast, function_def=function_def)
[docs]
class CrowExecutionCSPVisitor(CrowExecutionDefaultVisitor):
[docs]
def __init__(self, executor: CrowExecutor):
super().__init__(executor)
[docs]
def forward_external_function(
self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType],
return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[E.FunctionApplicationExpression] = None
) -> TensorValue:
need_reset_simulation_state = False
need_simulation_optimistic_execution = False
# This is a hack to handle the case where the expression is ValueCompareExpression
if isinstance(expression, E.FunctionApplicationExpression):
function = expression.function
if isinstance(function, CrowFunction) and function.is_simulation_dependent or function.is_execution_dependent:
assert self.executor.state is not None
assert self.executor.csp is not None
# print(f'function_name: {function_name}, csp_t={self.executor.csp.get_state_timestamp()}, state_t={self.executor.state.simulation_state_index}')
if self.executor.csp.get_state_timestamp() != self.executor.state.simulation_state_index:
need_simulation_optimistic_execution = True
else:
if self.executor.state.simulation_state_index > 0:
need_reset_simulation_state = True
argument_values = expand_argument_values(argument_values)
tensor_values = [argv for argv in argument_values if isinstance(argv, TensorValue)]
optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None]
if len(optimistic_masks) > 0 or need_simulation_optimistic_execution:
if need_simulation_optimistic_execution:
if len(optimistic_masks) == 0:
if len(tensor_values) > 0:
optimistic_mask = torch.ones(size=tensor_values[0].tensor.shape[:tensor_values[0].total_batch_dims], dtype=torch.bool, device=tensor_values[0].tensor.device)
elif all(isinstance(argv, StateObjectReference) for argv in argument_values):
optimistic_mask = torch.ones(size=tuple(), dtype=torch.bool, device=None)
else:
raise NotImplementedError('Unsupported case for simulation dependent functions. Either all arguments are StateObjectReference or at least one of the values are TensorValue.')
else:
optimistic_mask = optimistic_masks[0].clone()
optimistic_mask[...] = True
rv = TensorValue(
return_type,
batch_variables=tensor_values[0].batch_variables if len(tensor_values) > 0 else [],
tensor=torch.zeros(optimistic_mask.shape),
batch_dims=tensor_values[0].batch_dims if len(tensor_values) > 0 else 0
)
else:
optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1)
if need_reset_simulation_state:
with self.executor.simulation_interface.restore_context():
self.executor.simulation_interface.restore_state_keep(self.executor.state.simulation_state, self.executor.state.simulation_state_index)
rv = super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
else:
rv = super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
if optimistic_mask.sum().item() == 0:
return rv
rv.init_tensor_optimistic_values()
if self.executor.optimistic_execution:
rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC
else:
expr_string = expression.cached_string(-1)
if isinstance(expression, E.ValueCompareExpression):
if expression.compare_op is E.CompareOpType.EQ:
constraint_function = Constraint.EQUAL
else:
raise NotImplementedError(f'Unsupported compare op type: {expression.compare_op} for optimistic execution.')
elif isinstance(expression, E.PredicateEqualExpression):
constraint_function = Constraint.EQUAL
elif isinstance(expression, E.FunctionApplicationExpression):
constraint_function = expression.function
else:
raise NotImplementedError(f'Unsupported expression type: {expression} for optimistic execution.')
for ind in optimistic_mask.nonzero().tolist():
ind = tuple(ind)
new_identifier = self.executor.csp.new_var(return_type, wrap=True)
rv.tensor_optimistic_values[ind] = new_identifier.identifier
self.csp.add_constraint(Constraint.from_function(
constraint_function,
[_fast_index(argv, ind) for argv in argument_values],
new_identifier
), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string)
return rv
if need_reset_simulation_state:
with self.executor.simulation_interface.restore_context():
self.executor.simulation_interface.restore_state_keep(self.executor.state.simulation_state, self.executor.state.simulation_state_index)
return super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
else:
return super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
[docs]
def visit_function_application_expression(self, expr: E.FunctionApplicationExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType:
return super().visit_function_application_expression(expr, argument_values)
[docs]
def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType:
if argument_values is None:
argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True)
argument_values = list(expand_argument_values(argument_values))
for argv in argument_values:
assert argv.dtype == BOOL, 'Boolean expression only supports boolean values in CSP mode.'
optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None]
if len(optimistic_masks) > 0:
optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1)
if optimistic_mask.sum().item() > 0:
rv = super().visit_bool_expression(expr, argument_values)
rv.init_tensor_optimistic_values()
if self.executor.optimistic_execution:
rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC
else:
expr_string = expr.cached_string(-1)
for ind in optimistic_mask.nonzero().tolist():
ind = tuple(ind)
this_argv = [argv.fast_index(ind, wrap=False) for argv in argument_values]
determined = None
if expr.return_type == BOOL:
if expr.bool_op is BoolOpType.NOT:
pass # nothing we can do.
elif expr.bool_op is BoolOpType.AND:
if 0 in this_argv or False in this_argv:
determined = False
elif expr.bool_op is BoolOpType.OR:
if 1 in this_argv or True in this_argv:
determined = True
this_argv = [v for v in this_argv if isinstance(v, OptimisticValue)]
else: # generalized boolean operations.
pass
if determined is None:
new_identifier = self.csp.new_var(BOOL)
rv.tensor_optimistic_values[ind] = new_identifier
self.csp.add_constraint(Constraint(
expr.bool_op,
this_argv,
cvt_opt_value(new_identifier, BOOL),
), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string)
else:
rv[ind] = determined
return rv
else:
return super().visit_bool_expression(expr, argument_values)
else: # if len(optimistic_masks) == 0
return super().visit_bool_expression(expr, argument_values)
[docs]
def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> Any:
if value is None:
with self.executor.new_bounded_variables({expr.variable: QINDEX}):
value = self.forward_args(expr.expression)
assert isinstance(value, TensorValue)
assert value.dtype == BOOL, 'Quantification expression only supports boolean values in CSP mode.'
value.init_tensor_optimistic_values()
rv = super().visit_quantification_expression(expr, value)
dim = value.batch_variables.index(expr.variable.name) + value.batch_dims
value_transposed = value.tensor
optimistic_values_transposed = value.tensor_optimistic_values
if dim != value.tensor.ndim - 1:
value_transposed = value_transposed.transpose(dim, -1) # put the target dimension last.
optimistic_values_transposed = optimistic_values_transposed.transpose(dim, -1)
optimistic_mask_transposed = is_optimistic_value(optimistic_values_transposed)
value_transposed = torch.where(
optimistic_mask_transposed,
optimistic_values_transposed,
value_transposed.to(optimistic_values_transposed.dtype)
)
optimistic_mask = optimistic_mask_transposed.any(dim=-1)
if optimistic_mask.sum().item() == 0:
return rv
rv.init_tensor_optimistic_values()
if self.executor.optimistic_execution:
rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC
else:
expr_string = expr.cached_string(-1)
for ind in optimistic_mask.nonzero().tolist():
ind = tuple(ind)
this_argv = value_transposed[ind].tolist()
determined = None
if expr.quantification_op is QuantificationOpType.FORALL:
if 0 in this_argv or False in this_argv:
determined = False
else:
if 1 in this_argv or True in this_argv:
determined = True
this_argv = list(filter(is_optimistic_value, this_argv))
if determined is None:
new_identifier = self.csp.new_var(BOOL)
rv.tensor_optimistic_values[ind] = new_identifier
self.csp.add_constraint(Constraint(
expr.quantification_op,
[OptimisticValue(value.dtype, int(v)) for v in this_argv],
OptimisticValue(value.dtype, new_identifier),
), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string)
else:
rv.tensor[ind] = determined
return rv
[docs]
def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> Any:
if feature is None or value is None:
feature, value = self.forward_args(expr.predicate, expr.value)
feature, value = expand_argument_values([feature, value])
feature.init_tensor_optimistic_values()
value.init_tensor_optimistic_values()
optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values))
if optimistic_mask.sum().item() > 0:
raise NotImplementedError('Optimistic execution is not supported for predicate equal expression.')
rv = super().visit_predicate_equal_expression(expr, feature, value)
return rv
# feature.init_tensor_optimistic_values()
# value.init_tensor_optimistic_values()
# optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values))
# if optimistic_mask.sum().item() == 0:
# return rv
# rv.init_tensor_optimistic_values()
# if self.executor.optimistic_execution:
# rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC
# else:
# expr_string = expr.cached_string(-1)
# for ind in optimistic_mask.nonzero().tolist():
# ind = tuple(ind)
# this_argv = feature.fast_index(ind), value.fast_index(ind)
# new_identifier = self.csp.new_var(BOOL)
# rv.tensor_optimistic_values[ind] = new_identifier
# self.csp.add_constraint(EqualityConstraint(
# *[cvt_opt_value(v, feature.dtype) for v in this_argv],
# OptimisticValue(BOOL, new_identifier)
# ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string)
# return rv
[docs]
def visit_assign_expression(self, expr: E.AssignExpression) -> Any:
if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution:
feature = self.executor.state.features[expr.predicate.function.name]
feature.init_tensor_optimistic_values()
argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True)
assert self.executor.effect_state_index is not None, 'Effect action index must be set if the target predicate will be updated from simulation.'
expr_string = expr.cached_string(-1)
for entry_values in _expand_tensor_indices(feature, argument_values):
if self.executor.optimistic_execution:
raise RuntimeError('Optimistic execution is not supported for effect update from simulation.')
else:
opt_identifier = self.csp.new_var(feature.dtype, wrap=True)
feature.tensor_optimistic_values[entry_values] = opt_identifier.identifier
self.csp.add_constraint(Constraint(
SimulationFluentConstraintFunction(self.executor.effect_state_index, expr.predicate.function, entry_values, is_execution_constraint=self.executor.effect_update_from_execution),
[],
opt_identifier,
note=f'{expr_string}::{entry_values}' if len(entry_values) > 0 else expr_string
))
else:
return super().visit_assign_expression(expr)
[docs]
def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType:
return super().visit_conditional_select_expression(expr)
[docs]
def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> Any:
return super().visit_deictic_select_expression(expr)
[docs]
def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression) -> Any:
if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution:
raise NotImplementedError('Conditional assign is not supported in simulation mode.')
if self.executor.optimistic_execution:
raise RuntimeError('Optimistic execution is not supported for conditional assign.')
state = self.executor.state
argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True)
value = self.forward_args(expr.value)
condition = self.forward_args(expr.condition)
condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor
condition_tensor = (condition_tensor > 0.5).to(torch.bool)
feature = state.features[expr.predicate.function.name]
origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too.
# assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0
# NB(Jiayuan Mao @ 2023/08/15): conditional assignment does not support "soft" assignment.
if value.is_tensorized_pyobj:
raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.')
else:
if condition_tensor.dim() < value.tensor.dim():
condition_tensor_expanded = condition_tensor.unsqueeze(-1)
else:
condition_tensor_expanded = condition_tensor
feature.tensor[argument_values] = (
condition_tensor_expanded.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor_expanded).to(origin_tensor.dtype) * origin_tensor
)
feature.init_tensor_optimistic_values()
if value.tensor_optimistic_values is not None:
feature.tensor_optimistic_values[argument_values] = (
condition_tensor.to(torch.int64) * value.tensor_optimistic_values +
(1 - condition_tensor).to(torch.int64) * state.features[expr.predicate.function.name].tensor_optimistic_values[argument_values]
)
if condition.tensor_optimistic_values is None:
pass
else:
optimistic_mask = is_optimistic_value(condition.tensor_optimistic_values)
if optimistic_mask.sum().item() == 0:
pass
else:
expr_string = expr.cached_string(-1)
dtype = expr.predicate.function.return_type
for ind in optimistic_mask.nonzero().tolist():
ind = tuple(ind)
new_identifier = self.csp.new_var(dtype, wrap=True)
neg_condition_identifier = self.csp.new_var(BOOL, wrap=True)
eq_1_identifier = self.csp.new_var(BOOL, wrap=True)
eq_2_identifier = self.csp.new_var(BOOL, wrap=True)
condition_identifier = condition.tensor_optimistic_values[ind].item()
self.csp.add_constraint(EqualityConstraint(
new_identifier,
cvt_opt_value(value.fast_index(ind), dtype),
eq_1_identifier,
), note=f'{expr_string}::{ind}::eq-1')
self.csp.add_constraint(EqualityConstraint(
new_identifier,
cvt_opt_value(origin_tensor.fast_index(ind) if isinstance(origin_tensor, TensorizedPyObjValues) else origin_tensor[ind], dtype),
eq_2_identifier
), note=f'{expr_string}::{ind}::eq-2')
self.csp.add_constraint(Constraint(
BoolOpType.NOT,
[OptimisticValue(BOOL, condition_identifier)],
neg_condition_identifier
), note=f'{expr_string}::{ind}::neg-cond')
self.csp.add_constraint(Constraint(
BoolOpType.OR,
[neg_condition_identifier, eq_1_identifier],
cvt_opt_value(True, BOOL)
), note=f'{expr_string}::{ind}::implies-new')
self.csp.add_constraint(Constraint(
BoolOpType.OR,
[OptimisticValue(BOOL, condition_identifier), eq_2_identifier],
cvt_opt_value(True, BOOL)
), note=f'{expr_string}::{ind}::implies-old')
feature.tensor_optimistic_values[ind] = new_identifier.identifier
[docs]
def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression) -> Any:
return super().visit_deictic_assign_expression(expr)
def _iter_tensor_indices(target_tensor: torch.Tensor) -> Iterator[Tuple[int, ...]]:
"""Iterate from the indices of a tensor. """
for ind in torch.nonzero(torch.ones_like(target_tensor)):
yield tuple(ind.tolist())
def _expand_tensor_indices(target_value: TensorValue, input_indices: Tuple[Union[int, slice, StateObjectReference], ...]) -> Iterator[Tuple[int, ...]]:
"""Iterate over entry indices based on the input indices. Supported indices are int and QINDEX (:).
Args:
target_value: the target value, used to determine the size of ``QINDEX``.
input_indices: the indices to iterate over.
Yields:
the entry indices.
"""
indices = list()
for i, ind in enumerate(input_indices):
if isinstance(ind, int):
indices.append(torch.tensor([ind], dtype=torch.int64))
elif isinstance(ind, slice):
assert ind.step is None and ind.start is None and ind.stop is None # == ':'
indices.append(torch.arange(target_value.tensor.shape[i], dtype=torch.int64))
elif isinstance(ind, StateObjectReference):
indices.append(torch.tensor([ind.index], dtype=torch.int64))
else:
raise ValueError(f'Invalid index type: {ind}')
if len(indices) == 0:
yield tuple()
return
if len(indices) == 1:
for x in indices[0]:
yield tuple([x.item()], )
return
indices = torch.meshgrid(*indices, indexing='ij')
indices = [i.flatten() for i in indices]
for i in range(len(indices[0])):
yield tuple(indices[j][i].item() for j in range(len(indices)))