Source code for concepts.pdsketch.domain

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : domain.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/30/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import itertools
from typing import TYPE_CHECKING, Any, Optional, Union, Sequence, Tuple, List, Dict

import jacinle
import torch
from jacinle.utils.printing import indent_text, stprint

from concepts.dsl.dsl_types import BOOL, FLOAT32, INT64, ObjectType, TensorValueTypeBase, PyObjValueType, TupleType, ListType, ScalarValueType, VectorValueType, NamedTensorValueType, Variable, ObjectConstant
from concepts.dsl.dsl_functions import FunctionReturnType, FunctionArgumentListType, FunctionType, Function
from concepts.dsl.dsl_domain import DSLDomainBase
from concepts.dsl.constraint import OPTIM_MAGIC_NUMBER_MAGIC
from concepts.dsl.expression import Expression, ExpressionDefinitionContext, VariableExpression, ValueOutputExpression, cvt_expression_list
from concepts.dsl.expression import FunctionApplicationExpression, AssignExpression, ConditionalAssignExpression, DeicticAssignExpression
from concepts.dsl.constraint import is_optimistic_value
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import NamedObjectTensorState

from concepts.pdsketch.predicate import Predicate, flatten_expression, get_used_state_variables
from concepts.pdsketch.operator import Precondition, Effect, Implementation, Operator, MacroOperator, OperatorApplier, OperatorApplicationExpression
from concepts.pdsketch.regression_rule import RegressionRuleBodyItemType, RegressionRule
from concepts.pdsketch.generator import Generator, FancyGenerator

if TYPE_CHECKING:
    from concepts.pdsketch.executor import PDSketchExecutor

logger = jacinle.get_logger(__file__)

__all__ = ['Domain', 'Problem', 'State']


class _TypedVariableView(object):
    """Use `domain.typed_variable['type_name']('variable_name')`"""

    def __init__(self, domain):
        self.domain = domain

    def __getitem__(self, typename):
        def function(string):
            return Variable(string, self.domain.types[typename])
        return function


