#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : manipulator_interface.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 07/23/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import numpy as np
from typing import Any, Optional, Union, Tuple
from concepts.utils.typing_utils import VecNf, Vec3f, Vec4f
from concepts.dm.crowhat.manipulation_utils.pose_utils import canonicalize_pose, pose_difference
[docs]
class PlanningWorldInterface(object):
[docs]
    def get_objects(self):
        return self._get_objects() 
    def _get_objects(self):
        raise NotImplementedError()
[docs]
    def get_object_pose(self, identifier: Union[str, int]) -> Tuple[Vec3f, Vec4f]:
        return self._get_object_pose(identifier) 
    def _get_object_pose(self, identifier: Union[str, int]) -> Tuple[Vec3f, Vec4f]:
        raise NotImplementedError()
[docs]
    def set_object_pose(self, identifier: Union[str, int], pose: Tuple[Vec3f, Vec4f]):
        self._set_object_pose(identifier, pose) 
    def _set_object_pose(self, identifier: Union[str, int], pose: Tuple[Vec3f, Vec4f]):
        raise NotImplementedError()
 
[docs]
class GeneralArmArmMotionPlanningInterface(object):
    """General interface for robot arms. It specifies a set of basic operations for motion planning:
    - ``ik``: inverse kinematics.
    - ``fk``: forward kinematics.
    - ``jac``: jacobian matrix.
    - ``coriolis``: coriolis torque.
    - ``mass``: mass matrix.
    """
    @property
    def nr_joints(self) -> int:
        return self.get_nr_joints()
[docs]
    def get_nr_joints(self) -> int:
        raise NotImplementedError() 
    @property
    def joint_limits(self) -> Tuple[np.ndarray, np.ndarray]:
        lower, upper = self.get_joint_limits()
        return np.asarray(lower), np.asarray(upper)
[docs]
    def get_joint_limits(self) -> Tuple[np.ndarray, np.ndarray]:
        raise NotImplementedError() 
[docs]
    def fk(self, qpos: VecNf) -> Tuple[Vec3f, Vec4f]:
        return self._fk(np.asarray(qpos)) 
    def _fk(self, qpos: np.ndarray) -> Tuple[Vec3f, Vec4f]:
        raise NotImplementedError()
[docs]
    def ik(self, pos: Union[Vec3f, Tuple[Vec3f, Vec4f]], quat: Optional[Vec4f] = None, qpos: Optional[VecNf] = None, max_distance: Optional[float] = None) -> np.ndarray:
        pos, quat = canonicalize_pose(pos, quat)
        return self._ik(pos, quat, qpos, max_distance=max_distance) 
    def _ik(self, pos: np.ndarray, quat: np.ndarray, qpos: Optional[np.ndarray] = None, max_distance: Optional[float] = None) -> np.ndarray:
        raise NotImplementedError()
[docs]
    def jacobian(self, qpos: VecNf) -> np.ndarray:
        return self._jacobian(np.asarray(qpos)) 
    def _jacobian(self, qpos: np.ndarray) -> np.ndarray:
        raise NotImplementedError()
[docs]
    def coriolis(self, qpos: VecNf, qvel: VecNf) -> np.ndarray:
        return self._coriolis(np.asarray(qpos), np.asarray(qvel)) 
    def _coriolis(self, qpos: np.ndarray, qvel: np.ndarray) -> np.ndarray:
        raise NotImplementedError()
[docs]
    def mass(self, qpos: VecNf) -> np.ndarray:
        return self._mass(np.asarray(qpos)) 
    def _mass(self, qpos: np.ndarray) -> np.ndarray:
        raise NotImplementedError()
[docs]
    def differential_ik_qpos_diff(self, current_qpos: VecNf, current_pose: Tuple[Vec3f, Vec4f], next_pose: Tuple[Vec3f, Vec4f]) -> np.ndarray:
        """Use the differential IK to compute the joint difference for the given pose difference."""
        current_pose = canonicalize_pose(current_pose)
        next_pose = canonicalize_pose(next_pose)
        J = self.jacobian(current_qpos)  # 6 x N
        solution = np.linalg.lstsq(J, pose_difference(current_pose, next_pose), rcond=None)[0]
        return solution 
 
[docs]
class MotionPlanningResult(object):
[docs]
    def __init__(self, success: bool, result: Any, error: str = ''):
        self.succeeded = success
        self.result = result
        self.error = error 
    def __bool__(self):
        return self.succeeded
    def __str__(self):
        if self.succeeded:
            return f'MotionPlanningResult(SUCC: {self.result})'
        return f'MotionPlanningResult(FAIL: error="{self.error}")'
    def __repr__(self):
        return str(self)
[docs]
    @classmethod
    def success(cls, result: Any):
        return cls(True, result) 
[docs]
    @classmethod
    def fail(cls, error: str):
        return cls(False, None, error)