Source code for concepts.dm.crow.crow_function

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_function.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/16/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

from typing import Optional, Any, Sequence, Tuple, Dict

from jacinle.utils.enum import JacEnum
from jacinle.utils.meta import repr_from_str
from jacinle.utils.printing import indent_text

from concepts.dsl.dsl_types import ObjectType
from concepts.dsl.dsl_functions import FunctionType, Function
from concepts.dsl.expression import ValueOutputExpression

__all__ = ['CrowFunctionEvaluationMode', 'CrowFunctionBase', 'CrowFeature', 'CrowFunction']


[docs] class CrowFunctionEvaluationMode(JacEnum): """The evaluation mode of a function. This enum has three values: - ``FUNCTIONAL``: the function is a pure function. - ``SIMULATION``: the function is a simulation-dependent function, i.e., it is a function that can only be evaluated given the current state in simulation. - ``EXECUTION``: the function is an execution-dependent function, i.e., it is a function that can only be evaluated given the current state in execution. """ FUNCTIONAL = 'functional' SIMULATION = 'simulation' EXECUTION = 'execution'
[docs] @classmethod def from_bools(cls, simulation: bool, execution: bool): if simulation: assert not execution, 'Cannot set both simulation and execution mode.' return cls.SIMULATION elif execution: return cls.EXECUTION else: return cls.FUNCTIONAL
[docs] def get_prefix(self) -> str: if self == self.FUNCTIONAL: return '' return f'[[{self.value}]]'
[docs] class CrowFunctionBase(Function):
[docs] def __init__( self, name: str, ftype: FunctionType, derived_expression: Optional[ValueOutputExpression] = None, ): super().__init__(name, ftype, derived_expression) self.is_static = False self.is_cacheable = self._guess_is_cacheable()
is_static: bool """Whether the function is static (i.e., its grounded value will never change).""" is_cacheable: bool """Whether the function can be cached. Specifically, if it contains only "ObjectTypes" as arguments, it can be statically evaluated.""" def _guess_is_cacheable(self) -> bool: """Return whether the function can be cached. Specifically, if it contains only "ObjectTypes" as arguments, it can be statically evaluated.""" for arg_def in self.arguments: if not isinstance(arg_def.dtype, ObjectType): return False return True
[docs] def mark_static(self, flag: bool = True): """Mark a predicate as static (i.e., its grounded value will never change). Args: flag: Whether to mark the predicate as static. """ self.is_static = flag
@property def is_feature(self) -> bool: """Whether the object is defined as a state feature.""" return False @property def is_function(self) -> bool: """Whether the object is defined as a function.""" return False
[docs] def flags(self, short: bool = False) -> Dict[str, bool]: """Return the flags of the function.""" if short: return { 'static': self.is_static, 'cacheable': self.is_cacheable, } return { 'is_derived': self.is_derived, 'is_static': self.is_static, 'is_cacheable': self.is_cacheable, }
def __str__(self) -> str: flags = self.flags(short=True) flags = ', '.join([f for f, v in flags.items() if v]) if len(flags) > 0: flags = f'[{flags}]' fmt = f'{self.name}{flags}({", ".join([str(arg) for arg in self.arguments])})' if self.is_derived: fmt += ':\n' + indent_text('return ' + str(self.derived_expression)) return fmt __repr__ = repr_from_str
[docs] class CrowFeature(CrowFunctionBase):
[docs] def __init__( self, name: str, ftype: FunctionType, derived_expression: Optional[ValueOutputExpression] = None, observation: Optional[bool] = None, state: Optional[bool] = None, default: Optional[Any] = None, ): super().__init__(name, ftype, derived_expression) self.is_observation_variable = observation if observation is not None else self._guess_is_observation() self.is_state_variable = state if state is not None else self._guess_is_state() self.default = default self._check_argument_types()
is_static: bool is_cacheable: bool is_observation_variable: bool """Whether the feature is an observation variable.""" is_state_variable: bool """Whether the feature is a state variable.""" def _check_argument_types(self): for arg_def in self.arguments: assert isinstance(arg_def.dtype, ObjectType), f'Invalid argument type {arg_def.dtype} for feature {self.name}.' def _guess_is_observation(self) -> bool: """Guess whether the feature is an observation variable.""" return not self.is_derived def _guess_is_state(self) -> bool: """Guess whether the feature is a state variable.""" return True @property def is_feature(self) -> bool: return True
[docs] def flags(self, short: bool = False) -> Dict[str, bool]: flags = super().flags(short) if short: del flags['cacheable'] flags.update({ 'observation': self.is_observation_variable, 'state': self.is_state_variable, }) else: del flags['is_cacheable'] flags.update({ 'is_observation_variable': self.is_observation_variable, 'is_state_variable': self.is_state_variable, }) return flags
[docs] class CrowFunction(CrowFunctionBase):
[docs] def __init__( self, name: str, ftype: FunctionType, derived_expression: Optional[ValueOutputExpression] = None, generator_placeholder: bool = False, inplace_generators: Optional[Sequence[str]] = None, sgc: bool = False, simulation: bool = False, execution: bool = False, ): super().__init__(name, ftype, derived_expression) self.is_generator_placeholder = generator_placeholder self.inplace_generators = tuple(inplace_generators) if inplace_generators is not None else tuple() self.is_sgc_function = sgc self.evaluation_mode = CrowFunctionEvaluationMode.from_bools(simulation, execution)
is_static: bool is_cacheable: bool evaluation_mode: CrowFunctionEvaluationMode """The evaluation mode of the function. This enum has three values:""" is_generator_placeholder: bool """Whether the function is a generator placeholder.""" inplace_generators: Tuple[str, ...] """The list of inplace generators. This is usually used together with generator-placeholder functions.""" is_sgc_function: bool """Whether the function is a SGC (state-goal-constraint) function.""" @property def is_function(self) -> bool: return True @property def is_simulation_dependent(self) -> bool: """Whether the function is simulation-dependent.""" return self.evaluation_mode == CrowFunctionEvaluationMode.SIMULATION @property def is_execution_dependent(self) -> bool: """Whether the function is execution-dependent.""" return self.evaluation_mode == CrowFunctionEvaluationMode.EXECUTION
[docs] def flags(self, short: bool = False) -> Dict[str, bool]: flags = super().flags(short) if short: flags.update({ 'gen': self.is_generator_placeholder, 'sgc': self.is_sgc_function, 'sim': self.is_simulation_dependent, 'exe': self.is_execution_dependent, }) else: flags.update({ 'is_generator_placeholder': self.is_generator_placeholder, 'is_sgc_function': self.is_sgc_function, 'is_simulation_dependent': self.is_simulation_dependent, 'is_execution_dependent': self.is_execution_dependent, }) return flags