Source code for concepts.dm.crow.interfaces.controller_interface

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : controller_interface.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/16/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""A controller interface connects the controller commands output by a policy or a planner to the robot simulation/physical system.
Here, we distinguish between the simulation interface and the physical interface by whether they support state save/restore.
"""

import contextlib
from typing import Optional, Union, Sequence, Tuple, List, Dict, Callable
from dataclasses import dataclass

from concepts.dsl.tensor_value import TensorValue
from concepts.dm.crow.behavior import CrowEffectApplier
from concepts.dm.crow.controller import CrowControllerApplier
from concepts.dm.crow.crow_domain import CrowState
from concepts.dm.crow.executors.crow_executor import CrowExecutor

__all__ = ['CrowControllerExecutionError', 'CrowControllerInterfaceBase', 'CrowSimulationControllerInterface', 'CrowPhysicalControllerInterface']


[docs] class CrowControllerExecutionError(Exception): pass
[docs] @dataclass class CrowControllerReflex(object): callable: Callable priority: int uid: int
[docs] class CrowControllerReflexCaller(object):
[docs] def __init__(self, reflex_list: List[CrowControllerReflex]): self._reflex_list = reflex_list
[docs] def __call__(self, *args, **kwargs): for reflex in self._reflex_list: result = reflex.callable(*args, **kwargs) if result is not None: return result
[docs] class CrowControllerInterfaceBase(object): """The base class for all controller interfaces. The convention of the controller interface is that it takes a controller name and a list of arguments, and then calls the corresponding controller function with the arguments. If the execution fails, it should raise an exception. """
[docs] def __init__(self, executor: Optional[CrowExecutor] = None, mock: bool = False): self._executor = executor self._controllers = dict() self._controller_reflexes = dict() self._mock = mock self._controller_reflexes_uid = 0
@property def executor(self) -> Optional[CrowExecutor]: return self._executor @property def controllers(self) -> Dict[str, Callable]: return self._controllers @property def controller_reflexes(self) -> Dict[Tuple[str, str], CrowControllerReflex]: return self._controller_reflexes
[docs] def reset(self): pass
[docs] def register_controller(self, name: str, function: Callable): self.controllers[name] = function return self
[docs] def register_reflex(self, controller_name: str, reflex_name: str, function: Callable, priority: int = 10): self._controller_reflexes_uid += 1 self.controller_reflexes[(controller_name, reflex_name)] = CrowControllerReflex(function, priority, self._controller_reflexes_uid) return self
[docs] def reflex(self, controller_name: str, reflex_name: Optional[Union[str, Sequence[str]]] = None): if reflex_name is None: functions = [reflex for (name, _), reflex in self.controller_reflexes.items() if name == controller_name] elif isinstance(reflex_name, str): functions = [self.controller_reflexes[(controller_name, reflex_name)]] else: functions = [self.controller_reflexes[(controller_name, name)] for name in reflex_name] functions.sort(key=lambda x: (x.priority, x.uid), reverse=True) return CrowControllerReflexCaller(functions)
[docs] def step(self, action: CrowControllerApplier, **kwargs) -> Optional[bool]: if isinstance(action, CrowEffectApplier): return return self.step_internal(action.name, *action.arguments, **kwargs)
[docs] def step_without_error(self, action: CrowControllerApplier, **kwargs) -> bool: try: self.step(action, **kwargs) except CrowControllerExecutionError: return False return True
[docs] def step_internal(self, name: str, *args, **kwargs) -> Optional[bool]: if self._mock: print('Mocked controller:', name) args = [arg.item() if isinstance(arg, TensorValue) and arg.dtype.is_pyobj_value_type else arg for arg in args] for i, arg in enumerate(args): print(f' Arg {i}: {arg}') return if name not in self.controllers: raise ValueError(f"Controller {name} not found.") args = [arg.item() if isinstance(arg, TensorValue) and arg.dtype.is_pyobj_value_type else arg for arg in args] return self.controllers[name](*args, **kwargs)
[docs] class CrowSimulationControllerInterface(CrowControllerInterfaceBase):
[docs] def __init__(self, executor: Optional[CrowExecutor] = None): super().__init__(executor) self._action_counter = 0
[docs] def step_with_saved_state(self, action: CrowControllerApplier, **kwargs) -> Tuple[bool, int]: """Step with saved state. If the execution fails, return False and the state identifier. Args: action: the action to take. Returns: bool: whether the execution is successful. int: the state identifier. """ state_identifier = self.save_state(**kwargs) try: self.step(action, **kwargs) except CrowControllerExecutionError: return False, state_identifier return True, state_identifier
[docs] def step(self, action: CrowControllerApplier, **kwargs) -> None: try: return super().step(action, **kwargs) finally: self.increment_action_counter()
[docs] def reset_action_counter(self): self._action_counter = 0
[docs] def get_action_counter(self) -> int: return self._action_counter
[docs] def increment_action_counter(self): self._action_counter += 1
[docs] def save_state(self, **kwargs) -> int: raise NotImplementedError
[docs] def restore_state(self, state_identifier: int, **kwargs): raise NotImplementedError
[docs] def restore_state_keep(self, state_identifier: int, action_counter: Optional[int] = None, **kwargs): raise NotImplementedError
[docs] def get_crow_state(self) -> CrowState: """Get the state of the simulation interface.""" raise NotImplementedError()
[docs] @contextlib.contextmanager def restore_context(self, verbose: bool = False, **kwargs): state_identifier = self.save_state() action_counter = self._action_counter try: yield finally: self.restore_state(state_identifier) self._action_counter = action_counter
[docs] class CrowPhysicalControllerInterface(CrowControllerInterfaceBase): pass