#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : obj2urdf.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 08/25/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""
Convert obj file to urdf file.
Copied from https://github.com/harvard-microrobotics/object2urdf
"""
import numpy as np
import tempfile
import os
import copy
trimesh_available = True
try:
    import trimesh
except ImportError:
    trimesh_available = False
    trimesh = None
import xml.etree.ElementTree as ET
from scipy.spatial.transform import Rotation
__all__ = ['ObjectUrdfBuilder']
[docs]
class ObjectUrdfBuilder(object):
[docs]
    def __init__(self, object_folder, log_file=None, urdf_prototype=None, use_trimesh_vhacd: bool = True, use_pybullet_vhacd: bool = True):
        if not trimesh_available:
            raise ImportError('trimesh is not available. Please install trimesh to use this module.')
        self.object_folder = os.path.abspath(object_folder)
        if log_file is not None:
            self.log_file = os.path.abspath(log_file)
        else:
            f = tempfile.NamedTemporaryFile('r', suffix='log', prefix='vhacd', delete=False)
            f.close()
            self.log_file = f.name
        self.suffix = "vhacd"
        self.use_trimesh_vhacd = use_trimesh_vhacd
        self.use_pybullet_vhacd = use_pybullet_vhacd
        if urdf_prototype is None:
            self.urdf_base = self._read_xml(os.path.join(os.path.dirname(__file__), "urdf_prototype.urdf"))
        else:
            self.urdf_base = self._read_xml(os.path.join(object_folder, urdf_prototype)) 
    # Recursively get all files with a specific extension, excluding a certain suffix
    def _get_files_recursively(self, start_directory, filter_extension, exclude_suffix):
        for root, dirs, files in os.walk(start_directory):
            for file in files:
                if file.lower().endswith(filter_extension):
                    if not file.lower().endswith(exclude_suffix + filter_extension):
                        yield (root, file, os.path.abspath(os.path.join(root, file)))
    # Read and parse a URDF from a file
    def _read_xml(self, filename):
        root = ET.parse(filename).getroot()
        return root
    # Convert a list to a space-separated string
    def _list2str(self, in_list):
        out = ""
        for el in in_list:
            out += str(el) + " "
        return out[:-1]
    # Convert a space-separated string to a list
    def _str2list(self, in_str):
        out = in_str.split(' ')
        out = [float(el) for el in out]
        return out
    # Find the center of mass of the object
[docs]
    def get_center_of_mass(self, filename):
        mesh = trimesh.load(filename)
        if isinstance(mesh, trimesh.Scene):
            print("Imported combined mesh: using centroid rather than center of mass")
            return mesh.centroid
        else:
            return mesh.center_mass 
    # Find the geometric center of the object
[docs]
    def get_geometric_center(self, filename):
        mesh = trimesh.load(filename)
        return copy.deepcopy(mesh.centroid) 
    # Get the middle of a face of the bounding box
[docs]
    def get_face(self, filename, edge):
        mesh = trimesh.load(filename)
        bounds = mesh.bounds
        face = copy.deepcopy(mesh.centroid)
        if edge in ['top', 'xy_pos']:
            face[2] = bounds[1][2]
        elif edge in ['bottom', 'xy_neg']:
            face[2] = bounds[0][2]
        elif edge in ['xz_pos']:
            face[1] = bounds[1][1]
        elif edge in ['xz_neg']:
            face[1] = bounds[0][1]
        elif edge in ['yz_pos']:
            face[0] = bounds[1][0]
        elif edge in ['yz_neg']:
            face[0] = bounds[0][0]
        return face 
    # Do a convex decomposition
[docs]
    def do_vhacd(self, filename, outfile, debug=False, **kwargs):
        if trimesh_available and self.use_trimesh_vhacd:
            try:
                mesh = trimesh.load(filename)
                convex_list = trimesh.interfaces.vhacd.convex_decomposition(mesh, debug=debug, **kwargs)
                convex = trimesh.util.concatenate(convex_list)
                convex.export(outfile)
            except AttributeError:
                print("No direct VHACD backend available, trying pybullet")
                pass
            except ValueError:
                print("No direct VHACD backend available, trying pybullet")
                pass
        if self.use_pybullet_vhacd:
            try:
                import pybullet as p
                p.vhacd(filename, outfile, self.log_file, **kwargs)
            except ModuleNotFoundError:
                print(
                    '\n' + "ERROR - pybullet module not found: If you want to do convex decomposisiton, make sure you install pybullet (https://pypi.org/project/pybullet) or install VHACD directly (https://github.com/mikedh/trimesh/issues/404)" + '\n'
                )
                raise 
    # Find the center of mass of the object
[docs]
    def save_to_obj(self, filename):
        name, ext = os.path.splitext(filename)
        obj_filename = name + '.obj'
        mesh = trimesh.load(filename)
        mesh.export(obj_filename)
        return obj_filename 
    # Replace an attribute in a feild of a URDF
[docs]
    def replace_urdf_attribute(self, urdf, feild, attribute, value):
        urdf = self.replace_urdf_attributes(urdf, feild, {attribute: value})
        return urdf 
    # Replace several attributes in a feild of a URDF
[docs]
    def replace_urdf_attributes(self, urdf, feild, attribute_dict, sub_feild=None):
        if sub_feild is None:
            sub_feild = []
        field_obj = urdf.find(feild)
        if field_obj is not None:
            if len(sub_feild) > 0:
                for child in reversed(sub_feild):
                    field_obj = ET.SubElement(field_obj, child)
            field_obj.attrib.update(attribute_dict)
            # field_obj.attrib = attribute_dict
        else:
            feilds = feild.split("/")
            new_feild = "/".join(feilds[0:-1])
            sub_feild.append(feilds[-1])
            self.replace_urdf_attributes(urdf, new_feild, attribute_dict, sub_feild) 
    # Make an updated copy of the URDF for the current object
[docs]
    def update_urdf(self, object_file, object_name, collision_file=None, override=None, mass_center=None):
        # If no separate collision geometry is provided, use the object file
        if collision_file is None:
            collision_file = object_file
        # Update the filenames and object name
        new_urdf = copy.deepcopy(self.urdf_base)
        self.replace_urdf_attribute(new_urdf, './/visual/geometry/mesh', 'filename', object_file)
        self.replace_urdf_attribute(new_urdf, './/collision/geometry/mesh', 'filename', collision_file)
        new_urdf.attrib['name'] = object_name
        # Update the overrides
        if override is not None:
            for orverride_el in override:
                # Update attributes
                out_el_all = new_urdf.findall('.//' + orverride_el.tag)
                for out_el in out_el_all:
                    for key in orverride_el.attrib:
                        out_el.set(key, orverride_el.attrib[key])
                    # Remove fields that will be overwritten
                    for child in orverride_el:
                        el = out_el.find(child.tag)
                        if el is not None:
                            out_el.remove(el)
                    # Add updated feilds
                    out_el.extend(orverride_el)
        # Output the center of mass if provided
        if mass_center is not None:
            # Check if there's a geometry offset
            offset_ob = new_urdf.find('.//collision/origin')
            if offset_ob is not None:
                offset_str = offset_ob.attrib.get('xyz', '0 0 0')
                rot_str = offset_ob.attrib.get('rpy', '0 0 0')
                offset = self._str2list(offset_str)
                rpy = self._str2list(rot_str)
            else:
                offset = [0, 0, 0]
                rpy = [0, 0, 0]
            # Check if there's a scale factor and apply it
            scale_ob = new_urdf.find('.//collision/geometry/mesh')
            if scale_ob is not None:
                scale_str = scale_ob.attrib.get('scale', '1 1 1')
                scale = self._str2list(scale_str)
            else:
                scale = [1, 1, 1]
            for idx, axis in enumerate(mass_center):
                mass_center[idx] = -mass_center[idx] * scale[idx] + offset[idx]
            rot = Rotation.from_euler('xyz', rpy)
            rot_matrix = rot.as_matrix()
            mass_center = np.matmul(rot_matrix, np.vstack(np.asarray(mass_center))).squeeze()
            self.replace_urdf_attributes(
                new_urdf,
                './/visual/origin',
                {'xyz': self._list2str(mass_center), 'rpy': self._list2str(rpy)}
            )
            self.replace_urdf_attributes(
                new_urdf,
                './/collision/origin',
                {'xyz': self._list2str(mass_center), 'rpy': self._list2str(rpy)}
            )
        return new_urdf 
    # Save a URDF to a file
[docs]
    def save_urdf(self, new_urdf, filename, overwrite=False):
        out_file = os.path.join(self.object_folder, filename)
        # Do not overwrite the file unless the option is True
        if os.path.exists(out_file) and not overwrite:
            return
        # Save the file
        mydata = ET.tostring(new_urdf)
        with open(out_file, "wb") as f:
            f.write(mydata) 
    # Build a URDF from an object file
[docs]
    def build_urdf(
        self, filename, output_folder=None,
        force_overwrite=False, decompose_concave=False, force_decompose=False,
        center='mass', **kwargs
    ):
        # If no output folder is specified, use the base object folder
        if output_folder is None:
            output_folder = self.object_folder
        # Generate a relative path from the output folder to the geometry files
        filename = os.path.abspath(filename)
        common = os.path.commonprefix([output_folder, filename])
        rel = os.path.join(filename.replace(common, ''))
        if rel[0] == os.path.sep:
            rel = rel[1:]
        name = rel.split(os.path.sep)[0]
        rel = rel.replace(os.path.sep, '/')
        file_name_raw, file_extension = os.path.splitext(filename)
        # If an override file exists, include its data in the URDF
        override_file = filename.replace(file_extension, '.ovr')
        if os.path.exists(override_file):
            overrides = self._read_xml(override_file)
        else:
            overrides = None
        # Calculate the center of mass
        if center == 'mass':
            mass_center = self.get_center_of_mass(filename)
        elif center == 'geometric':
            mass_center = self.get_geometric_center(filename)
        elif center in ['top', 'bottom', 'xy_pos', 'xy_neg', 'xz_pos', 'xz_neg', 'yz_pos', 'yz_neg']:
            mass_center = self.get_face(filename, center)
        else:
            mass_center = None
        # mesh = trimesh.load(filename)
        # print(mesh.bounds)
        # If the user wants to run convex decomposition on concave objects, do it.
        if decompose_concave:
            if file_extension == '.stl' or file_extension == '.STL':
                obj_filename = self.save_to_obj(filename)
                visual_file = rel.replace(file_extension, '.obj')
            elif file_extension == '.obj':
                obj_filename = filename
                visual_file = rel
            else:
                raise ValueError("Your filetype needs to be an STL or OBJ to perform concave decomposition")
            outfile = obj_filename.replace('.obj', '_' + self.suffix + '.obj')
            collision_file = visual_file.replace('.obj', '_' + self.suffix + '.obj')
            # Only run a decomposition if one does not exist, or if the user forces an overwrite
            if not os.path.exists(outfile) or force_decompose:
                self.do_vhacd(obj_filename, outfile, **kwargs)
            urdf_out = self.update_urdf(visual_file, name, collision_file=collision_file, override=overrides, mass_center=mass_center)
        else:
            urdf_out = self.update_urdf(rel, name, override=overrides, mass_center=mass_center)
        self.save_urdf(urdf_out, name + '.urdf', force_overwrite) 
    # Build the URDFs for all objects in your library.
[docs]
    def build_library(self, **kwargs):
        print("\nFOLDER: %s" % (self.object_folder))
        # Get all OBJ files
        obj_files = self._get_files_recursively(self.object_folder, filter_extension='.obj', exclude_suffix=self.suffix)
        stl_files = self._get_files_recursively(self.object_folder, filter_extension='.stl', exclude_suffix=self.suffix)
        obj_folders = []
        for root, _, full_file in obj_files:
            obj_folders.append(root)
            self.build_urdf(full_file, **kwargs)
            common = os.path.commonprefix([self.object_folder, full_file])
            rel = os.path.join(full_file.replace(common, ''))
            print('\tBuilding: %s' % (rel))
        for root, _, full_file in stl_files:
            if root not in obj_folders:
                self.build_urdf(full_file, **kwargs)
                common = os.path.commonprefix([self.object_folder, full_file])
                rel = os.path.join(full_file.replace(common, ''))
                print('Building: %s' % (rel))