Source code for concepts.pdsketch.planners.discrete_search

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : discrete_search.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 11/17/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Discrete-space search PDSketch planners."""

import jacinle
import jactorch
import heapq as hq
from typing import Optional, Any, Union, Tuple, Sequence, List, Dict, Callable
from dataclasses import dataclass

from concepts.dsl.expression import ValueOutputExpression
from concepts.dsl.tensor_value import TensorValue
from concepts.algorithm.search.heuristic_search import QueueNode
from concepts.pdsketch.operator import OperatorApplier, gen_all_grounded_actions
from concepts.pdsketch.domain import State
from concepts.pdsketch.executor import PDSketchExecutor
from concepts.pdsketch.planners.solution_score_tracker import MostPromisingTrajectoryTracker
from concepts.pdsketch.csp_solvers.brute_force_sampling import ContinuousValueDict
from concepts.pdsketch.strips.strips_expression import SState
from concepts.pdsketch.strips.strips_grounding import GStripsTranslatorOptimistic
from concepts.pdsketch.strips.strips_heuristics import StripsHeuristic
from concepts.pdsketch.strips.strips_search import get_priority_func

__all__ = [
    'apply_action', 'goal_test', 'prepare_search',
    'validate_plan', 'brute_force_search'
]



[docs] def apply_action(executor: PDSketchExecutor, state: State, action: OperatorApplier, forward_derived=True) -> Tuple[bool, State]: """Apply an action to a state. If success, the function also forwards the axioms and optionally derived predicates. Args: executor: the executor. state: the state to be applied. action: the action to be applied. forward_derived: whether to forward the derived predicates. Returns: a tuple of (success, new_state). """ succ, ns = executor.apply(action, state) # if succ: # ns = executor.forward_predicates_and_axioms(ns, forward_state_variables=False, forward_axioms=True, forward_derived=forward_derived) return succ, ns
[docs] def goal_test( executor: PDSketchExecutor, state: State, goal_expr: ValueOutputExpression, trajectory: Sequence[OperatorApplier], mpt_tracker: Optional[MostPromisingTrajectoryTracker] = None, verbose=False ) -> bool: """Test whether a state satisfies a goal expression. Args: executor: the executor. state: the state to be tested. goal_expr: the goal expression. trajectory: the trajectory that leads to the current state. mpt_tracker: the tracker for the most promising trajectory. verbose: whether to print the verbose information. Returns: True if the state satisfies the goal expression. """ score = executor.execute(goal_expr, state).item() threshold = 0.5 if mpt_tracker is not None: if mpt_tracker.check(score): mpt_tracker.update(score, trajectory) threshold = mpt_tracker.threshold return score > threshold
[docs] @jactorch.no_grad_func def validate_plan( executor: PDSketchExecutor, state: State, goal_expr: Union[str, ValueOutputExpression], actions: Sequence[OperatorApplier], *, forward_state_variables=True, forward_derived=True, ) -> Tuple[State, TensorValue]: """Validate a plan by executing it on the given state. Args: executor: the executor. state: the initial state. goal_expr: the goal expression. actions: the sequence of actions to execute. forward_state_variables: whether to forward state variables. forward_derived: whether to forward derived predicates. Returns: the final state and the goal value (the execution result of the goal expression). """ # TODO(Jiayuan Mao @ 11/28): update!. # state = executor.forward_predicates_and_axioms(state, forward_state_variables, False, forward_derived) if isinstance(goal_expr, str): goal_expr = executor.domain.parse(goal_expr) else: assert isinstance(goal_expr, ValueOutputExpression) for action in actions: succ, state = apply_action(executor, state, action, forward_derived=forward_derived) assert succ, f'Action application failed: {action}.' score = executor.execute(goal_expr, state) return state, score
[docs] @dataclass class HeuristicSearchState(object): """The state for heuristic search.""" state: State """The state.""" strips_state: SState """The STRIPS state.""" trajectory: Tuple[OperatorApplier, ...] """The trajectory.""" g: float """The cost so far."""
[docs] @jactorch.no_grad_func def heuristic_search_strips( executor: PDSketchExecutor, state: State, goal_expr: Union[str, ValueOutputExpression], strips_heuristic: str = 'hff', *, max_expansions: int = 100000, max_depth: int = 100, # search related parameters. heuristic_weight: float = float('inf'), # heuristic related parameters. external_heuristic_function: Callable[[State, ValueOutputExpression], int] = None, # external heuristic related parameters. actions: Optional[Sequence[OperatorApplier]] = None, continuous_values: Optional[ContinuousValueDict] = None, action_filter: Callable[[OperatorApplier], bool] = None, strips_forward_relevance_analysis: bool = False, strips_backward_relevance_analysis: bool = True, strips_use_sas: bool = False, # whether to use SAS Strips compiler (AODiscretization) use_strips_op: bool = False, use_tuple_desc: bool = True, # pruning related parameters. forward_state_variables: bool = True, forward_derived: bool = False, # initialization related parameters. track_most_promising_trajectory: bool = False, prob_goal_threshold: float = 0.5, # non-optimal trajectory tracking related parameters. verbose: bool = False, return_extra_info: bool = False ) -> Union[ Optional[Sequence[OperatorApplier]], Tuple[Optional[Sequence[OperatorApplier]], Dict[str, Any]] ]: """Perform heuristic search with STRIPS-based heuristics. Args: executor: the executor. state: the initial state. goal_expr: the goal expression. strips_heuristic: the heuristic to use. Should be a string. Use 'external' to use the external heuristic function. max_expansions: the maximum number of expanded nodes. max_depth: the maximum depth of the search. heuristic_weight: the weight of the heuristic. Use float('inf') to do greedy best-first search. external_heuristic_function: the external heuristic function. Should be a function that takes in a state and a goal expression, and returns an integer. actions: the actions to use. If None, use all possible actions. continuous_values: the continuous values for action parameters. If None, all action parameters should be discrete. action_filter: the action filter. If None, use all possible actions. It should be a function that takes in an action and returns a boolean. strips_forward_relevance_analysis: whether to perform forward relevance analysis when translating the problem into STRIPS. strips_backward_relevance_analysis: whether to perform backward relevance analysis when translating the problem into STRIPS. strips_use_sas: whether to use SAS Strips compiler (AODiscretization). use_strips_op: whether to use STRIPS operators when applying actions. Recommended to be False. use_tuple_desc: whether to use tuple description to prune the search space. forward_state_variables: whether to forward state variables before the search starts. forward_derived: whether to forward derived predicates after applying actions. track_most_promising_trajectory: whether to track the most promising trajectory. prob_goal_threshold: the probability threshold for the most promising trajectory. When there is a trajectory with probability greater than this threshold, the search will stop. verbose: whether to print verbose information. return_extra_info: whether to return extra information, such as the number of expanded nodes. Returns: the trajectory if succeeded, otherwise None. When `return_extra_info` is True, return a tuple of (trajectory, extra_info), where extra_info is a dictionary. """ state, goal_expr, actions = prepare_search( 'hsstrips', executor, state, goal_expr, actions=actions, action_filter=action_filter, continuous_values=continuous_values, forward_state_variables=forward_state_variables, forward_derived=forward_derived, verbose=verbose ) if strips_use_sas: raise NotImplementedError('SAS strips is not implemented yet.') else: strips_translator = GStripsTranslatorOptimistic(executor, use_string_name=True, prob_goal_threshold=prob_goal_threshold) strips_task = strips_translator.compile_task( state, goal_expr, actions, is_relaxed=False, forward_relevance_analysis=strips_forward_relevance_analysis, backward_relevance_analysis=strips_backward_relevance_analysis, ) if strips_heuristic == 'external' and external_heuristic_function is not None: pass else: heuristic = StripsHeuristic.from_type( strips_heuristic, strips_task, strips_translator, forward_relevance_analysis=strips_forward_relevance_analysis, backward_relevance_analysis=strips_backward_relevance_analysis, ) # from IPython import embed; embed() # import ipdb; ipdb.set_trace() # print(strips_task.goal) # print(strips_task.operators) # print(heuristic.relaxed.goal) # print(heuristic.relaxed.operators) mpt_tracker = None if track_most_promising_trajectory: mpt_tracker = MostPromisingTrajectoryTracker(True, prob_goal_threshold) initial_state = HeuristicSearchState(state, strips_task.state, tuple(), 0) queue: List[QueueNode] = list() visited = set() if strips_heuristic == 'external' and external_heuristic_function is not None: def heuristic_fn(state: HeuristicSearchState, goal_expr=goal_expr) -> int: return external_heuristic_function(state.state, goal_expr) else: def heuristic_fn(state: HeuristicSearchState) -> int: return heuristic.compute(state.strips_state) priority_func = get_priority_func(heuristic_fn, heuristic_weight) def push_node(node: HeuristicSearchState): added = False if use_tuple_desc: nst = node.state.generate_tuple_description(executor.domain) if nst not in visited: added = True visited.add(nst) else: # unconditionally expand added = True if added: hq.heappush(queue, QueueNode(priority_func(node, node.g), node)) if heuristic_search_strips.DEBUG: print(' hsstrips::push_node:', *node.trajectory) print(' ', 'heuristic =', heuristic.compute(node.strips_state), 'g =', node.g) push_node(initial_state) nr_expanded_states = 0 nr_tested_actions = 0 def wrap_extra_info(trajectory): if return_extra_info: return trajectory, {'nr_expansions': nr_expanded_states, 'nr_tested_actions': nr_tested_actions} return trajectory pbar = None if verbose: pbar = jacinle.tqdm_pbar(desc='heuristic_search::expanding') while len(queue) > 0 and nr_expanded_states < max_expansions: priority, node = hq.heappop(queue) nr_expanded_states += 1 """ Name convention: - node: current node. - nnode: next node. - s: the state of the search tree. - ns: the state of the search tree after the action is applied. - a: the action applied. - ss: the strips state of the search tree. - nss: the strips state of the search tree after the action is applied. - traj: the path from the root to the node in the search tree. - nt: the path in the search tree after the action is applied. - g: the cost of the path from the root to the node in the search tree. """ s, ss, traj = node.state, node.strips_state, node.trajectory if heuristic_search_strips.DEBUG: print('hsstrips::pop_node:') print(' trajectory:', *traj, sep='\n ') print(' priority =', priority, 'g =', node.g) if heuristic_search_strips.DEBUG_INTERACTIVE: input(' Continue?') if verbose: pbar.set_description(f'heuristic_search::expanding: priority = {priority} g = {node.g}') pbar.update() if heuristic_search_strips.DEBUG: print('hsstrips::pop_node:', *traj) print(' priority =', priority, 'g =', node.g) goal_reached = goal_test( executor, s, goal_expr, trajectory=traj, verbose=verbose, mpt_tracker=mpt_tracker ) if goal_reached: if verbose: print('hsstrips::search succeeded.') print('hsstrips::total_expansions:', nr_expanded_states) return wrap_extra_info(traj) if len(traj) >= max_depth: continue for sa in strips_task.operators: a = sa.raw_operator succ, ns = apply_action(executor, s, a) nt = traj + (a, ) if succ: nss = sa.apply(ss) if use_strips_op else strips_translator.compile_state(ns.clone(), forward_derived=False) nnode = HeuristicSearchState(ns, nss, nt, node.g + 1) push_node(nnode) if verbose: print('hsstrips::search failed.') print('hsstrips::total_expansions:', nr_expanded_states) if mpt_tracker is not None: return wrap_extra_info(mpt_tracker.solution) return wrap_extra_info(None)
heuristic_search_strips.DEBUG = False heuristic_search_strips.set_debug = lambda x = True: setattr(heuristic_search_strips, 'DEBUG', x) heuristic_search_strips.DEBUG_INTERACTIVE = False heuristic_search_strips.set_debug_interactive = lambda x = True: setattr(heuristic_search_strips, 'DEBUG_INTERACTIVE', x)