#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : tensor_value_executor.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 11/03/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""Tensor-based expression executor.
The high-level interface for tensor-based expression is that we can execute an expression with a given state and a set of
bounded variables. The executor will return a tensor value.
The state is represented using :class:`concepts.dsl.tensor_state.TensorState` or :class:`concepts.dsl.tensor_state.NamedObjectTensorState`, which internally stores a dictionary
mapping from string (the state variable name, e.g., ``is_hot``) to a :class:`concepts.dsl.tensor_value.TensorValue` class.
The bounded variables are essentially a dictionary mapping from strings (the variable name, e.g., ``x``) to its value. There are
two types of values: (1) a :class:`concepts.dsl.tensor_value.TensorValue` class, which represents an actual value (e.g., a vector representation);
(2) a :class:`StateObjectReference` instance or a QINDEX (a.k.a., ``slice(None)``), which represents a reference to an object in the state.
With the bounded variables, the expressions can have variables, which are essentially placeholders for the actual values. For example,
.. code-block:: python
domain = FunctionDomain()
# Define an object type `person`.
domain.define_type(ObjectType('person'))
# Define a state variable `is_friend` with type `person x person -> bool`.
domain.define_function(Function('is_friend', FunctionType([ObjectType('person'), ObjectType('person')], BOOL)))
x = VariableExpression(Variable('x', ObjectType('person')))
y = VariableExpression(Variable('y', ObjectType('person')))
relation = FunctionApplication(domain.functions['is_friend'], [x, y])
Then we can execute the expression with a given state and bounded variables:
.. code-block:: python
# See the documentation for namedObjectTensorState for more details.
state = NamedObjectTensorState({
'is_friend': TensorValue(BOOL, ['x', 'y'], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 0, 1]], dtype=torch.bool))
}, object_names={
'Alice': ObjectType('person'),
'Bob': ObjectType('person'),
'Charlie': ObjectType('person'),
})
executor = SimpleFunctionTensorValueExecutor(domain)
# For both of the following lines, the result is a tensor value with value `True`.
# Use the constructed expression:
executor.execute(relation, state, {'x': 'Alice', 'y': 'Bob'})
# To use the default parser:
executor.execute('is_friend(x, y)', state, {'x': 'Alice', 'y': 'Bob'})
"""
import contextlib
from typing import Optional, Union, Tuple, Sequence, Dict
import torch
from concepts.dsl.dsl_types import ObjectType, ValueType, TensorValueTypeBase, NamedTensorValueType, PyObjValueType, ListType, ObjectConstant, Variable, UnnamedPlaceholder, QINDEX
from concepts.dsl.dsl_types import BOOL, INT64, FLOAT32
from concepts.dsl.dsl_domain import DSLDomainBase
from concepts.dsl.function_domain import FunctionDomain
from concepts.dsl.value import ListValue
from concepts.dsl.constraint import Constraint, SimulationFluentConstraintFunction
from concepts.dsl.tensor_value import TensorValue, scalar
from concepts.dsl.tensor_state import StateObjectReference, TensorState, NamedObjectTensorState
from concepts.dsl.expression import Expression, VariableExpression, ObjectConstantExpression, ConstantExpression, FunctionApplicationExpression, ValueCompareExpression, BoolOpType, QuantificationOpType
from concepts.dsl.constraint import OptimisticValue, ConstraintSatisfactionProblem, OPTIM_MAGIC_NUMBER
from concepts.dsl.parsers.parser_base import ParserBase
from concepts.dsl.parsers.function_expression_parser import FunctionExpressionParser
from concepts.dsl.executors.executor_base import DSLExecutorBase
from concepts.dsl.executors.value_quantizers import ValueQuantizer, PyObjectStore
__all__ = [
'BoundedVariablesDict', 'BoundedVariablesDictCompatible',
'compose_bvdict', 'compose_bvdict_args', 'get_bvdict',
'TensorValueExecutorReturnType', 'TensorValueExecutorBase', 'FunctionDomainTensorValueExecutor'
]
BoundedVariablesDict = Dict[str, Dict[str, Union[StateObjectReference, slice, TensorValue]]]
"""Internal representation of a bounded variable dictionary. It stores a nested two-layer dictionary, where the first layer
stores the type of the object, and the second layer stores the name of the object. The value can be either a :class:`concepts.dsl.tensor_value.TensorValue`
or a :class:`StateObjectReference` instance (representing the reference to a single object)."""
BoundedVariablesDictCompatibleKeyType = Union[str, Variable]
BoundedVariablesDictCompatibleValueType = Union[str, int, slice, bool, float, torch.Tensor, TensorValue, ObjectConstant, StateObjectReference]
BoundedVariablesDictCompatible = Union[
None, Sequence[Variable],
Dict[BoundedVariablesDictCompatibleKeyType, BoundedVariablesDictCompatibleValueType],
BoundedVariablesDict
]
"""Compatible types with :class:`BoundedVariablesDict`. They can be converted to :class:`BoundedVariablesDict` using :func:`compose_bvdict`."""
def _get_state_object_reference(state, dtype, value):
if isinstance(value, int):
assert isinstance(state, NamedObjectTensorState)
value = StateObjectReference(state.object_type2name[dtype.typename][value], value, dtype)
return value
elif isinstance(value, str):
assert isinstance(state, NamedObjectTensorState)
value = StateObjectReference(value, state.get_typed_index(value), dtype)
return value
elif isinstance(value, ObjectConstant):
assert isinstance(state, NamedObjectTensorState)
value = StateObjectReference(value.name, state.get_typed_index(value.name, typename=value.dtype.typename), value.dtype)
return value
elif isinstance(value, slice):
return value
elif isinstance(value, StateObjectReference):
return value
else:
raise TypeError(f'Invalid object reference type: {type(value)}.')
[docs]
def compose_bvdict(input_dict: BoundedVariablesDictCompatible, state: Optional[TensorState] = None) -> BoundedVariablesDict:
"""Compose a bounded variable dict from raw inputs.
Args:
input_dict: the input dict. There are three types of inputs:
1. A sequence of :class:`concepts.dsl.dsl_types.Variable` instances, which represents a set of variables with no values.
2. A dictionary mapping from :class:`concepts.dsl.dsl_types.Variable` instances to the actual value.
3. A dictionary mapping from strings (the name of the variables) to values.
Acceptable values are:
1. A :class:`str`, which represents a reference to an object in the state (so the state must be object-named).
2. An integer, which represents a reference to an object in the state (so the state must be object-named).
3. A QINDEX (a.k.a., ``slice(None)``), which represents all objects in the state of a given type (so the state must be object-named).
4. A :class:`concepts.dsl.tensor_value.TensorValue` instance, which represents an actual value.
5. A :class:`StateObjectReference` instance, which represents a reference to an object in the state (so the state must be object-named).
6. A :class:`bool`, :class:`int`, :class:`float`, or :class:`torch.Tensor` instance, which represents an actual value. They will be converted to a :class:`concepts.dsl.tensor_value.TensorValue` instance.
state: the state.
Returns:
a dictionary mapping from strings (the typename) to a dictionary mapping from strings (the name of the variables) to values.
"""
if input_dict is None:
return dict()
if isinstance(input_dict, dict):
if len(input_dict) == 0:
return {}
sample_value = next(iter(input_dict.values()))
if isinstance(sample_value, dict):
return {k: v.copy() for k, v in input_dict.items()}
output_dict = dict()
for var, value in input_dict.items():
if isinstance(var, Variable):
# Part 1: the variable corresponds to an object.
if isinstance(var.dtype, ObjectType):
output_dict.setdefault(var.typename, dict()).setdefault(var.name, _get_state_object_reference(state, var.dtype, value))
elif isinstance(var.dtype, ListType):
assert isinstance(value, ListValue)
output_dict.setdefault(var.dtype.typename, {})[var.name] = value
# Part 2: the variable corresponds to a Python object.
elif isinstance(var.dtype, PyObjValueType):
if isinstance(value, TensorValue):
pass
else:
value = TensorValue.from_scalar(value, var.dtype)
typename = var.dtype.typename
output_dict.setdefault(typename, {})[var.name] = value
# Part 3: the variable corresponds to a PyTorch tensor.
elif isinstance(var.dtype, TensorValueTypeBase):
if isinstance(value, TensorValue):
pass
elif isinstance(value, (bool, int, float, torch.Tensor)):
value = TensorValue.from_scalar(value, var.dtype)
elif isinstance(value, UnnamedPlaceholder):
value = TensorValue.from_optimistic_value_int(OPTIM_MAGIC_NUMBER, var.dtype) # Just a placeholder.
else:
raise TypeError(f'Invalid value type for variable {var}: {type(value)}.')
output_dict.setdefault(var.dtype.typename, {})[var.name] = value
elif isinstance(var.dtype, ListType):
assert isinstance(value, ListValue)
if isinstance(var.dtype.element_type, ObjectType):
if value.values == QINDEX:
pass
else:
value = ListValue(var.dtype, [_get_state_object_reference(state, var.dtype.element_type, v) for v in value.values])
else:
pass
output_dict.setdefault(var.dtype.typename, {})[var.name] = value
else:
raise TypeError(f'Invalid variable type: {var.dtype}.')
elif isinstance(var, OptimisticValue):
raise RuntimeError('Invalid branch; OptimisticValue should be handled in the previous branch. Report a bug to the developers.')
elif isinstance(var, str) and isinstance(value, str):
assert state is not None
typename, value_index = state.get_typename(value), state.get_typed_index(value)
value = StateObjectReference(value, value_index)
output_dict.setdefault(typename, dict()).setdefault(var, value)
elif isinstance(var, str) and isinstance(value, ObjectConstant):
assert state is not None
typename = value.typename
value_index = state.get_typed_index(value.name, typename)
value = StateObjectReference(value.name, value_index, value.dtype)
output_dict.setdefault(typename, dict()).setdefault(var, value)
elif isinstance(var, str) and isinstance(value, StateObjectReference):
assert state is not None
assert value.dtype is not None
output_dict.setdefault(value.dtype.typename, dict()).setdefault(var, value)
elif isinstance(var, str) and isinstance(value, ListValue):
output_dict.setdefault(value.dtype.typename, dict()).setdefault(var, value)
elif isinstance(var, str) and isinstance(value, TensorValue):
output_dict.setdefault(value.dtype.typename, dict()).setdefault(var, value)
else:
raise TypeError(f'Invalid KV pair: {var} -> {value}.')
return output_dict
elif isinstance(input_dict, (list, tuple)):
# The input dict is a list of variables.
assert isinstance(input_dict, (list, tuple))
output_dict = dict()
for var in input_dict:
assert isinstance(var, Variable)
output_dict.setdefault(var.typename, dict()).setdefault(var.name, QINDEX)
return output_dict
else:
raise TypeError(f'Invalid input type: {type(input_dict)}.')
[docs]
def compose_bvdict_args(arguments_def: Sequence[Variable], arguments: Sequence[BoundedVariablesDictCompatibleValueType], state: Optional[TensorState] = None) -> BoundedVariablesDict:
"""Compose a bounded variable dict, but from a list of arguments. This function is useful when we want to compose a bounded variable dict from a list of arguments to a function.
Args:
arguments_def: the definition of the arguments, including their name and dtypes.
arguments: the actual arguments.
state: the state.
Returns:
a bounded variable dictionary.
"""
return compose_bvdict(dict(zip(arguments_def, arguments)), state=state)
[docs]
def get_bvdict(bvdict: BoundedVariablesDict, variable: Variable) -> Union[StateObjectReference, slice, TensorValue]:
"""Get the value of a variable from a bounded variable dict.
Args:
bvdict: the bounded variable dict.
variable: the variable.
Returns:
the value of the variable.
"""
return bvdict[variable.typename][variable.name]
TensorValueExecutorReturnTypeElem = Union[TensorValue, slice, StateObjectReference, ListValue, None]
TensorValueExecutorReturnType = Union[TensorValueExecutorReturnTypeElem, Tuple[TensorValueExecutorReturnTypeElem, ...]]
[docs]
class TensorValueExecutorBase(DSLExecutorBase):
"""The base class for tensor value executors."""
[docs]
def __init__(self, domain: DSLDomainBase, parser: Optional[ParserBase] = None):
"""Initialize the base class for tensor value executors.
Args:
domain: the domain of the executor.
parser: the parser to use. If None, no parser will be used.
"""
super().__init__(domain)
self._parser = parser
self._state = None
self._bounded_variables = dict()
@property
def parser(self) -> Optional[ParserBase]:
"""The parser for the domain."""
return self._parser
@property
def state(self) -> Optional[TensorState]:
"""The current state of the environment."""
return self._state
@property
def bounded_variables(self) -> BoundedVariablesDict:
"""The bounded variables for the execution. Note that most of the time you should use the :meth:`get_bounded_variable` method to get values for the bounded variable."""
return self._bounded_variables
@property
def value_quantizer(self) -> ValueQuantizer:
"""The value quantizer."""
return self._value_quantizer
@property
def pyobj_store(self) -> PyObjectStore:
"""The Python object store."""
return self._pyobj_store
[docs]
@contextlib.contextmanager
def with_state(self, state: Optional[TensorState] = None):
"""A context manager to temporarily set the state of the executor."""
old_state = self._state
self._state = state
yield
self._state = old_state
[docs]
@contextlib.contextmanager
def with_bounded_variables(self, bvdict: BoundedVariablesDictCompatible):
"""A context manager to set the bounded variables for the executor.
Args:
bvdict: the bounded variables.
"""
old_bvdict = self._bounded_variables
self._bounded_variables = compose_bvdict(bvdict, state=self._state)
yield
self._bounded_variables = old_bvdict
[docs]
@contextlib.contextmanager
def new_bounded_variables(self, bvdict: BoundedVariablesDictCompatible):
"""A context manager to add additional bounded variables to the executor.
Args:
bvdict: the new bounded variables.
"""
bvdict = compose_bvdict(bvdict, state=self._state)
for typename, variables in bvdict.items():
for name, value in variables.items():
if typename not in self._bounded_variables:
self._bounded_variables[typename] = dict()
assert name not in self._bounded_variables[typename], f'Variable {name} already exists in bounded variables.'
self._bounded_variables[typename][name] = value
yield
for typename, variables in bvdict.items():
for name in variables:
del self._bounded_variables[typename][name]
[docs]
def retrieve_bounded_variable_by_name(self, name: str) -> Union[TensorValue, slice, StateObjectReference]:
"""Retrieve a bounded variable by its name.
Args:
name: the name of the variable.
Returns:
the value of the variable.
"""
for variables in self._bounded_variables.values():
if name in variables:
return variables[name]
raise KeyError(f'Variable {name} not found in the bounded variables.')
[docs]
def get_bounded_variable(self, variable: Variable) -> Union[TensorValue, slice, StateObjectReference]:
"""Get the value of a bounded variable.
Args:
variable: the variable.
Returns:
the value of the variable.
"""
return get_bvdict(self._bounded_variables, variable)
[docs]
def set_parser(self, parser: ParserBase):
"""Set the parser for the executor.
Args:
parser: the parser.
"""
self._parser = parser
[docs]
def parse(self, expression: Union[Expression, str]):
"""Parse an expression.
Args:
expression: the expression to parse. When the input is already an expression, it will be returned directly.
Returns:
the parsed expression.
"""
if isinstance(expression, Expression):
return expression
if self._parser is None:
raise ValueError('No parser is set for the executor.')
return self._parser.parse_expression(expression)
[docs]
def execute(
self, expression: Union[Expression, str],
state: Optional[TensorState] = None,
bounded_variables: Optional[BoundedVariablesDictCompatible] = None,
) -> TensorValueExecutorReturnType:
"""Execute an expression.
Args:
expression: the expression to execute.
state: the state to use. If None, the current state of the executor will be used.
bounded_variables: the bounded variables to use. If None, the current bounded variables of the executor will be used.
Returns:
the TensorValue object.
"""
if isinstance(expression, str):
expression = self.parse(expression)
state = state if state is not None else self._state
bounded_variables = bounded_variables if bounded_variables is not None else self._bounded_variables
with self.with_state(state), self.with_bounded_variables(bounded_variables):
return self._execute(expression)
def _execute(self, expression: Expression) -> TensorValueExecutorReturnType:
raise NotImplementedError()
[docs]
def check_constraint(self, constraint: Constraint, state: Optional[TensorState] = None):
if constraint.function is BoolOpType.NOT:
return constraint.arguments[0].item() == (not constraint.rv.item())
elif constraint.function in (QuantificationOpType.FORALL, BoolOpType.AND):
return all([x.item() for x in constraint.arguments]) == constraint.rv.item()
elif constraint.function in (QuantificationOpType.EXISTS, BoolOpType.OR):
return any([x.item() for x in constraint.arguments]) == constraint.rv.item()
elif constraint.function is BoolOpType.IMPLIES:
return (not constraint.arguments[0].item()) or constraint.arguments[1].item() == constraint.rv.item()
elif constraint.function is BoolOpType.XOR:
return sum([x.item() for x in constraint.arguments]) % 2 == constraint.rv.item()
if constraint.is_equal_constraint:
if constraint.arguments[0].dtype in (BOOL, INT64, FLOAT32):
return (constraint.arguments[0].item() == constraint.arguments[1].item()) == constraint.rv.item()
else:
return self.check_eq_constraint(constraint.arguments[0].dtype, constraint.arguments[0], constraint.arguments[1], constraint.rv.item(), state)
if isinstance(constraint.function, SimulationFluentConstraintFunction):
return False
# assert isinstance(c.function, CrowFunctionBase)
# # NB(Jiayuan Mao @ 09/05): for generator placeholders, they can only be set true through the corresponding generators.
# if isinstance(c.function, CrowFunction) and c.function.is_generator_placeholder:
# return False
argument_values = list()
for argument, argv in zip(constraint.function.arguments, constraint.arguments):
if isinstance(argument.dtype, ObjectType):
assert isinstance(argv, StateObjectReference)
argument_values.append(ObjectConstantExpression(ObjectConstant(argv, argument.dtype)))
elif isinstance(argument.dtype, ValueType):
argument_values.append(ConstantExpression(argv, argument.dtype))
else:
raise TypeError(f'Unsupported argument type: {argument.dtype}.')
func = FunctionApplicationExpression(constraint.function, argument_values)
with self.with_state(state):
rv = self._execute(func)
if rv.dtype == BOOL:
return (rv.item() > 0.5) == constraint.rv.item()
else:
return self.check_eq_constraint(rv.dtype, rv, constraint.rv.item(), True, state)
[docs]
def check_eq_constraint(self, dtype: TensorValueTypeBase, x: TensorValue, y: TensorValue, target: bool, state: Optional[TensorState] = None) -> bool:
expr = ValueCompareExpression(ValueCompareExpression.OpType.EQ, ConstantExpression(x, dtype), ConstantExpression(y, dtype))
with self.with_state(state):
return self._execute(expr).item() == target
[docs]
class FunctionDomainTensorValueExecutor(TensorValueExecutorBase):
"""Similar to :class:`~concepts.dsl.executors.function_domain_executor.FunctionDomainExecutor`, but works for :class:`~concepts.dsl.tensor_value.TensorValue`.
The two of the main differences are:
1. The :meth:`execute` method returns a :class:`~concepts.dsl.tensor_value.TensorValue` object instead of a :class:`~concepts.dsl.value.Value` object.
2. The class supports binding variables to values during execution. See the documentation for this file and tutorials for details.
"""
[docs]
def __init__(self, domain: FunctionDomain, parser: Optional[ParserBase] = None):
"""Initialize a tensor value executor for a function domain.
Args:
domain: the domain of the executor.
parser: the parser to use. If not specified, no parser will be used.
"""
if parser is None:
parser = FunctionExpressionParser(domain, allow_variable=True, escape_string=True)
super().__init__(domain, parser)
_domain: FunctionDomain
@property
def domain(self) -> FunctionDomain:
"""The function domain of the executor."""
return self._domain
def _execute(self, expr: Expression) -> TensorValueExecutorReturnType:
if isinstance(expr, VariableExpression):
variable = expr.variable
return self._bounded_variables[variable.dtype.typename][variable.name]
elif isinstance(expr, ObjectConstantExpression):
if isinstance(expr.constant.name, StateObjectReference):
return expr.constant.name
assert isinstance(self._state, NamedObjectTensorState)
constant = expr.constant
return StateObjectReference(
constant.name,
self._state.get_typed_index(constant.name, constant.dtype.typename),
constant.dtype
)
elif isinstance(expr, ConstantExpression):
assert isinstance(expr.constant, TensorValue)
return expr.constant
elif isinstance(expr, FunctionApplicationExpression):
assert isinstance(self._state, NamedObjectTensorState)
func = expr.function
args = [self._execute(arg) for arg in expr.arguments]
if func.name in self._state.features:
args = [arg.index if isinstance(arg, StateObjectReference) else arg for arg in args]
return self._state.features[func.name][tuple(args)]
else:
assert self.has_function_implementation(func.name)
return self.get_function_implementation(func.name)(*args)
else:
raise ValueError(f'Unsupported expression type: {type(expr)}.')