Source code for concepts.algorithm.rrt.rrt

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : rrt.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 12/01/2019
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Basic RRT algorithm.

The following algorithms and data structures have a generic interface. The are not designed specifically for a specific robot.
"""

import numpy as np
import numpy.random as npr

from typing import Optional, Tuple, List
from concepts.algorithm.configuration_space import ProblemSpace

__all__ = ['RRTNode', 'RRTTree', 'smooth_path', 'get_smooth_path', 'optimize_path', 'rrt', 'birrt']


[docs] class RRTNode(object):
[docs] def __init__(self, config, parent=None): self.config = config self.parent = parent self.children = list()
[docs] def add_to_children(self, other): self.children.append(other) return self
[docs] def attach_to(self, other): self.parent = other self.parent.add_to_children(self) return self
[docs] def backtrace(self, config=True): path = list() def dfs(x): if x.parent is not None: dfs(x.parent) path.append(x.config if config else x) try: dfs(self) return path finally: del dfs
[docs] @classmethod def from_states(cls, states): if isinstance(states, list): return [cls(s) for s in states] else: return cls(states)
def __repr__(self): return f'RRTNode(config={self.config}, parent={self.parent})'
[docs] def traverse_rrt_bfs(nodes): queue = nodes.copy() results = list() while len(queue) > 0: x = queue[0] queue = queue[1:] results.append(x) for y in x.children: queue.append(y) return results
[docs] class RRTTree(object): """The RRT tree."""
[docs] def __init__(self, pspace, roots): if isinstance(roots, RRTNode): roots = [roots] else: assert isinstance(roots, list) self.pspace = pspace self.roots = roots self.size = len(roots)
[docs] def extend(self, parent, child_config): child = RRTNode(child_config).attach_to(parent) self.size += 1 return child
[docs] def nearest(self, config, pspace=None): if pspace is None: pspace = self.pspace best_node, best_value = None, np.inf for node in traverse_rrt_bfs(self.roots): distance = pspace.distance(node.config, config) if distance < best_value: best_value = distance best_node = node return best_node
[docs] def smooth_path(pspace, path, nr_attemps=100, use_fine_path: bool = False): if nr_attemps is None: nr_attemps = 100 if use_fine_path: path = get_smooth_path(pspace, path) for i in range(nr_attemps): if len(path) <= 2: break # Use the uniform pair sampling method. a, b = npr.randint(0, len(path) - 1), npr.randint(0, len(path) - 1) if a > b: a, b = b, a + 1 else: b = b + 1 # The original version: # a = npr.randint(0, len(path) - 1) # b = npr.randint(a + 1, len(path)) if use_fine_path: success, _, subpath = pspace.try_extend_path(path[a], path[b]) if success: path = path[:a + 1] + subpath[1:] + path[b:] else: if pspace.validate_path(path[a], path[b]): path = path[:a + 1] + path[b:] return path
[docs] def get_smooth_path(pspace, path): cpath = [path[0]] for config in path[1:]: cpath.extend(pspace.cspace.gen_path(cpath[-1], config)[1][1:]) path = cpath return path
[docs] def optimize_path_forward(pspace, path): spath = [path[0]] ptr = 1 while ptr + 1 < len(path): if pspace.validate_path(spath[-1], path[ptr + 1]): ptr += 1 else: _, sub_path = pspace.cspace.gen_path(spath[-1], path[ptr]) spath.append(sub_path[1]) if len(sub_path) == 2: ptr += 1 spath.extend(pspace.cspace.gen_path(spath[-1], path[ptr])[1][1:]) return spath
[docs] def optimize_path(pspace, path, *args, **kwargs): path = get_smooth_path(path) while True: length = len(path) path = optimize_path_forward(pspace, path) path.reverse() path = optimize_path_forward(pspace, path) path.reverse() if len(path) >= length: break return path
[docs] def rrt(pspace, start_state, goal_state, nr_iterations=1000, p_sample_goal=0.05, nr_smooth_iterations=None): rrt_start = RRTTree(pspace, RRTNode.from_states(start_state)) for i in range(nr_iterations): sample_goal = False if npr.uniform() < p_sample_goal: sample_goal = True next_config = goal_state else: next_config = pspace.sample() node = rrt_start.nearest(next_config) success, next_config, _ = pspace.try_extend_path(node.config, next_config) if next_config is not None: new_node = rrt_start.extend(node, next_config) if sample_goal and success: return smooth_path(pspace, new_node.backtrace(), nr_smooth_iterations), rrt_start return None, rrt_start
[docs] def birrt( pspace: ProblemSpace, start_state, goal_state, nr_iterations=1000, nr_smooth_iterations=None, smooth_fine_path=False, verbose: bool = False ) -> Tuple[Optional[List[np.ndarray]], Tuple[RRTTree, RRTTree]]: rrt_start = RRTTree(pspace, RRTNode.from_states(start_state)) rrt_goal = RRTTree(pspace, RRTNode.from_states(goal_state)) swapped = False if verbose: print(f'birrt::input check: {pspace.validate_config(start_state)=}, {pspace.validate_config(goal_state)=}') if not pspace.validate_config(start_state) or not pspace.validate_config(goal_state): return None, (rrt_start, rrt_goal) # import pybulelt as p # p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1) # import ipdb; ipdb.set_trace() if True: # Try to directly connecting the starting state and the goal state. success, next_config, _ = pspace.try_extend_path(start_state, goal_state) if success: if verbose: print('birrt::shortcut: Directly connected the start and goal states.') rrt_start.extend(rrt_start.nearest(goal_state), goal_state) return smooth_path(pspace, [start_state, goal_state], 0, use_fine_path=smooth_fine_path), (rrt_start, rrt_goal) for i in range(nr_iterations): next_config = pspace.sample() node_start = rrt_start.nearest(next_config) success, next_config, _ = pspace.try_extend_path(node_start.config, next_config) if next_config is not None: new_node_start = rrt_start.extend(node_start, next_config) node_goal = rrt_goal.nearest(next_config) success, next_config, _ = pspace.try_extend_path(node_goal.config, next_config) if success: path1 = new_node_start.backtrace() path2 = node_goal.backtrace() path = list(path1) + list(reversed(path2)) # still add the new node to the rrt_goal for better visualization. _ = rrt_goal.extend(node_goal, next_config) if swapped: path.reverse() rrt_start, rrt_goal = rrt_goal, rrt_start return smooth_path(pspace, path, nr_smooth_iterations, use_fine_path=smooth_fine_path), (rrt_start, rrt_goal) elif next_config is not None: _ = rrt_goal.extend(node_goal, next_config) rrt_start, rrt_goal, swapped = rrt_goal, rrt_start, not swapped return None, (rrt_start, rrt_goal)