Source code for concepts.benchmark.manip_tabletop.paint_factory.crow_domains.paint_factory_crow_interface

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : paint_factory_crow_interface.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 05/03/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import cv2
import torch
import numpy as np

from concepts.benchmark.manip_tabletop.paint_factory.paint_factory import PaintFactoryEnv
from concepts.dm.crow.interfaces.perception_interface import CrowPerceptionInterface
from concepts.dm.crow.crow_domain import CrowDomain, CrowState

__all__ = ['PaintFactoryPerceptionInterface']


[docs] class PaintFactoryPerceptionInterface(CrowPerceptionInterface):
[docs] def __init__(self, env: PaintFactoryEnv, domain: CrowDomain): super().__init__() self.env = env self.domain = domain
[docs] def get_crow_state(self) -> CrowState: """Get the state of the perception interface.""" objects = dict() objects['robot'] = 'Robot' for _, item_info in self.env.iter_objects(): objects[item_info['name']] = 'Object' state = CrowState.make_empty_state(self.domain, objects) state.fast_set_value('robot_identifier', ['robot'], self.env.robot.get_body_id()) for item_id, item_info in self.env.iter_objects(): state.fast_set_value('object_pose', [item_info['name']], torch.tensor(self.env.world.get_body_state_by_id(item_id).pos[:2])) # Only use (x, y). state.fast_set_value('object_identifier', [item_info['name']], item_id) if item_info['name'].startswith('block'): state.fast_set_value('is_block', [item_info['name']], True) elif item_info['name'].startswith('bowl'): state.fast_set_value('is_machine', [item_info['name']], True) elif item_info['name'] == 'target-box': state.fast_set_value('is_target', [item_info['name']], True) elif item_info['name'] in ('clean-box', 'dry-box'): state.fast_set_value('is_machine', [item_info['name']], True) else: raise ValueError('Unknown object name: {}'.format(item_info['name'])) img, depth, segm = self.env.world.render_image(self.env.topdown_camera, normalize_depth=True) tensors = list() for item_id, item_info in self.env.iter_objects(): item_segm = (segm == item_id) # Use OpenCV instead to get the bounding box bounding_box = cv2.boundingRect(item_segm.astype(np.uint8)) # (x, y, w, h) bounding_box = (bounding_box[0], bounding_box[1], bounding_box[0] + bounding_box[2], bounding_box[1] + bounding_box[3]) # Expand the bounding box by 5 pixels and clip to image size bounding_box = ( max(0, bounding_box[0] - 5), max(0, bounding_box[1] - 5), min(img.shape[1], bounding_box[2] + 5), min(img.shape[0], bounding_box[3] + 5), ) item_img = img[bounding_box[1]:bounding_box[3], bounding_box[0]:bounding_box[2]] item_img = cv2.resize(item_img, (32, 32), interpolation=cv2.INTER_AREA) item_img = item_img.flatten() tensors.append(torch.tensor(item_img, dtype=torch.float32)) state.batch_set_value('object_image', torch.stack(tensors, dim=0)) return state