[docs]class Domain(DSLDomainBase): """The planning domain definition."""
[docs] def __init__(self, name: Optional[str] = None, pdsketch_version: int = 2): """Initialize a planning domain. Args: name: The name of the domain. """ super().__init__(name) self.pdsketch_version = pdsketch_version self.operators = dict() self.operator_templates = dict() self.regression_rules = dict() self.axioms = dict() self.generators = dict() self.fancy_generators = dict() self.external_functions = dict() self.external_function_crossrefs = dict() self.tv = self.typed_variable = _TypedVariableView(self)
pdsketch_version: int """The version of the PDSketch language. Currently, two supported versions are 2 and 3. This will be used to determine the parsing behavior of the domain.""" name: str """The name of the domain.""" types: Dict[str, Union[ObjectType, PyObjValueType, TensorValueTypeBase]] """The types defined in the domain, as a dictionary from type names to types.""" functions: Dict[str, Predicate] """A mapping from function name to the corresponding :class:`~concepts.pdsketch.predicate.Predicate` class. Note that, unlike the basic :class:`~concepts.dsl.dsl_domain.DSLDomainBase`, in planning domain, all functions should be of type :class:`~concepts.pdsketch.predicate.Predicate`.""" constants: Dict[str, ObjectConstant] """The constants defined in the domain, as a dictionary from constant names to values.""" operators: Dict[str, Union[Operator, MacroOperator]] """A mapping of operators: from operator name to the corresponding :class:`~concepts.pdsketch.operator.Operator` class.""" operator_templates: Dict[str, Operator] """A mapping of operator templates: from operator name to the corresponding :class:`~concepts.pdsketch.operator.Operator` class.""" regression_rules: Dict[str, RegressionRule] """A mapping of regression rules: from regression rule name to the corresponding :class:`~concepts.pdsketch.operator.RegressionRule` class.""" axioms: Dict[str, Operator] """A mapping of axioms: from axiom name to the corresponding :class:`~concepts.pdsketch.operator.Operator` class.""" generators: Dict[str, Generator] """A mapping of generators: from generator name to the corresponding :class:`~concepts.pdsketch.generator.Generator` class.""" fancy_generators: Dict[str, FancyGenerator] """A mapping of fancy generators: from fancy generator name to the corresponding :class:`~concepts.pdsketch.generator.FancyGenerator` class.""" external_functions: Dict[str, Function] """A mapping of external functions: from function name to the corresponding :class:`~concepts.dsl.dsl_functions.Function` class.""" external_function_crossrefs: Dict[str, str] """A mapping from function name to another function name. This is useful when defining one function as an derived function of another function.""" tv: _TypedVariableView """A helper function that returns a variable with the given type. For example, `domain.tv['object']('x')` returns a variable of type `object` with name `x`.""" def __getattr__(self, item): if item.startswith('__') and item.endswith('__'): raise AttributeError # NB(Jiayuan Mao @ 09/03): PDDL definition convention. item = item.replace('_', '-') if item.startswith('t-'): return self.types[item[2:]] elif item.startswith('p-') or item.startswith('f-'): return self.functions[item[2:]] elif item.startswith('op-'): return self.operators[item[3:]] elif item.startswith('gen-'): return self.generators[item[4:]] raise NameError('Unknown attribute: {}.'.format(item))
[docs] def set_name(self, name: str): """Set the name of the domain. Args: name: the new name of the domain. """ self.name = name
BUILTIN_TYPES = ['object', 'pyobject', 'bool', 'int64', 'float32', '__totally_ordered_plan__', '__partially_ordered_plan__'] BUILTIN_NUMERIC_TYPES = { 'bool': BOOL, 'int64': INT64, 'float32': FLOAT32 } BUILTIN_PYOBJ_TYPES = { '__control__': PyObjValueType('__control__', alias='__control__'), '__regression_body_item__': PyObjValueType('__regression_body_item__', alias='__regression_body_item__'), '__totally_ordered_plan__': ListType(PyObjValueType('__regression_body_item__'), alias='__totally_ordered_plan__'), }
[docs] def define_type(self, typename, parent_name: Optional[Union[VectorValueType, ScalarValueType, str]] = 'object') -> Union[ObjectType, PyObjValueType, VectorValueType, ScalarValueType]: """Define a new type. Args: typename: the name of the new type. parent_name: the parent type of the new type, default to 'object'. Returns: the newly defined type. """ if typename == 'object': logger.warning_once('Shadowing built-in type name "object".') elif typename in type(self).BUILTIN_TYPES: raise ValueError('Typename {} is a built-in type.'.format(typename)) assert isinstance(parent_name, (str, VectorValueType)), f'Currently only support inheritance from builtin types: {type(self).BUILTIN_TYPES}.' if isinstance(parent_name, str): if parent_name == 'object': self.types[typename] = ObjectType(typename) elif parent_name == 'pyobject': dtype = PyObjValueType(typename) self.types[typename] = dtype self.declare_external_function(f'type::{typename}::equal', [dtype, dtype], BOOL) elif parent_name == 'int64': dtype = NamedTensorValueType(typename, INT64) self.types[typename] = dtype self.declare_external_function(f'type::{typename}::equal', [dtype, dtype], BOOL) elif parent_name == 'float32': dtype = NamedTensorValueType(typename, FLOAT32) self.types[typename] = dtype self.declare_external_function(f'type::{typename}::equal', [dtype, dtype], BOOL) else: raise ValueError(f'Unknown parent type: {parent_name}.') elif isinstance(parent_name, VectorValueType): dtype = NamedTensorValueType(typename, parent_name) self.types[typename] = dtype self.declare_external_function(f'type::{typename}::equal', [dtype, dtype], BOOL) else: raise ValueError(f'Unknown parent type: {parent_name}.') return self.types[typename]
[docs] def get_type(self, typename: str) -> Union[ObjectType, PyObjValueType, VectorValueType, ScalarValueType, NamedTensorValueType]: """Get a type by name. Args: typename: the name of the type. Returns: the type with the given name. """ if typename in type(self).BUILTIN_NUMERIC_TYPES: return type(self).BUILTIN_NUMERIC_TYPES[typename] elif typename in type(self).BUILTIN_PYOBJ_TYPES: return type(self).BUILTIN_PYOBJ_TYPES[typename] if typename not in self.types: raise ValueError(f'Unknown type: {typename}, known types are: {list(self.types.keys())}.') return self.types[typename]
[docs] def define_predicate( self, name: str, arguments: FunctionArgumentListType, return_type: FunctionReturnType = BOOL, *, observation: Optional[bool] = None, state: Optional[bool] = None, generator_placeholder: bool = False, inplace_generators: Optional[Sequence[str]] = None, simulation: bool = False, execution: bool = False, is_generator_function: bool = False, ): """Define a new predicate. Args: name: the name of the new predicate. arguments: the arguments of the new predicate. return_type: the return type of the new predicate. observation: whether the new predicate is an observation variable. state: whether the new predicate is a state variable. generator_placeholder: whether the new predicate is a generator placeholder. inplace_generators: a list of generators that will be defined in-place for this predicate. simulation: whether the new predicate requires the up-to-date simulation state to evaluate. execution: whether the new predicate requires the up-to-date execution state to evaluate. is_generator_function: whether the new predicate is a generator function. Returns: the newly defined predicate. """ predicate = Predicate(name, FunctionType(arguments, return_type, is_generator_function=is_generator_function), observation=observation, state=state, generator_placeholder=generator_placeholder, inplace_generators=inplace_generators, simulation=simulation, execution=execution) self.define_predicate_inner(name, predicate) return predicate
[docs] def define_derived( self, name: str, arguments: FunctionArgumentListType, return_type: Optional[FunctionReturnType] = None, expr: ValueOutputExpression = None, *, state: bool = False, generator_placeholder: bool = False, simulation: bool = False, execution: bool = False ): """Define a new derived predicate. Note that a derived predicate can not be an observation variable. Args: name: the name of the new derived predicate. arguments: the arguments of the new derived predicate. return_type: the return type of the new derived predicate. expr: the expression of the new derived predicate. state: whether the new derived predicate is a state variable. generator_placeholder: whether the new derived predicate is a generator placeholder. simulation: whether the new derived predicate requires the up-to-date simulation state to evaluate. execution: whether the new derived predicate requires the up-to-date execution state to evaluate. Returns: the newly defined derived predicate. """ predicate_def = Predicate(name, FunctionType(arguments, return_type), observation=False, state=state, generator_placeholder=generator_placeholder, derived_expression=expr, simulation=simulation, execution=execution) return self.define_predicate_inner(name, predicate_def)
[docs] def define_predicate_inner(self, name: str, predicate_def: Predicate): self.functions[name] = predicate_def # NB(Jiayuan Mao @ 07/21): a non-cacheable function is basically an external function. if not predicate_def.is_cacheable and predicate_def.derived_expression is None: self.external_functions[name] = predicate_def return predicate_def
[docs] def get_predicate(self, name: str) -> Predicate: """Get a predicate by name. Args: name: the name of the predicate. Returns: the predicate with the given name. """ if name not in self.functions: raise ValueError(f'Unknown predicate: {name}.') assert isinstance(self.functions[name], Predicate) return self.functions[name]
[docs] def define_operator( self, name: str, parameters: Sequence[Variable], preconditions: Sequence[Precondition], effects: Sequence[Effect], controller: Implementation, template: bool = False, extends: Optional[str] = None, ) -> Operator: """Define a new operator. Args: name: the name of the new operator. parameters: the parameters of the new operator. preconditions: the preconditions of the new operator. effects: the effects of the new operator. controller: the controller of the new operator. template: whether the new operator is a template. extends: the parent operator of the new operator. Returns: the newly defined operator. """ self.operators[name] = op = Operator( name, parameters, preconditions, effects, controller, extends=extends, is_template=template ) return op
[docs] def define_operator_inner(self, name: str, operator: Operator) -> Operator: assert name not in self.operators self.operators[name] = operator return operator
[docs] def has_operator(self, name: str) -> bool: return name in self.operators
[docs] def get_operator(self, name: str) -> Operator: if name not in self.operators: raise ValueError(f'Operator {name} not found.') return self.operators[name]
[docs] def define_regression_rule( self, name: str, parameters: Sequence[Variable], preconditions: Sequence[Precondition], goal_expression: ValueOutputExpression, side_effects: Sequence[Effect], body: Sequence[RegressionRuleBodyItemType], always: bool = False ): """Define a new regression rule. Args: name: the name of the new regression rule. parameters: the parameters of the new regression rule. preconditions: the preconditions of the new regression rule. goal_expression: the goal expression of the new regression rule, as a single expression. side_effects: the side effects of the new regression rule. body: the body of the new regression rule. always: whether the new regression rule is always applicable. Returns: the newly defined regression rule. """ self.regression_rules[name] = rule = RegressionRule(name, parameters, preconditions, goal_expression, side_effects, body, always=always) return rule
[docs] def has_regression_rule(self, name: str) -> bool: return name in self.regression_rules
[docs] def get_regression_rule(self, name: str) -> RegressionRule: if name not in self.regression_rules: raise ValueError(f'Regression rule {name} not found.') return self.regression_rules[name]
[docs] def define_axiom(self, name: Optional[str], parameters: Sequence[Variable], preconditions: Sequence[Precondition], effects: Sequence[Effect]) -> Operator: """Define a new axiom. Args: name: the name of the new axiom. If None, a unique name will be generated. parameters: the parameters of the new axiom. preconditions: the preconditions of the new axiom. effects: the effects of the new axiom. Returns: the newly defined axiom. """ if name is None: name = f'axiom_{len(self.axioms)}' self.axioms[name] = op = Operator(name, parameters, preconditions, effects, is_axiom=True) return op
[docs] def define_macro(self, name: str, parameters: Sequence[Variable], sub_operators: Sequence[OperatorApplier], preconditions: Sequence[Precondition] = tuple(), effects: Sequence[Effect] = tuple()) -> MacroOperator: """Define a new macro. Args: name: the name of the new macro. parameters: the parameters of the new macro. sub_operators: the sub operators of the new macro. preconditions: the preconditions of the new macro. effects: the effects of the new macro. Returns: the newly defined macro. """ self.operators[name] = op = MacroOperator(name, parameters, sub_operators, preconditions=preconditions, effects=effects) return op
[docs] def define_generator( self, name: str, parameters: Sequence[Variable], certifies: ValueOutputExpression, context: Sequence[Union[VariableExpression, ValueOutputExpression]], generates: Sequence[Union[VariableExpression, ValueOutputExpression]], implementation: Optional[Implementation] = None, priority: int = 0, unsolvable: bool = False ) -> Generator: """Define a new generator. Args: name: the name of the new generator. parameters: the parameters of the new generator. certifies: the certified condition of the new generator. context: the context of the new generator. generates: the generates of the new generator. implementation: the implementation of the new generator. priority: the priority of the new generator. unsolvable: whether the new generator is unsolvable. Returns: the newly defined generator. """ if unsolvable: priority = int(1e9) context: List[Union[VariableExpression, ValueOutputExpression]] = cvt_expression_list(context) generates: List[Union[VariableExpression, ValueOutputExpression]] = cvt_expression_list(generates) arguments = [Variable(f'?c{i}', c.return_type) for i, c in enumerate(context)] return_type = [target.return_type for target in generates] if len(return_type) == 1: return_type = return_type[0] else: return_type = TupleType(return_type) output_vars = [Variable(f'?g{i}', g.return_type) for i, g in enumerate(generates)] return_names = [v.name for v in output_vars] if len(return_names) == 1: return_names = return_names[0] identifier = f'generator::{name}' function = Function(identifier, FunctionType(arguments, return_type, return_name=return_names)) all_variables = {c: cv for c, cv in zip(context, arguments)} all_variables.update({g: gv for g, gv in zip(generates, output_vars)}) ctx = ExpressionDefinitionContext(*arguments, *output_vars, domain=self) flatten_certifies = flatten_expression(certifies, all_variables, ctx, flatten_cacheable_expression=True) if not unsolvable and implementation is None: self.external_functions[identifier] = function if name in self.generators: raise ValueError(f'Duplicate generator: {name}.') self.generators[name] = generator = Generator( name, parameters, certifies, context=context, generates=generates, function=function, output_vars=output_vars, flatten_certifies=flatten_certifies, implementation=implementation, priority=priority, unsolvable=unsolvable ) return generator
[docs] def define_fancy_generator( self, name: str, certifies: ValueOutputExpression, implementation: Optional[Implementation] = None, priority: int = 10, unsolvable: bool = False ) -> FancyGenerator: """Declare a new fancy generator. The difference between a fancy generator and a normal generator is that a fancy generator is not directional. That is, it can generate a set of variables satisfies the constraints, without requiring specific `contexts` to `generates` directions. Therefore, we don't need to specify the `context` and `generates` of a fancy generator. Args: name: the name of the new fancy generator. certifies: the certified condition of the new fancy generator. implementation: the implementation of the new fancy generator. priority: the priority of the new fancy generator. unsolvable: whether the new fancy generator is unsolvable. Returns: the newly declared fancy generator. """ if unsolvable: priority = int(1e9) identifier = f'generator::{name}' # TODO(Jiayuan Mao @ 2023/04/04): fix the typing for this. function = Function(identifier, FunctionType([], [])) flatten_certifies = certifies if not unsolvable and implementation is None: self.external_functions[identifier] = function if name in self.generators: raise ValueError(f'Duplicate generator: {name}.') self.fancy_generators[name] = generator = FancyGenerator(name, certifies, function=function, flatten_certifies=flatten_certifies, implementation=implementation, priority=priority, unsolvable=unsolvable) return generator
[docs] def has_generator(self, name: str) -> bool: return name in self.generators or name in self.fancy_generators
[docs] def get_generator(self, name: str) -> Union[Generator, FancyGenerator]: if name in self.generators: return self.generators[name] if name in self.fancy_generators: return self.fancy_generators[name] raise ValueError(f'Generator {name} not found.')
[docs] def declare_external_function(self, function_name: str, argument_types: FunctionArgumentListType, return_type: FunctionReturnType, kwargs: Optional[Dict[str, Any]] = None) -> Function: """Declare an external function. Args: function_name: the name of the external function. argument_types: the argument types of the external function. return_type: the return type of the external function. kwargs: the keyword arguments of the external function. Supported keyword arguments are: - ``observation``: whether the external function is an observation variable. - ``state``: whether the external function is a state variable. """ if kwargs is None: kwargs = dict() self.external_functions[function_name] = Predicate(function_name, FunctionType(argument_types, return_type), **kwargs) return self.external_functions[function_name]
[docs] def declare_external_function_crossref(self, function_name: str, cross_ref_name: str): """Declare a cross-reference to an external function. This is useful when one function is an derived function of another function. Args: function_name: the name of the external function. cross_ref_name: the name of the cross-reference. """ self.external_function_crossrefs[function_name] = cross_ref_name
[docs] def parse(self, string: Union[str, Expression], state: Optional['State'] = None, variables: Optional[Sequence[Variable]] = None) -> Expression: """Parse a string into an expression. Args: string: the string to be parsed. variables: the variables to be used in the expression. Returns: the parsed expression. """ if isinstance(string, Expression): return string if self.pdsketch_version == 2: from concepts.pdsketch.parsers.pdsketch_parser import parse_expression return parse_expression(self, string, variables) elif self.pdsketch_version == 3: from concepts.dm.crow.parsers.crow_parser import parse_expression return parse_expression(self, string, state=state, variables=variables) else: raise ValueError(f'Unknown PDSketch version: {self.pdsketch_version}.')
[docs] def make_executor(self) -> 'PDSketchExecutor': """Make an executor for this domain.""" from concepts.pdsketch.executor import PDSketchExecutor return PDSketchExecutor(self)
[docs] def incremental_define(self, string: str): """Incrementally define new parts of the domain. Args: string: the string to be parsed and defined. """ from concepts.pdsketch.parsers.pdsketch_parser import load_domain_string_incremental load_domain_string_incremental(self, string)
[docs] def incremental_define3(self, string: str): """Incrementally define new parts of the domain using PDSketch3. Args: string: the string to be parsed and defined. """ from concepts.dm.crow.parsers.crow_parser import load_domain_string_incremental return load_domain_string_incremental(self, string)
[docs] def print_summary(self, external_functions: bool = False, full_generators: bool = False): """Print a summary of the domain.""" print(f'Domain {self.name}') stprint(key='Types: ', data=self.types, indent=1, sort_key=False) stprint(key='Functions: ', data=self.functions, indent=1, sort_key=False) if external_functions: stprint(key='External Functions: ', data=self.external_functions, indent=1, sort_key=False) if full_generators: stprint(key='Generators: ', data=self.generators, indent=1, sort_key=False) stprint(key='Fancy Generators: ', data=self.fancy_generators, indent=1, sort_key=False) else: print(' Generators:') if len(self.generators) > 0: for gen in self.generators.values(): print(indent_text(gen.short_str(), level=2)) else: print(' <Empty>') print(' Fancy Generators:') if len(self.fancy_generators) > 0: for gen in self.fancy_generators.values(): print(indent_text(gen.short_str(), level=2)) else: print(' <Empty>') print(' Operators:') if len(self.operators) > 0: for op in self.operators.values(): if not op.is_macro and op.extends is not None: print(indent_text(f'(:action {op.name} extends {op.extends})', level=2)) else: print(indent_text(op.pddl_str(), level=2)) else: print(' <Empty>') print(' Axioms:') if len(self.axioms) > 0: for op in self.axioms.values(): print(indent_text(op.pddl_str(), level=2)) else: print(' <Empty>') print(' Regression Rules:') if len(self.regression_rules) > 0: for op in self.regression_rules.values(): print(indent_text(op.pddl_str(), level=2)) else: print(' <Empty>')
[docs] def post_init(self): """Post-initialization of the domain. This function should be called by the domain generator after all the domain definitions (predicates and operators) are done. Currently, the following post-initialization steps are performed: 1. Analyze the static predicates. """ self._analyze_static_predicates()
def _analyze_static_predicates(self): """Run static analysis on the predicates to determine which predicates are static.""" dynamic = set() for op in itertools.chain(self.operators.values(), self.axioms.values()): if isinstance(op, MacroOperator): continue for eff in op.effects: if isinstance(eff.assign_expr, (AssignExpression, ConditionalAssignExpression)): dynamic.add(eff.assign_expr.predicate.function.name) elif isinstance(eff.assign_expr, DeicticAssignExpression): expr = eff.unwrapped_assign_expr assert isinstance(expr, (AssignExpression, ConditionalAssignExpression)) dynamic.add(expr.predicate.function.name) else: raise TypeError(f'Unknown effect type: {eff.assign_expr}.') # propagate the static predicates. for p in self.functions.values(): if p.is_state_variable: p.mark_static(p.name not in dynamic) else: if p.is_cacheable and p.derived_expression is not None: used_predicates = get_used_state_variables(p.derived_expression) static = True for predicate_def in used_predicates: if not predicate_def.is_static: static = False break p.mark_static(static)
[docs]class Problem(object): """The representation for a planning problem. It contains the set of objects, a inital state (a set of propositions), and a goal expression."""
[docs] def __init__(self, domain: Optional[Domain] = None): """Initialize the problem.""" self.domain = domain self.objects = dict() self.predicates = list() self.goal = None
objects: Dict[str, str] """The set of objects, which are mappings from object names to object type names.""" predicates: List[FunctionApplicationExpression] """The initial state, which is a set of propositions.""" goal: Optional[ValueOutputExpression] """The goal expression."""
[docs] def add_object(self, name: str, typename: str): """Add an object to the problem. Args: name: the name of the object. typename: the type of the object. """ self.objects[name] = typename
[docs] def add_proposition(self, proposition: FunctionApplicationExpression): """Add a proposition to the initial problem. Args: proposition: the proposition to add. """ self.predicates.append(proposition)
[docs] def set_goal(self, goal: ValueOutputExpression): """Set the goal of the problem. Args: goal: the goal expression. """ self.goal = goal
[docs] def to_state(self, executor: 'PDSketchExecutor') -> 'State': """Convert the problem to a :class:`State` object. Args: executor: the executor to use to instantiate the state. Returns: the state object. """ domain = executor.domain object_names = list(self.objects.keys()) object_types = [executor.domain.types[self.objects[name]] for name in object_names] for constant in domain.constants.values(): object_names.append(constant.name) object_types.append(constant.dtype) state = State(None, object_names, object_types) from concepts.pdsketch.executor import StateDefinitionHelper ctx = StateDefinitionHelper(executor, state) predicates = list() for p in self.predicates: predicates.append(ctx.get_predicate(p.function.name)(*[arg.constant.name for arg in p.arguments])) ctx.define_predicates(predicates) return state
[docs]class State(NamedObjectTensorState): """Planning domain state."""
[docs] def init_dirty_feature(self, function: Function): """Initialize a dirty feature. A dirty feature is a cacheable feature but not in the original state representation. The convention for dirty features is that they are initialized with optimistic values being OPTIM_MAGIC_NUMBER_MAGIC. Args: function: the feature to initialize. """ feature_name = function.name return_type = function.return_type if feature_name not in self.features: sizes = list() for arg_def in function.arguments: sizes.append(len(self.object_type2name[arg_def.typename]) if arg_def.typename in self.object_type2name else 0) sizes = tuple(sizes) self.features[feature_name] = tensor = TensorValue.make_empty(return_type, [var.name for var in function.arguments], sizes) tensor.init_tensor_optimistic_values() tensor.tensor_optimistic_values.fill_(OPTIM_MAGIC_NUMBER_MAGIC) self.internals.setdefault('ditry_features', set()).add(feature_name)
[docs] def clone_internals(self): """Clone the internal state of the state.""" rv = super().clone_internals() if 'ditry_features' in rv: rv['ditry_features'] = rv['ditry_features'].copy()
[docs] def simple_quantize(self, domain: Domain, features=None) -> 'State': """Make a quantized version of the state. Args: domain: the planning domain. features: the features to use for quantization. If None, use all state variables. Returns: the quantized state. """ if features is None: features = [name for name in self.features.all_feature_names if domain.functions[name].is_state_variable] new_tensor_dict = dict() for feature_name in features: new_tensor_dict[feature_name] = self.features[feature_name].simple_quantize() return type(self)(self.object_types, new_tensor_dict, self.object_names)
[docs] def generate_tuple_description(self, domain: Domain) -> Tuple[int, ...]: """Generate a tuple description of the state. Args: domain: the planning domain. Returns: the tuple description of the state. """ rv = list() for feature_name in sorted(self.features.all_feature_names): if domain.functions[feature_name].is_state_variable: feature = self.features[feature_name] if isinstance(feature.dtype, TensorValueTypeBase) and feature.dtype.is_intrinsically_quantized(): rv.extend(_maybe_apply_optimistic_mask(feature.tensor, feature.tensor_optimistic_values).flatten().tolist()) elif feature.tensor_quantized_values is not None: rv.extend(_maybe_apply_optimistic_mask(feature.tensor_quantized_values, feature.tensor_optimistic_values).flatten().tolist()) else: raise RuntimeError(f'Cannot generate tuple description for feature {feature_name}.') return tuple(rv)
def _maybe_apply_optimistic_mask(tensor, optimistic_values): if optimistic_values is None: return tensor assert tensor.shape == optimistic_values.shape optimistic_mask = is_optimistic_value(optimistic_values) return torch.where(optimistic_mask, optimistic_values, tensor.to(torch.int64))