Source code for concepts.dm.crow.executors.crow_executor

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_executor.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/17/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""This file contains the executor classes for the Crow framework.

All the executors are based on :class:`~concepts.dsl.tensor_value.TensorValue` classes.
It supports all expressions defined in :mod:`~concepts.dsl.expressions`, including the basic
function application expressions and a few first-order logic expression types. The executors
are designed to be "differentiable", which means we can directly do backpropagation on the
computed values.

The main entry for the executor is the :class:`CrowExecutor` class.
Internally it contains two executor implementations: the basic one, and an "optimistic" one,
which handles the case where the value of a variable can be unknown and "optimistic".
"""

import itertools
import contextlib
from typing import Any, Optional, Union, Iterator, Sequence, Tuple, List, Mapping, Dict, Callable

import torch
import jactorch
from jacinle.logging import get_logger

import concepts.dsl.expression as E
from concepts.dsl.dsl_types import BOOL, INT64, FLOAT32, STRING, AutoType, TensorValueTypeBase, ScalarValueType, VectorValueType, NamedTensorValueType, PyObjValueType, QINDEX, Variable
from concepts.dsl.expression import Expression, BoolOpType, QuantificationOpType
from concepts.dsl.expression_visitor import ExpressionVisitor
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorizedPyObjValues, TensorValue, MaskedTensorStorage
from concepts.dsl.tensor_state import StateObjectReference, StateObjectList, TensorState, NamedObjectTensorState
from concepts.dsl.tensor_value_utils import expand_argument_values
from concepts.dsl.constraint import OPTIM_MAGIC_NUMBER_MAGIC, is_optimistic_value, OptimisticValue, cvt_opt_value, Constraint, EqualityConstraint, ConstraintSatisfactionProblem, SimulationFluentConstraintFunction
from concepts.dsl.parsers.parser_base import ParserBase
from concepts.dsl.executors.tensor_value_executor import BoundedVariablesDictCompatible, TensorValueExecutorBase, TensorValueExecutorReturnType

from concepts.dm.crow.crow_function import CrowFeature, CrowFunction, CrowFunctionEvaluationMode
from concepts.dm.crow.crow_domain import CrowDomain, CrowState
from concepts.dm.crow.executors.python_function import CrowPythonFunctionRef, CrowPythonFunctionCrossRef, CrowSGC


logger = get_logger(__file__)

__all__ = [
    'CrowExecutor', 'CrowExecutionDefaultVisitor', 'CrowExecutionCSPVisitor',
]


[docs] class CrowExecutor(TensorValueExecutorBase):
[docs] def __init__(self, domain: CrowDomain, parser: Optional[ParserBase] = None): """Initialize a Crow expression executor. Args: domain: the domain of this executor. parser: the parser to be used. This argument is optional. If provided, the execute function can take strings as input. """ super().__init__(domain, parser) self._csp = None self._optimistic_execution = False self._sgc = None self._default_visitor = CrowExecutionDefaultVisitor(self) self._csp_visitor = CrowExecutionCSPVisitor(self) self._register_default_function_implementations() self._effect_update_from_simulation = False self._effect_update_from_execution = False self._effect_action_index = None
@property def csp(self) -> Optional[ConstraintSatisfactionProblem]: """The CSP that describes the constraints in past executions.""" return self._csp @property def sgc(self) -> Optional[CrowSGC]: """The SGC (state-goal-constraints) context.""" return self._sgc @property def optimistic_execution(self) -> bool: """Whether to execute the expression optimistically (i.e., treat all CSP constraints True).""" return self._optimistic_execution
[docs] @contextlib.contextmanager def with_csp(self, csp: Optional[ConstraintSatisfactionProblem]): """A context manager to temporarily set the CSP of the executor.""" old_csp = self._csp self._csp = csp yield self._csp = old_csp
[docs] @contextlib.contextmanager def with_sgc(self, sgc: Optional[CrowSGC]): """A context manager to temporarily set the SGC of the executor.""" old_sgc = self._sgc self._sgc = sgc yield self._sgc = old_sgc
def _register_default_function_implementations(self): for t in itertools.chain(self.domain.types.values(), [BOOL, INT64, FLOAT32, STRING]): if isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, ScalarValueType) or isinstance(t, ScalarValueType): # NB(Jiayuan Mao @ 2024/07/11): Can't use the lambda function here, because the value of `t` is not captured. self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation( f'type::{t.typename}::hash', CrowPythonFunctionRef(lambda x: x.tensor.item(), support_batch=False, unwrap_values=False) ) elif isinstance(t, NamedTensorValueType) and isinstance(t.parent_type, VectorValueType): self.register_function_implementation(f'type::{t.typename}::add', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x + y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::sub', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x - y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::mul', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x * y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::div', CrowPythonFunctionRef(_CrowArithFunctionImpl(t, lambda x, y: x / y), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::neg', CrowPythonFunctionRef(_CrowUnaryFunctionImpl(t, lambda x: -x), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.eq(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ne(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.lt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::less_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.le(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.gt(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) self.register_function_implementation(f'type::{t.typename}::greater_equal', CrowPythonFunctionRef(_CrowComparisonFunctionImpl(BOOL, lambda x, y: torch.ge(x, y).all(dim=-1)), support_batch=True, auto_broadcast=True, unwrap_values=False)) elif isinstance(t, PyObjValueType): if t.base_typename == 'string': def compare(x, y): return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values == y.tensor.values, dtype=torch.bool), x.batch_dims) self.register_function_implementation(f'type::{t.typename}::equal', CrowPythonFunctionRef(compare, support_batch=True, auto_broadcast=True, unwrap_values=False)) def compare_neg(x, y): return TensorValue(BOOL, x.batch_variables, torch.tensor(x.tensor.values != y.tensor.values, dtype=torch.bool), x.batch_dims) self.register_function_implementation(f'type::{t.typename}::not_equal', CrowPythonFunctionRef(compare_neg, support_batch=True, auto_broadcast=True, unwrap_values=False)) for fname, cross_ref_name in self.domain.external_function_crossrefs.items(): self.register_function_implementation(fname, CrowPythonFunctionCrossRef(cross_ref_name)) _domain: CrowDomain @property def domain(self) -> CrowDomain: return self._domain @property def state(self) -> CrowState: """The state of the executor.""" return self._state @property def effect_update_from_simulation(self) -> bool: """A context variable indicating whether the current effect should be updated from simulation, instead of the evaluation of expressions.""" return self._effect_update_from_simulation @property def effect_update_from_execution(self) -> bool: """A context variable indicating whether the current effect should be updated from the execution of the operator.""" return self._effect_update_from_execution @property def effect_action_index(self) -> Optional[int]: return self._effect_action_index
[docs] def parse(self, string: Union[str, Expression], *, state: Optional[CrowState] = None, variables: Optional[Sequence[Variable]] = None) -> Expression: if isinstance(string, Expression): return string return self._domain.parse(string, state=state, variables=variables)
_function_implementations: Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]] @property def function_implementations(self) -> Dict[str, Union[CrowPythonFunctionRef, CrowPythonFunctionCrossRef]]: return self._function_implementations
[docs] def register_function_implementation(self, name: str, func: Union[Callable, CrowPythonFunctionRef, CrowPythonFunctionCrossRef]): if isinstance(func, CrowPythonFunctionRef): self._function_implementations[name] = func.set_executor(self) elif isinstance(func, CrowPythonFunctionCrossRef): self._function_implementations[name] = func else: self._function_implementations[name] = CrowPythonFunctionRef(func)
[docs] def get_function_implementation(self, name: str) -> CrowPythonFunctionRef: while name in self._function_implementations: func = self._function_implementations[name] if isinstance(func, CrowPythonFunctionCrossRef): name = func.cross_ref_name else: return func raise KeyError(f'Function {name} not found.')
[docs] def execute( self, expression: Union[Expression, str], state: Optional[TensorState] = None, bounded_variables: Optional[BoundedVariablesDictCompatible] = None, csp: Optional[ConstraintSatisfactionProblem] = None, sgc: Optional[CrowSGC] = None, optimistic_execution: bool = False ) -> TensorValueExecutorReturnType: """Execute an expression. Args: expression: the expression to execute. state: the state to use. If None, the current state of the executor will be used. bounded_variables: the bounded variables to use. If None, the current bounded variables of the executor will be used. csp: the constraint satisfaction problem to use. If None, the current CSP of the executor will be used. sgc: the SGC (state-goal-constraints) context to use. If None, the current SGC context of the executor will be used. optimistic_execution: whether to execute the expression optimistically (i.e., treat all CSP constraints True). Returns: the TensorValue object. """ if isinstance(expression, str): all_variables = list() if bounded_variables is not None: for k, v in bounded_variables.items(): if isinstance(k, str): all_variables.append(Variable(k, AutoType)) elif isinstance(k, Variable): all_variables.append(k) expression = self.parse(expression, state=state, variables=all_variables) state = state if state is not None else self._state csp = csp if csp is not None else self._csp sgc = sgc if sgc is not None else self._sgc bounded_variables = bounded_variables if bounded_variables is not None else self._bounded_variables with self.with_state(state), self.with_csp(csp), self.with_sgc(sgc), self.with_bounded_variables(bounded_variables): self._optimistic_execution, backup = optimistic_execution, self._optimistic_execution try: return self._execute(expression) finally: self._optimistic_execution = backup
def _execute(self, expression: Expression) -> TensorValueExecutorReturnType: if self.csp is not None: return self._csp_visitor.visit(expression) return self._default_visitor.visit(expression)
[docs] @contextlib.contextmanager def update_effect_mode(self, evaluation_mode: CrowFunctionEvaluationMode, action_index: Optional[int] = None): old_from_simulation = self._effect_update_from_simulation old_from_execution = self._effect_update_from_execution old_action_index = self._effect_action_index self._effect_update_from_simulation = evaluation_mode is CrowFunctionEvaluationMode.SIMULATION self._effect_update_from_execution = evaluation_mode is CrowFunctionEvaluationMode.EXECUTION self._effect_action_index = action_index yield self._effect_update_from_simulation = old_from_simulation self._effect_update_from_execution = old_from_execution self._effect_action_index = old_action_index
def _fast_index(value, ind): if isinstance(value, TensorValue): return value.fast_index(ind) if len(ind) == 0: return value raise ValueError(f'Unsupported value type: {value}.') class _CrowUnaryFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x): return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor), x.batch_dims) class _CrowArithFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x, y): return TensorValue(self.dtype, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims) class _CrowComparisonFunctionImpl(object): def __init__(self, dtype, op): self.dtype = dtype self.op = op def __call__(self, x, y): return TensorValue(BOOL, x.batch_variables, self.op(x.tensor, y.tensor), x.batch_dims)
[docs] class CrowExecutionDefaultVisitor(ExpressionVisitor): """The underlying default implementation for :class:`CrowExecutor`. This function does not handle CSPs (a.k.a. optimistic execution)."""
[docs] def __init__(self, executor: CrowExecutor): """Initialize a PDExpressionExecutionDefaultVisitor. Args: executor: the executor that uses this visitor. """ self.executor = executor
@property def csp(self) -> ConstraintSatisfactionProblem: return self.executor.csp
[docs] def visit_null_expression(self, expr: E.NullExpression) -> Any: return None
[docs] def visit_variable_expression(self, expr: E.VariableExpression) -> TensorValueExecutorReturnType: variable = expr.variable if variable.dtype is AutoType: return self.executor.retrieve_bounded_variable_by_name(variable.name) return self.executor.bounded_variables[variable.dtype.typename][variable.name]
[docs] def visit_object_constant_expression(self, expr: E.ObjectConstantExpression) -> Union[StateObjectReference, ListValue]: const = expr.constant if isinstance(const.name, (StateObjectReference, StateObjectList)): return const.name state = self.executor.state assert isinstance(state, NamedObjectTensorState) if isinstance(const, ListValue): return StateObjectList(const.dtype, [StateObjectReference(c.name, state.get_typed_index(c.name, c.dtype.typename), c.dtype) for c in const.values]) return StateObjectReference( const.name, state.get_typed_index(const.name, const.dtype.typename), const.dtype )
[docs] def visit_constant_expression(self, expr: E.ConstantExpression) -> TensorValueExecutorReturnType: value = expr.constant assert isinstance(value, (TensorValue, ListValue)) return value
[docs] def visit_list_creation_expression(self, expr: E.ListCreationExpression) -> Any: argument_values = self.forward_args(*expr.arguments, force_tuple=True) return ListValue(expr.return_type, argument_values)
[docs] def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> Any: raise RuntimeError('List expansion is not supported in the expression evaluation.')
[docs] def visit_function_application_expression( self, expr: E.FunctionApplicationExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None ) -> TensorValueExecutorReturnType: function = expr.function return_type = function.return_type state = self.executor.state assert isinstance(function, (CrowFeature, CrowFunction)) if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True) has_list_values = any(isinstance(v, ListValue) for v in argument_values) if isinstance(function, CrowFunction) and function.is_generator_placeholder: # always true branch if has_list_values: raise NotImplementedError('List values are not supported in the generator placeholder function.') argument_values = expand_argument_values(argument_values) batched_value = None for argv in argument_values: if isinstance(argv, TensorValue): batched_value = argv break assert batched_value is not None rv = torch.ones( batched_value.tensor.shape[:batched_value.total_batch_dims], dtype=torch.bool, device=batched_value.tensor.device ) assert return_type == BOOL # Add "true" asserts to the csp. if self.csp is not None and not self.executor.optimistic_execution: expr_string = expr.cached_string(-1) for ind in _iter_tensor_indices(rv): self.csp.add_constraint(Constraint.from_function( function, [argv.fast_index(tuple(ind)) for argv in argument_values], True ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) return TensorValue( BOOL, batched_value.batch_variables, rv, batch_dims=state.batch_dims ) elif function.is_cacheable and function.name in state.features: argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values] batch_variables = list() anonymous_index = 0 for i, (arg, value) in enumerate(zip(expr.arguments, argument_values)): if value == QINDEX: batch_variables.append(arg.variable.name) elif isinstance(value, StateObjectList): batch_variables.append(f'@{anonymous_index}') anonymous_index += 1 argument_values[i] = value.array_accessor value = state.features[function.name][tuple(argument_values)] if 'dirty_features' in state.internals and function.name in state.internals['dirty_features']: if has_list_values: raise NotImplementedError('List values are not supported for dirty features.') value_opt = state.features[function.name].tensor_optimistic_values[tuple(argument_values)] if (value_opt < 0).any().item(): assert function.is_derived with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}): return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values) return value.rename_batch_variables(batch_variables) elif function.is_derived: with self.executor.with_bounded_variables({k: v for k, v in zip(function.arguments, argument_values)}): return self._rename_derived_function_application(self.visit(function.derived_expression), function.arguments, expr.arguments, argument_values) else: # dynamic predicate is exactly the same thing as a pre-defined external function. # only supports external functions with a single return value. return self.forward_external_function(function.name, argument_values, return_type, expression=expr)
def _rename_derived_function_application(self, rv: TensorValue, function_argument_variables, outer_arguments, argument_values): if not isinstance(rv, TensorValue): return rv output_batch_variables = list(rv.batch_variables) for function_argument_variable, outer_argument_expr, argument_value in zip(function_argument_variables, outer_arguments, argument_values): if argument_value is QINDEX: assert function_argument_variable.name in output_batch_variables, f'Variable {function_argument_variable.name} not found in the output batch variables {output_batch_variables}. Report this as a bug.' assert isinstance(outer_argument_expr, E.VariableExpression), 'Only variable arguments can be QINDEX. Report this as a bug.' index = output_batch_variables.index(function_argument_variable.name) output_batch_variables[index] = outer_argument_expr.variable.name return rv.clone(clone_tensor=False).rename_batch_variables(output_batch_variables, clone=False)
[docs] def visit_list_function_application_expression(self, expr: E.ListFunctionApplicationExpression) -> Any: argument_values = self.forward_args(*expr.arguments, force_tuple=True) return self.visit_function_application_expression(expr, argument_values)
# if nr_values is None: # return self.visit_function_application_expression(expr, argument_values) # else: # rvs = list() # for i in range(nr_values): # this_argv = tuple(argv.values[i] if isinstance(argv, ListValue) else argv for argv in argument_values) # rv = self.visit_function_application_expression(expr, this_argv) # rvs.append(rv) # return ListValue(expr.return_type, rvs)
[docs] def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True) argument_values = expand_argument_values(argument_values) assert len(argument_values) > 0 assert all(isinstance(v, TensorValue) for v in argument_values) dtype = argument_values[0].dtype batch_variables = argument_values[0].batch_variables if expr.bool_op is BoolOpType.NOT: assert len(argument_values) == 1 return TensorValue( dtype, batch_variables, torch.logical_not(argument_values[0].tensor) if argument_values[0].tensor.dtype == torch.bool else 1 - argument_values[0].tensor, batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.AND: if len(argument_values) == 1: return argument_values[0] return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amin(dim=-1), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.OR: if len(argument_values) == 1: return argument_values[0] return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).amax(dim=-1), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.XOR: if len(argument_values) == 1: return argument_values[0] for argv in argument_values: if argv.tensor.requires_grad: raise RuntimeError('XOR does not support gradients.') return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.stack([argv.tensor for argv in argument_values], dim=-1).sum(dim=-1) % 2, None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) elif expr.bool_op is BoolOpType.IMPLIES: assert len(argument_values) == 2 return TensorValue( dtype, batch_variables, MaskedTensorStorage(torch.max(1 - argument_values[0].tensor, argument_values[1].tensor), None, argument_values[0].tensor_mask), batch_dims=self.executor.state.batch_dims ) else: raise ValueError(f'Unknown bool op type: {expr.bool_op}.')
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType: if value is None: with self.executor.new_bounded_variables({expr.variable: QINDEX}): value = self.forward_args(expr.expression) assert isinstance(value, TensorValue) dtype = value.dtype batch_variables = value.batch_variables variable_index = batch_variables.index(expr.variable.name) batch_variables = batch_variables[:variable_index] + batch_variables[variable_index + 1:] if value.tensor_mask is None: tensor = value.tensor mask = None else: tensor, mask = value.tensor, value.tensor_mask if expr.quantification_op is QuantificationOpType.FORALL: tensor = (tensor * mask + (1 - mask)).to(tensor.dtype) elif expr.quantification_op is QuantificationOpType.EXISTS: tensor = (tensor * mask).to(tensor.dtype) else: raise ValueError(f'Unknown quantification op type: {expr.quantification_op}.') if expr.quantification_op is QuantificationOpType.FORALL: return TensorValue( dtype, batch_variables, MaskedTensorStorage(tensor.amin(dim=variable_index + value.batch_dims), None, mask), batch_dims=self.executor.state.batch_dims ) elif expr.quantification_op is QuantificationOpType.EXISTS: return TensorValue( dtype, batch_variables, MaskedTensorStorage(tensor.amax(dim=variable_index + value.batch_dims), None, mask), batch_dims=self.executor.state.batch_dims ) else: raise ValueError(f'Unknown quantifier type: {expr.quantification_op}.')
[docs] def visit_object_compare_expression(self, expr: E.ObjectCompareExpression) -> Any: v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1]) values = list() batched_variables = list() for i, v in enumerate([v1, v2]): if v is QINDEX: arg = expr.arguments[i] assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.' batched_variables.append(arg.variable.name) values.append(torch.arange(0, len(self.executor.state.object_type2name[arg.variable.dtype.typename]), dtype=torch.int64)) elif isinstance(v, StateObjectReference): values.append(torch.tensor(v.index, dtype=torch.int64)) else: raise ValueError(f'Unsupported value type: {v}.') if v1 is QINDEX: if v2 is QINDEX: value = values[0].unsqueeze(1).eq(values[1].unsqueeze(0)) else: value = values[0].eq(values[1]) else: value = values[0].eq(values[1]) if expr.compare_op is E.CompareOpType.EQ: pass elif expr.compare_op is E.CompareOpType.NEQ: value = torch.logical_not(value) else: raise ValueError(f'Unknown compare op type for object types: {expr.compare_op}.') return TensorValue(BOOL, batched_variables, value, batch_dims=self.executor.state.batch_dims)
[docs] def visit_value_compare_expression(self, expr: E.ValueCompareExpression) -> Any: v1, v2 = self.forward_args(expr.arguments[0], expr.arguments[1]) mapping = { E.CompareOpType.EQ: 'equal', E.CompareOpType.NEQ: 'not_equal', E.CompareOpType.LT: 'less', E.CompareOpType.LEQ: 'less_equal', E.CompareOpType.GT: 'greater', E.CompareOpType.GEQ: 'greater_equal', } target_op = mapping[expr.compare_op] if isinstance(v1.dtype, PyObjValueType): rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr) elif isinstance(v1.dtype, NamedTensorValueType) or isinstance(v1.dtype, ScalarValueType): rv = self.forward_external_function(f'type::{v1.dtype.typename}::{target_op}', [v1, v2], BOOL, expression=expr) else: raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {v1.dtype} and {v2.dtype}.') return rv
[docs] def visit_condition_expression(self, expr: E.ConditionExpression) -> Any: raise NotImplementedError('Condition expression is not supported in the expression evaluation.')
[docs] def visit_find_one_expression(self, expr: E.FindOneExpression) -> Any: with self.executor.new_bounded_variables({expr.variable: QINDEX}): values = self.visit(expr.expression) assert values.batch_dims == 0 assert len(values.batch_variables) == 1 x = (values.tensor > 0.5) objects = x.nonzero().squeeze(-1).detach().cpu().tolist() if len(objects) == 0: raise RuntimeError('No object found. Currently the executor does not support this case.') names = self.executor.state.object_type2name[expr.variable.dtype.typename] return StateObjectReference(names[objects[0]], objects[0], expr.variable.dtype)
[docs] def visit_find_all_expression(self, expr: E.FindAllExpression) -> Any: with self.executor.new_bounded_variables({expr.variable: QINDEX}): values = self.visit(expr.expression) assert values.batch_dims == 0 assert len(values.batch_variables) == 1 x = (values.tensor > 0.5) objects = x.nonzero().squeeze(-1).detach().cpu().tolist() names = self.executor.state.object_type2name[expr.variable.dtype.typename] return StateObjectList(expr.return_type, [StateObjectReference(names[i], i, expr.variable.dtype) for i in objects])
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> TensorValueExecutorReturnType: if feature is None or value is None: feature, value = self.forward_args(expr.predicate, expr.value) feature, value = expand_argument_values([feature, value]) if isinstance(feature.dtype, PyObjValueType): rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr) elif isinstance(feature.dtype, NamedTensorValueType): rv = self.forward_external_function(f'type::{feature.dtype.typename}::equal', [feature, value], BOOL, expression=expr) else: raise NotImplementedError(f'Unsupported FeatureEqual computation for dtype {feature.dtype} and {value.dtype}.') return rv
[docs] def visit_assign_expression(self, expr: E.AssignExpression): # TODO(Jiayuan Mao @ 2024/01/22): is this really the right thing to do? if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: return state: CrowState = self.executor.state argument_values = list(self.forward_args(*expr.predicate.arguments, force_tuple=True)) target_value = self.forward_args(expr.value) for i, (arg, value) in enumerate(zip(expr.predicate.arguments, argument_values)): if value == QINDEX: assert isinstance(arg, E.VariableExpression), 'Quantified object comparison only supports variable arguments.' # TODO(Jiayuan Mao @ 2024/08/3): I think we need to think about how to align the batch variable dimensions... pass elif isinstance(value, StateObjectList): argument_values[i] = value.array_accessor elif isinstance(value, StateObjectReference): argument_values[i] = value.index # if state.features[expr.predicate.function.name].quantized: # if not value.quantized: # value = self.executor.value_quantizer.quantize_value(value) # else: # if value.quantized: # value = self.executor.value_quantizer.unquantize_value(value) function_name = expr.predicate.function.name if function_name not in state.features: raise NotImplementedError('Assignments to dirty features are not supported.') else: state.features[expr.predicate.function.name][tuple(argument_values)] = target_value
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType: value, condition = self.forward_args(expr.predicate, expr.condition) value = value.clone() if value.tensor_mask is None: value.tensor_mask = condition.tensor else: value.tensor_mask = torch.min(value.tensor_mask, condition.tensor) return value
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> TensorValueExecutorReturnType: with self.executor.new_bounded_variables({expr.variable: QINDEX}): return self.visit(expr.expression)
CONDITIONAL_ASSIGN_QUANTIZE = False
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression): state = self.executor.state argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) argument_values = [v.index if isinstance(v, StateObjectReference) else v for v in argument_values] value = self.forward_args(expr.value) condition = self.forward_args(expr.condition) condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor feature = state.features[expr.predicate.function.name] origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too. # assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0 if value.is_tensorized_pyobj: raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.') else: if condition_tensor.dim() < value.tensor.dim(): condition_tensor = condition_tensor.unsqueeze(-1) state.features[expr.predicate.function.name].tensor[argument_values] = ( condition_tensor.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor).to(origin_tensor.dtype) * origin_tensor )
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression): with self.executor.new_bounded_variables({expr.variable: QINDEX}): self.visit(expr.expression)
[docs] def forward_args(self, *args, force_tuple: bool = False, expand_list_arguments: bool = False) -> Union[TensorValueExecutorReturnType, Tuple[TensorValueExecutorReturnType, ...]]: if len(args) == 1 and not force_tuple: rvs = self.visit(args[0]) else: rvs = tuple(self.visit(arg) for arg in args) if expand_list_arguments: expanded_rvs = list() for rv in rvs: if isinstance(rv, ListValue): expanded_rvs.extend(rv.values) else: expanded_rvs.append(rv) if not force_tuple and len(expanded_rvs) == 1: return expanded_rvs[0] else: return tuple(expanded_rvs) else: return rvs
[docs] def forward_external_function( self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[Expression] = None ) -> TensorValue: external_function = self.executor.get_function_implementation(function_name) assert isinstance(external_function, CrowPythonFunctionRef) function_def = expression.function if isinstance(expression, E.FunctionApplicationExpression) else None return external_function.forward(argument_values, return_type=return_type, auto_broadcast=auto_broadcast, function_def=function_def)
[docs] class CrowExecutionCSPVisitor(CrowExecutionDefaultVisitor):
[docs] def __init__(self, executor: CrowExecutor): super().__init__(executor)
[docs] def forward_external_function( self, function_name: str, argument_values: Sequence[TensorValueExecutorReturnType], return_type: Union[TensorValueTypeBase, PyObjValueType], auto_broadcast: bool = True, expression: Optional[E.FunctionApplicationExpression] = None ) -> TensorValue: argument_values = expand_argument_values(argument_values) optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None] if len(optimistic_masks) > 0: optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1) rv = super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression) if optimistic_mask.sum().item() == 0: return rv rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expression.cached_string(-1) constraint_function = Constraint.EQUAL if isinstance(expression, E.PredicateEqualExpression) else expression.function for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) new_identifier = self.executor.csp.new_var(return_type, wrap=True) rv.tensor_optimistic_values[ind] = new_identifier.identifier self.csp.add_constraint(Constraint.from_function( constraint_function, [_fast_index(argv, ind) for argv in argument_values], new_identifier ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) return rv return super().forward_external_function(function_name, argument_values, return_type=return_type, auto_broadcast=auto_broadcast, expression=expression)
[docs] def visit_function_application_expression(self, expr: E.FunctionApplicationExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: return super().visit_function_application_expression(expr, argument_values)
[docs] def visit_bool_expression(self, expr: E.BoolExpression, argument_values: Optional[Tuple[TensorValueExecutorReturnType, ...]] = None) -> TensorValueExecutorReturnType: if argument_values is None: argument_values = self.forward_args(*expr.arguments, force_tuple=True, expand_list_arguments=True) argument_values = list(expand_argument_values(argument_values)) for argv in argument_values: assert argv.dtype == BOOL, 'Boolean expression only supports boolean values in CSP mode.' optimistic_masks = [is_optimistic_value(argv.tensor_optimistic_values) for argv in argument_values if isinstance(argv, TensorValue) and argv.tensor_optimistic_values is not None] if len(optimistic_masks) > 0: optimistic_mask = torch.stack(optimistic_masks, dim=-1).any(dim=-1) if optimistic_mask.sum().item() > 0: rv = super().visit_bool_expression(expr, argument_values) rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expr.cached_string(-1) for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) this_argv = [argv.fast_index(ind, wrap=False) for argv in argument_values] determined = None if expr.return_type == BOOL: if expr.bool_op is BoolOpType.NOT: pass # nothing we can do. elif expr.bool_op is BoolOpType.AND: if 0 in this_argv or False in this_argv: determined = False elif expr.bool_op is BoolOpType.OR: if 1 in this_argv or True in this_argv: determined = True this_argv = [v for v in this_argv if isinstance(v, OptimisticValue)] else: # generalized boolean operations. pass if determined is None: new_identifier = self.csp.new_var(BOOL) rv.tensor_optimistic_values[ind] = new_identifier self.csp.add_constraint(Constraint( expr.bool_op, this_argv, cvt_opt_value(new_identifier, BOOL), ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) else: rv[ind] = determined return rv else: return super().visit_bool_expression(expr, argument_values) else: # if len(optimistic_masks) == 0 return super().visit_bool_expression(expr, argument_values)
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression, value: Optional[TensorValue] = None) -> Any: if value is None: with self.executor.new_bounded_variables({expr.variable: QINDEX}): value = self.forward_args(expr.expression) assert isinstance(value, TensorValue) assert value.dtype == BOOL, 'Quantification expression only supports boolean values in CSP mode.' value.init_tensor_optimistic_values() rv = super().visit_quantification_expression(expr, value) dim = value.batch_variables.index(expr.variable.name) + value.batch_dims value_transposed = value.tensor optimistic_values_transposed = value.tensor_optimistic_values if dim != value.tensor.ndim - 1: value_transposed = value_transposed.transpose(dim, -1) # put the target dimension last. optimistic_values_transposed = optimistic_values_transposed.transpose(dim, -1) optimistic_mask_transposed = is_optimistic_value(optimistic_values_transposed) value_transposed = torch.where( optimistic_mask_transposed, optimistic_values_transposed, value_transposed.to(optimistic_values_transposed.dtype) ) optimistic_mask = optimistic_mask_transposed.any(dim=-1) if optimistic_mask.sum().item() == 0: return rv rv.init_tensor_optimistic_values() if self.executor.optimistic_execution: rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC else: expr_string = expr.cached_string(-1) for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) this_argv = value_transposed[ind].tolist() determined = None if expr.quantification_op is QuantificationOpType.FORALL: if 0 in this_argv or False in this_argv: determined = False else: if 1 in this_argv or True in this_argv: determined = True this_argv = list(filter(is_optimistic_value, this_argv)) if determined is None: new_identifier = self.csp.new_var(BOOL) rv.tensor_optimistic_values[ind] = new_identifier self.csp.add_constraint(Constraint( expr.quantification_op, [OptimisticValue(value.dtype, int(v)) for v in this_argv], OptimisticValue(value.dtype, new_identifier), ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) else: rv.tensor[ind] = determined return rv
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression, feature: Optional[TensorValue] = None, value: Optional[TensorValue] = None) -> Any: if feature is None or value is None: feature, value = self.forward_args(expr.predicate, expr.value) feature, value = expand_argument_values([feature, value]) feature.init_tensor_optimistic_values() value.init_tensor_optimistic_values() optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values)) if optimistic_mask.sum().item() > 0: raise NotImplementedError('Optimistic execution is not supported for predicate equal expression.') rv = super().visit_predicate_equal_expression(expr, feature, value) return rv
# feature.init_tensor_optimistic_values() # value.init_tensor_optimistic_values() # optimistic_mask = torch.logical_or(is_optimistic_value(feature.tensor_optimistic_values), is_optimistic_value(value.tensor_optimistic_values)) # if optimistic_mask.sum().item() == 0: # return rv # rv.init_tensor_optimistic_values() # if self.executor.optimistic_execution: # rv.tensor_optimistic_values[optimistic_mask.nonzero(as_tuple=True)] = OPTIM_MAGIC_NUMBER_MAGIC # else: # expr_string = expr.cached_string(-1) # for ind in optimistic_mask.nonzero().tolist(): # ind = tuple(ind) # this_argv = feature.fast_index(ind), value.fast_index(ind) # new_identifier = self.csp.new_var(BOOL) # rv.tensor_optimistic_values[ind] = new_identifier # self.csp.add_constraint(EqualityConstraint( # *[cvt_opt_value(v, feature.dtype) for v in this_argv], # OptimisticValue(BOOL, new_identifier) # ), note=f'{expr_string}::{ind}' if len(ind) > 0 else expr_string) # return rv
[docs] def visit_assign_expression(self, expr: E.AssignExpression) -> Any: if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: feature = self.executor.state.features[expr.predicate.function.name] feature.init_tensor_optimistic_values() argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) assert self.executor.effect_action_index is not None, 'Effect action index must be set if the target predicate will be updated from simulation.' expr_string = expr.cached_string(-1) for entry_values in _expand_tensor_indices(feature, argument_values): if self.executor.optimistic_execution: raise RuntimeError('Optimistic execution is not supported for effect update from simulation.') else: opt_identifier = self.csp.new_var(feature.dtype, wrap=True) feature.tensor_optimistic_values[entry_values] = opt_identifier.identifier self.csp.add_constraint(Constraint( SimulationFluentConstraintFunction(self.executor.effect_action_index, expr.predicate.function, entry_values, is_execution_constraint=self.executor.effect_update_from_execution), [], opt_identifier, note=f'{expr_string}::{entry_values}' if len(entry_values) > 0 else expr_string )) else: return super().visit_assign_expression(expr)
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> TensorValueExecutorReturnType: return super().visit_conditional_select_expression(expr)
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> Any: return super().visit_deictic_select_expression(expr)
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression) -> Any: if self.executor.effect_update_from_simulation or self.executor.effect_update_from_execution: raise NotImplementedError('Conditional assign is not supported in simulation mode.') if self.executor.optimistic_execution: raise RuntimeError('Optimistic execution is not supported for conditional assign.') state = self.executor.state argument_values = self.forward_args(*expr.predicate.arguments, force_tuple=True) value = self.forward_args(expr.value) condition = self.forward_args(expr.condition) condition_tensor = jactorch.quantize(condition.tensor) if self.CONDITIONAL_ASSIGN_QUANTIZE else condition.tensor condition_tensor = (condition_tensor > 0.5).to(torch.bool) feature = state.features[expr.predicate.function.name] origin_tensor = feature[argument_values].tensor # I am not using feature.tensor[argument_values] because the current code will handle TensorizedPyObjValues too. # assert value.tensor.dim() == condition_tensor.dim() or value.tensor.dim() == 0 # NB(Jiayuan Mao @ 2023/08/15): conditional assignment does not support "soft" assignment. if value.is_tensorized_pyobj: raise NotImplementedError('Cannot make conditional assignments for tensorized pyobj.') else: if condition_tensor.dim() < value.tensor.dim(): condition_tensor_expanded = condition_tensor.unsqueeze(-1) else: condition_tensor_expanded = condition_tensor feature.tensor[argument_values] = ( condition_tensor_expanded.to(origin_tensor.dtype) * value.tensor + (1 - condition_tensor_expanded).to(origin_tensor.dtype) * origin_tensor ) feature.init_tensor_optimistic_values() if value.tensor_optimistic_values is not None: feature.tensor_optimistic_values[argument_values] = ( condition_tensor.to(torch.int64) * value.tensor_optimistic_values + (1 - condition_tensor).to(torch.int64) * state.features[expr.predicate.function.name].tensor_optimistic_values[argument_values] ) if condition.tensor_optimistic_values is None: pass else: optimistic_mask = is_optimistic_value(condition.tensor_optimistic_values) if optimistic_mask.sum().item() == 0: pass else: expr_string = expr.cached_string(-1) dtype = expr.predicate.function.return_type for ind in optimistic_mask.nonzero().tolist(): ind = tuple(ind) new_identifier = self.csp.new_var(dtype, wrap=True) neg_condition_identifier = self.csp.new_var(BOOL, wrap=True) eq_1_identifier = self.csp.new_var(BOOL, wrap=True) eq_2_identifier = self.csp.new_var(BOOL, wrap=True) condition_identifier = condition.tensor_optimistic_values[ind].item() self.csp.add_constraint(EqualityConstraint( new_identifier, cvt_opt_value(value.fast_index(ind), dtype), eq_1_identifier, ), note=f'{expr_string}::{ind}::eq-1') self.csp.add_constraint(EqualityConstraint( new_identifier, cvt_opt_value(origin_tensor.fast_index(ind) if isinstance(origin_tensor, TensorizedPyObjValues) else origin_tensor[ind], dtype), eq_2_identifier ), note=f'{expr_string}::{ind}::eq-2') self.csp.add_constraint(Constraint( BoolOpType.NOT, [OptimisticValue(BOOL, condition_identifier)], neg_condition_identifier ), note=f'{expr_string}::{ind}::neg-cond') self.csp.add_constraint(Constraint( BoolOpType.OR, [neg_condition_identifier, eq_1_identifier], cvt_opt_value(True, BOOL) ), note=f'{expr_string}::{ind}::implies-new') self.csp.add_constraint(Constraint( BoolOpType.OR, [OptimisticValue(BOOL, condition_identifier), eq_2_identifier], cvt_opt_value(True, BOOL) ), note=f'{expr_string}::{ind}::implies-old') feature.tensor_optimistic_values[ind] = new_identifier.identifier
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression) -> Any: return super().visit_deictic_assign_expression(expr)
def _iter_tensor_indices(target_tensor: torch.Tensor) -> Iterator[Tuple[int, ...]]: """Iterate from the indices of a tensor. """ for ind in torch.nonzero(torch.ones_like(target_tensor)): yield tuple(ind.tolist()) def _expand_tensor_indices(target_value: TensorValue, input_indices: Tuple[Union[int, slice, StateObjectReference], ...]) -> Iterator[Tuple[int, ...]]: """Iterate over entry indices based on the input indices. Supported indices are int and QINDEX (:). Args: target_value: the target value, used to determine the size of ``QINDEX``. input_indices: the indices to iterate over. Yields: the entry indices. """ indices = list() for i, ind in enumerate(input_indices): if isinstance(ind, int): indices.append(torch.tensor([ind], dtype=torch.int64)) elif isinstance(ind, slice): assert ind.step is None and ind.start is None and ind.stop is None # == ':' indices.append(torch.arange(target_value.tensor.shape[i], dtype=torch.int64)) elif isinstance(ind, StateObjectReference): indices.append(torch.tensor([ind.index], dtype=torch.int64)) else: raise ValueError(f'Invalid index type: {ind}') if len(indices) == 0: yield tuple() return if len(indices) == 1: for x in indices[0]: yield tuple([x.item()], ) return indices = torch.meshgrid(*indices, indexing='ij') indices = [i.flatten() for i in indices] for i in range(len(indices[0])): yield tuple(indices[j][i].item() for j in range(len(indices)))