Source code for concepts.pdsketch.regression_utils

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : regression_utils.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/30/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Utility functions for regression search."""

import itertools
from typing import Any, Optional, Union, Sequence, Tuple, NamedTuple, List, Dict

import jacinle
import torch

from concepts.dsl.dsl_types import ObjectType, ObjectConstant, UnnamedPlaceholder, Variable, QINDEX
from concepts.dsl.constraint import ConstraintSatisfactionProblem, EqualityConstraint, GroupConstraint
from concepts.dsl.constraint import OptimisticValue, AssignmentDict
from concepts.dsl.executors.tensor_value_executor import BoundedVariablesDictCompatible
from concepts.dsl.expression import (
    BoolExpression, FunctionApplicationExpression, ListCreationExpression, ListFunctionApplicationExpression, ListExpansionExpression,
    ConstantExpression, ObjectConstantExpression, QuantificationExpression, ValueOutputExpression, VariableExpression,
    is_and_expr, iter_exprs
)
from concepts.dsl.expression_visitor import IdentityExpressionVisitor
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import StateObjectReference
from concepts.dsl.value import ValueBase, ListValue
from concepts.pdsketch.domain import State
from concepts.pdsketch.executor import PDSketchSGC, PDSketchExecutor
from concepts.pdsketch.operator import OperatorApplier, OperatorApplicationExpression
from concepts.pdsketch.regression_rule import RegressionRule, RegressionRuleApplier, RegressionRuleApplicationExpression, AchieveExpression, FindExpression, RuntimeAssignExpression, RegressionCommitFlag
from concepts.pdsketch.crow.crow_state import TotallyOrderedPlan, PartiallyOrderedPlan

__all__ = [
    'surface_fol_downcast', 'ground_fol_expression', 'ground_fol_expression_v2',
    'ground_operator_application_expression', 'ground_regression_application_expression',
    'evaluate_bool_scalar_expression',
    'ApplicableRegressionRuleItem', 'ApplicableRegressionRuleGroup', 'gen_applicable_regression_rules', 'len_candidate_regression_rules',
    'create_find_expression_csp_variable', 'create_find_expression_variable_placeholder',
    'mark_constraint_group_solver',
    'has_optimistic_value_or_list', 'is_single_optimistic_value_or_list', 'cvt_single_optimistic_value_or_list',
    'has_optimistic_constant_expression',
    'map_csp_placeholder_goal', 'map_csp_placeholder_action', 'map_csp_placeholder_regression_rule_applier', 'map_csp_variable_mapping', 'map_csp_variable_state',
    'gen_grounded_subgoals_with_placeholders'
]


[docs]def surface_fol_downcast(expression_1: ValueOutputExpression, expression_2: ValueOutputExpression) -> Optional[Dict[str, Union[Variable, ObjectConstant]]]: """Trying to downcast the `expression_1` to the same form as `expression_2`. Downcasting means that we try to replace variables in `expression_1` with constants in `expression_2` to make them the same. Args: expression_1: the first expression. expression_2: the second expression. Returns: the downcasted mapping if the downcasting is successful, otherwise None. """ current_mapping = dict() # @jacinle.log_function(verbose=True) def dfs(expr1, expr2): nonlocal current_mapping if isinstance(expr1, VariableExpression): if expr1.name in current_mapping: expr2_name = _get_variable_or_constant_name(expr2) return expr2_name == current_mapping[expr1.name].name else: if isinstance(expr2, VariableExpression): current_mapping[expr1.name] = expr2.variable return True elif isinstance(expr2, ObjectConstantExpression): current_mapping[expr1.name] = expr2.constant return True elif isinstance(expr2, ConstantExpression): current_mapping[expr1.name] = expr2.constant return True elif isinstance(expr2, ListCreationExpression): current_mapping[expr1.name] = ListValue(expr1.return_type, [_get_variable_or_constant_object(x) for x in expr2.arguments]) return True else: return False elif isinstance(expr1, ObjectConstantExpression): if isinstance(expr2, ObjectConstantExpression): return expr1.name == expr2.name else: return False elif isinstance(expr1, (FunctionApplicationExpression, ListFunctionApplicationExpression)): if not isinstance(expr2, (FunctionApplicationExpression, ListFunctionApplicationExpression)): return False if expr1.function.name != expr2.function.name: return False if len(expr1.arguments) != len(expr2.arguments): return False for arg1, arg2 in zip(expr1.arguments, expr2.arguments): if not dfs(arg1, arg2): return False return True elif isinstance(expr1, BoolExpression): if not isinstance(expr2, BoolExpression): return False if expr1.bool_op != expr2.bool_op: return False if len(expr1.arguments) != len(expr2.arguments): return False for arg1, arg2 in zip(expr1.arguments, expr2.arguments): if not dfs(arg1, arg2): return False return True elif isinstance(expr1, QuantificationExpression): if not isinstance(expr2, QuantificationExpression): return False if expr1.quantification_op != expr2.quantification_op: return False assert expr1.variable.name not in current_mapping current_mapping[expr1.variable.name] = expr2.variable try: return dfs(expr1.expression, expr2.expression) finally: del current_mapping[expr1.variable.name] else: raise TypeError(f'Unsupported expression type: {type(expr1)}') rv = dfs(expression_1, expression_2) if rv: return current_mapping return None
[docs]def ground_fol_expression(expression: ValueOutputExpression, variable_mapping: Dict[Variable, str]) -> ValueOutputExpression: """Ground the given FOL expression with the given variable mapping. Args: expression: the expression to ground. variable_mapping: the variable mapping, which is a mapping from the Variable object to the constant name. Returns: the grounded expression. """ name2symbol = dict() for var, content in variable_mapping.items(): if isinstance(content, ListValue): if isinstance(content.element_type, ObjectType): name2symbol[var.name] = ObjectConstantExpression(content) else: name2symbol[var.name] = ConstantExpression(content) elif isinstance(content, ObjectConstant): name2symbol[var.name] = ObjectConstantExpression(content) elif isinstance(content, ValueBase): name2symbol[var.name] = ConstantExpression(content) elif isinstance(content, Variable): name2symbol[var.name] = VariableExpression(content) elif isinstance(content, str): name2symbol[var.name] = ObjectConstantExpression(ObjectConstant(content, var.dtype)) else: raise TypeError(f'Unsupported type: {type(content)}') bounded_variables = set() def dfs(e): if isinstance(e, VariableExpression): if e.name not in bounded_variables: return name2symbol[e.name] return e elif isinstance(e, ObjectConstantExpression): return e elif isinstance(e, ConstantExpression): return e elif isinstance(e, ListFunctionApplicationExpression): return ListFunctionApplicationExpression(e.function, [dfs(arg) for arg in e.arguments]) elif isinstance(e, FunctionApplicationExpression): return FunctionApplicationExpression(e.function, [dfs(arg) for arg in e.arguments]) elif isinstance(e, BoolExpression): return BoolExpression(e.bool_op, [dfs(arg) for arg in e.arguments]) elif isinstance(e, QuantificationExpression): bounded_variables.add(e.variable.name) rv = QuantificationExpression(e.quantification_op, e.variable, dfs(e.expression)) bounded_variables.remove(e.variable.name) return rv else: raise TypeError(f'Unsupported expression type: {type(e)}') return dfs(expression)
[docs]def ground_fol_expression_v2(expression: ValueOutputExpression, variable_mapping: Dict[str, Union[ListValue, ObjectConstant, ValueBase, Variable]]) -> ValueOutputExpression: """Ground the given FOL expression with the given variable mapping. Args: expression: the expression to ground. variable_mapping: the variable mapping, which is a mapping from the Variable object to the constant name. Returns: the grounded expression. """ name2symbol = dict() for var, content in variable_mapping.items(): if isinstance(content, ListValue): if isinstance(content.element_type, ObjectType): name2symbol[var] = ObjectConstantExpression(content) else: name2symbol[var] = ConstantExpression(content) elif isinstance(content, ObjectConstant): name2symbol[var] = ObjectConstantExpression(content) elif isinstance(content, ValueBase): name2symbol[var] = ConstantExpression(content) elif isinstance(content, Variable): name2symbol[var] = VariableExpression(content) else: raise TypeError(f'Unsupported type: {type(content)}') bounded_variables = set() def dfs(e): if isinstance(e, VariableExpression): if e.name not in bounded_variables: return name2symbol[e.name] return e elif isinstance(e, ObjectConstantExpression): return e elif isinstance(e, ListFunctionApplicationExpression): return ListFunctionApplicationExpression(e.function, [dfs(arg) for arg in e.arguments]) elif isinstance(e, FunctionApplicationExpression): arguments = [dfs(arg) for arg in e.arguments] return FunctionApplicationExpression(e.function, arguments) elif isinstance(e, BoolExpression): return BoolExpression(e.bool_op, [dfs(arg) for arg in e.arguments]) elif isinstance(e, QuantificationExpression): bounded_variables.add(e.variable.name) rv = QuantificationExpression(e.quantification_op, e.variable, dfs(e.expression)) bounded_variables.remove(e.variable.name) return rv else: raise TypeError(f'Unsupported expression type: {type(e)}') return dfs(expression)
def _get_variable_or_constant_name(expr: Union[VariableExpression, ObjectConstantExpression]) -> str: """Get the name of the given variable or constant expression.""" if isinstance(expr, VariableExpression): return expr.name elif isinstance(expr, ObjectConstantExpression): return expr.name else: raise TypeError(f'Unsupported type: {type(expr)} for _get_variable_or_constant_name.') def _get_variable_or_constant_object(expr: Union[VariableExpression, ObjectConstantExpression]) -> Union[Variable, ObjectConstant]: """Get the object of the given variable or constant expression.""" if isinstance(expr, VariableExpression): return expr.variable elif isinstance(expr, ObjectConstantExpression): return expr.constant else: raise TypeError(f'Unsupported type: {type(expr)} for _get_variable_or_constant_object.')
[docs]def ground_operator_application_expression(expression: OperatorApplicationExpression, variable_mapping: Dict[Variable, str], csp: Optional[ConstraintSatisfactionProblem] = None, rule_applier: Optional[RegressionRuleApplier] = None) -> OperatorApplier: """Ground the given operator application expression with the given variable mapping. Args: expression: the expression to ground. variable_mapping: the variable mapping, which is a mapping from the Variable object to the constant name. csp: the constraint satisfaction problem to add the constraints to. rule_applier: the rule applier to use. Returns: the grounded expression. """ name2symbol = {var.name: name for var, name in variable_mapping.items()} arguments = list() for arg in expression.arguments: if isinstance(arg, VariableExpression): symbol = name2symbol[arg.variable.name] arguments.append(symbol.name if isinstance(symbol, ObjectConstant) else symbol) elif isinstance(arg, UnnamedPlaceholder): if csp is not None: arguments.append(TensorValue.from_optimistic_value(csp.new_var(arg.dtype, wrap=True))) else: arguments.append(arg) elif isinstance(arg, str): arguments.append(arg) else: raise TypeError(f'Unknown argument type: {type(arg)}') return OperatorApplier(expression.operator, arguments, regression_rule=rule_applier)
[docs]def ground_regression_application_expression(expression: RegressionRuleApplicationExpression, variable_mapping: Dict[Variable, str], csp: Optional[ConstraintSatisfactionProblem] = None) -> RegressionRuleApplier: """Ground the given regression application expression with the given variable mapping. Args: expression: the expression to ground. variable_mapping: the variable mapping, which is a mapping from the Variable object to the constant name. csp: the constraint satisfaction problem to add the constraints to. Returns: the grounded expression. """ name2symbol = {var.name: name for var, name in variable_mapping.items()} arguments = list() for arg in expression.arguments: if isinstance(arg, VariableExpression): symbol = name2symbol[arg.variable.name] arguments.append(symbol.name if isinstance(symbol, ObjectConstant) else symbol) elif isinstance(arg, UnnamedPlaceholder): if csp is not None: arguments.append(TensorValue.from_optimistic_value(csp.new_var(arg.dtype, wrap=True))) else: arguments.append(arg) elif isinstance(arg, str): arguments.append(arg) else: raise TypeError(f'Unknown argument type: {type(arg)}') return RegressionRuleApplier( expression.regression_rule, arguments, maintains=_ground_maintains_expressions(expression.maintains, variable_mapping), serializability=expression.serializability, csp_serializability=expression.csp_serializability )
[docs]def evaluate_bool_scalar_expression( executor: PDSketchExecutor, expr: Union[ValueOutputExpression, Sequence[ValueOutputExpression]], state: State, bounded_variables: BoundedVariablesDictCompatible, csp: ConstraintSatisfactionProblem, csp_note: str = '' ) -> Tuple[bool, bool, Optional[ConstraintSatisfactionProblem]]: if csp is not None: csp = csp.clone() is_optimistic = False if not isinstance(expr, (list, tuple)): expr = [expr] for e in expr: rv = executor.execute(e, state=state, bounded_variables=bounded_variables, csp=csp).item() if isinstance(rv, OptimisticValue): csp.add_constraint(EqualityConstraint.from_bool(rv, True), note=csp_note) is_optimistic = True elif float(rv) < 0.5: return False, is_optimistic, csp return True, is_optimistic, csp
[docs]class ApplicableRegressionRuleItem(NamedTuple): regression_rule: RegressionRule bounded_variables: BoundedVariablesDictCompatible
[docs]class ApplicableRegressionRuleGroup(NamedTuple): chain_index: int subgoal_index: int regression_rules: List[ApplicableRegressionRuleItem]
[docs]def gen_applicable_regression_rules( executor: PDSketchExecutor, state: State, goals: PartiallyOrderedPlan, maintains: Sequence[ValueOutputExpression], return_all_candidates: bool = True ) -> List[ApplicableRegressionRuleGroup]: from concepts.pdsketch.planners.optimistic_search_bilevel_utils import extract_bounded_variables_from_nonzero, extract_bounded_variables_from_nonzero_dc # TODO(Jiayuan Mao @ 2023/09/10): implement maintains. candidate_regression_rules = list() for chain_index, chain in goals.iter_feasible_chains(): if chain.is_ordered: subgoal_indices = [len(chain) - 1] else: subgoal_indices = list(range(len(chain))) for subgoal_index in subgoal_indices: subgoal = chain.sequence[subgoal_index] this_chain_candidate_regression_rules = list() if isinstance(subgoal, RegressionRuleApplier): this_chain_candidate_regression_rules.append(ApplicableRegressionRuleItem(subgoal.regression_rule, {arg: argv for arg, argv in zip(subgoal.regression_rule.arguments, subgoal.arguments)})) else: # For each subgoal in the goal_set, try to find a list of applicable regression rules. # If one of the regression rules is always applicable, then we can stop searching. for regression_rule in executor.domain.regression_rules.values(): goal_expr = regression_rule.goal_expression if (variable_binding := surface_fol_downcast(goal_expr, subgoal)) is None: continue bounded_variables = dict() for v in regression_rule.goal_arguments: value = variable_binding[v.name] bounded_variables[v] = value if len(regression_rule.binding_arguments) > 0: for v in regression_rule.binding_arguments: bounded_variables[v] = QINDEX if len(regression_rule.preconditions_conjunction.arguments) > 0: if len(regression_rule.binding_arguments) > 4 and len(regression_rule.preconditions_conjunction.arguments) >= 2: rv = extract_bounded_variables_from_nonzero_dc(executor, state, regression_rule, bounded_variables, use_optimistic=False) else: sgc = PDSketchSGC(state, regression_rule.goal_expression, maintains) rv = executor.execute(regression_rule.preconditions_conjunction, state=state, bounded_variables=bounded_variables, sgc=sgc) rv = extract_bounded_variables_from_nonzero(state, rv, regression_rule, default_bounded_variables=bounded_variables, use_optimistic=False) else: rv = None if rv is None: type_binding_arguments = [state.object_type2name[v.dtype.typename] for v in regression_rule.binding_arguments] all_forall = True for i, arg in enumerate(regression_rule.binding_arguments): if arg.quantifier_flag == 'forall': type_binding_arguments[i] = type_binding_arguments[i][:1] if len(type_binding_arguments[i]) > 0 else [] else: all_forall = False for binding_arguments in itertools.product(*type_binding_arguments): cbv = bounded_variables.copy() for variable, value in zip(regression_rule.binding_arguments, binding_arguments): cbv[variable] = value if all_forall and regression_rule.always: this_chain_candidate_regression_rules = [ApplicableRegressionRuleItem(regression_rule, cbv)] break this_chain_candidate_regression_rules.append(ApplicableRegressionRuleItem(regression_rule, cbv)) else: all_forall, candidate_bounded_variables = rv if all_forall and regression_rule.always: candidate_bounded_variables = _expand_type_binding_arguments(state, regression_rule, candidate_bounded_variables, return_first=True) this_chain_candidate_regression_rules = [ApplicableRegressionRuleItem(regression_rule, candidate_bounded_variables[0])] break else: candidate_bounded_variables = _expand_type_binding_arguments(state, regression_rule, candidate_bounded_variables, return_first=False) this_chain_candidate_regression_rules.extend([ApplicableRegressionRuleItem(regression_rule, cbv) for cbv in candidate_bounded_variables]) candidate_regression_rules.append(ApplicableRegressionRuleGroup(chain_index, subgoal_index, this_chain_candidate_regression_rules)) # TODO(Jiayuan Mao @ 2024/01/20): implement this within the search process. if not return_all_candidates: # if we don't need to return all candidates, we can return only the first applicable regression rule if it's always applicable (for each subgoal). filtered_candidate_regression_rules = list() for item in candidate_regression_rules: simplified_regression_rules = list() for regression_rule in item.regression_rules: if regression_rule.regression_rule.always: simplified_regression_rules = [regression_rule] break else: simplified_regression_rules.append(regression_rule) filtered_candidate_regression_rules.append(ApplicableRegressionRuleGroup(item.chain_index, item.subgoal_index, simplified_regression_rules)) candidate_regression_rules = filtered_candidate_regression_rules return candidate_regression_rules
def _expand_type_binding_arguments(state: State, regression_rule: RegressionRule, candidate_bounded_variables, return_first: bool = False): output_candidate_bounded_variables = list() for cbv in candidate_bounded_variables: groups_variables = list() groups_values = list() for k, v in cbv.items(): if v is QINDEX: groups_variables.append(k) groups_values.append(state.object_type2name[k.dtype.typename]) if len(groups_variables) == 0: output_candidate_bounded_variables.append(cbv) if return_first: return output_candidate_bounded_variables else: for binding_arguments in itertools.product(*groups_values): cbv = cbv.copy() for variable, value in zip(groups_variables, binding_arguments): cbv[variable] = value output_candidate_bounded_variables.append(cbv) if return_first: return output_candidate_bounded_variables return output_candidate_bounded_variables
[docs]def len_candidate_regression_rules(candidate_regression_rules: List[ApplicableRegressionRuleGroup]) -> int: """Compute the number of candidate regression rules.""" return sum(len(x.regression_rules) for x in candidate_regression_rules)
[docs]def create_find_expression_csp_variable(variable: Variable, csp: ConstraintSatisfactionProblem, bounded_variables: Dict[Variable, Any]): """Create a TensorValue that corresponds to a variable inside a `FindExpression`. Args: variable: the variable in the FindExpression. csp: the current CSP. bounded_variables: the already bounded variables. """ if variable.dtype.is_list_type: length = -1 for v in bounded_variables.values(): if isinstance(v, ListValue): length = len(v) if length == -1: raise ValueError(f'Cannot create a list variable {variable} without specifying the length.') return ListValue(variable.dtype, [TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype.element_type, wrap=True)) for _ in range(length)]) else: return TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype, wrap=True))
[docs]def create_find_expression_variable_placeholder(variable: Variable, bounded_variables: Dict[Variable, Any]): """Create a TensorValue that corresponds to a variable inside a `FindExpression`. Unlike `_create_find_expression_variable`, this function only creates placeholder variables. Args: variable: the variable in the FindExpression. bounded_variables: the already bounded variables. """ if variable.dtype.is_list_type: length = -1 for v in bounded_variables.values(): if isinstance(v, ListValue): length = len(v) if length == -1: raise ValueError(f'Cannot create a list variable {variable} without specifying the length.') return ListValue(variable.dtype, [UnnamedPlaceholder(variable.dtype) for _ in range(length)]) else: return UnnamedPlaceholder(variable.dtype)
[docs]def mark_constraint_group_solver(executor: PDSketchExecutor, state: State, bounded_variables: Dict[Variable, Any], group: GroupConstraint): """Mark the solver for the current state. Args: executor: the executor. state: the current state. bounded_variables: the already bounded variables. group: the current group constraint. """ for generator in executor.domain.generators.values(): if (matching := surface_fol_downcast(generator.certifies, group.expression)) is not None: matching_success = True inputs = list() outputs = list() for var in generator.context: value = executor.execute(var, state, bounded_variables, optimistic_execution=True) if has_optimistic_value_or_list(value): matching_success = False break inputs.append(value) for var in generator.generates: this_matching_success = False if isinstance(var, VariableExpression): if var.name in matching and is_single_optimistic_value_or_list(matching[var.name]): this_matching_success = True outputs.append(cvt_single_optimistic_value_or_list(matching[var.name])) if not this_matching_success: matching_success = False if matching_success: group.candidate_generators.append((generator, inputs, outputs))
[docs]def has_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool: """Check if there is any optimistic value in the input TensorValue or a list of TensorValue's.""" if isinstance(x, ListValue): return any(has_optimistic_value_or_list(y) for y in x.values) elif isinstance(x, TensorValue): return x.has_optimistic_value() else: raise ValueError(f'Unknown value type {type(x)}')
[docs]def is_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool: """Check if the input TensorValue is a single optimistic value or a list of TensorValue's that are all single optimistic values.""" if isinstance(x, ListValue): return all(is_single_optimistic_value_or_list(y) for y in x.values) elif isinstance(x, TensorValue): return x.is_single_optimistic_value() else: raise ValueError(f'Unknown value type {type(x)}')
[docs]def cvt_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> Union[ListValue, OptimisticValue]: """Convert a single optimistic value stored in a TensorValue to an OptimisticValue. If the input is a list of TensorValue's, convert them to a list of OptimisticValue's.""" if isinstance(x, ListValue): return ListValue(x.dtype, [cvt_single_optimistic_value_or_list(y) for y in x.values]) elif isinstance(x, TensorValue): return x.single_elem() else: raise ValueError(f'Unknown value type {type(x)}')
[docs]def has_optimistic_constant_expression(*expressions: Union[ValueOutputExpression, RegressionRuleApplier]): """Check if there is a ConstantExpression whose value is an optimistic constant. Useful when checking if the subgoal is fully "grounded." """ for expression in expressions: if isinstance(expression, RegressionRuleApplier): expression = expression.goal_expression for e in iter_exprs(expression): if isinstance(e, ConstantExpression) and has_optimistic_value_or_list(e.constant): return True return False
[docs]def make_rule_applier(rule: RegressionRule, bounded_variables: Dict[str, ValueOutputExpression]) -> RegressionRuleApplier: """Make a rule applier from a regression rule and a set of bounded variables.""" canonized_bounded_variables = dict() for k, v in bounded_variables.items(): if isinstance(k, Variable): k = k.name if isinstance(v, ObjectConstant): v = v.name canonized_bounded_variables[k] = v arguments = [canonized_bounded_variables[x.name] for x in rule.arguments] return RegressionRuleApplier(rule, arguments)
class _ReplaceCSPVariableVisitor(IdentityExpressionVisitor): def __init__(self, csp: ConstraintSatisfactionProblem, previous_csp: ConstraintSatisfactionProblem, csp_variable_mapping: Dict[int, Any], reg_variable_mapping: Optional[Dict[str, ObjectConstant]]): self.csp = csp self.previous_csp = previous_csp self.csp_variable_mapping = csp_variable_mapping self.reg_variable_mapping = reg_variable_mapping if reg_variable_mapping is not None else dict() def _replace_opt_value(self, value: Any): if isinstance(value, ListValue): return ListValue(value.dtype, [self._replace_opt_value(x) for x in value.values]) elif isinstance(value, TensorValue): if value.is_single_optimistic_value(): identifier = value.single_elem().identifier if identifier in self.csp_variable_mapping: return self.csp_variable_mapping[identifier] else: self.csp_variable_mapping[identifier] = TensorValue.from_optimistic_value(self.csp.new_actionable_var(value.dtype, wrap=True)) return self.csp_variable_mapping[identifier] else: return value else: raise ValueError(f'Unknown value type {type(value)}') def visit_constant_expression(self, expr: ConstantExpression) -> ConstantExpression: return ConstantExpression(self._replace_opt_value(expr.constant)) def visit_variable_expression(self, expr: VariableExpression) -> Union[VariableExpression, ObjectConstantExpression, ConstantExpression]: if expr.variable.name in self.reg_variable_mapping: value = self.reg_variable_mapping[expr.variable.name] if isinstance(value, ObjectConstant): return ObjectConstantExpression(value) elif isinstance(value, TensorValue): return ConstantExpression(value) else: raise ValueError(f'Unknown value type {type(value)}') return expr # subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_bounded_variables, csp_variable_mapping)
[docs]def map_csp_placeholder_goal( subgoal: ValueOutputExpression, csp: ConstraintSatisfactionProblem, placeholder_csp: ConstraintSatisfactionProblem, cur_csp_variable_mapping: Dict[int, TensorValue], cur_reg_variable_mapping: Optional[Dict[str, Any]] = None ) -> Tuple[ValueOutputExpression, Dict[int, TensorValue]]: """Map the CSP variables in the subgoal to the CSP variables in the placeholder CSP.""" new_mapping = cur_csp_variable_mapping.copy() visitor = _ReplaceCSPVariableVisitor(csp, placeholder_csp, new_mapping, cur_reg_variable_mapping) new_subgoal = visitor.visit(subgoal) return new_subgoal, new_mapping
[docs]def map_csp_placeholder_action( action: OperatorApplier, csp: ConstraintSatisfactionProblem, placeholder_csp: ConstraintSatisfactionProblem, cur_csp_variable_mapping: Dict[int, TensorValue], cur_reg_variable_mapping: Optional[Dict[str, Any]] = None, ) -> Tuple[OperatorApplier, Dict[int, TensorValue]]: """Map the CSP variables in the action to the CSP variables in the placeholder CSP.""" new_mapping = cur_csp_variable_mapping.copy() new_arguments = list() for value in action.arguments: if isinstance(value, Variable): if cur_reg_variable_mapping is not None and value.name in cur_reg_variable_mapping: new_arguments.append(cur_reg_variable_mapping[value.name].name) else: raise KeyError(f'Unknown variable {value.name}') elif isinstance(value, TensorValue): if value.is_single_optimistic_value(): identifier = value.single_elem().identifier if identifier in new_mapping: new_arguments.append(new_mapping[identifier]) else: new_mapping[identifier] = TensorValue.from_optimistic_value(csp.new_actionable_var(value.dtype, wrap=True)) new_arguments.append(new_mapping[identifier]) else: new_arguments.append(value) else: new_arguments.append(value) new_action = OperatorApplier(action.operator, new_arguments, regression_rule=action.regression_rule) return new_action, new_mapping
[docs]def map_csp_placeholder_regression_rule_applier( rule: RegressionRuleApplier, csp: ConstraintSatisfactionProblem, placeholder_csp: ConstraintSatisfactionProblem, cur_csp_variable_mapping: Dict[int, TensorValue], cur_reg_variable_mapping: Optional[Dict[str, Any]] = None ) -> Tuple[RegressionRuleApplier, Dict[int, TensorValue]]: """Map the CSP variables in the regression rule applier to the CSP variables in the placeholder CSP.""" new_mapping = cur_csp_variable_mapping.copy() new_arguments = list() for value in rule.arguments: if isinstance(value, Variable): if cur_reg_variable_mapping is not None and value.name in cur_reg_variable_mapping: new_arguments.append(cur_reg_variable_mapping[value.name].name) else: raise KeyError(f'Unknown variable {value.name}') elif isinstance(value, TensorValue): if value.is_single_optimistic_value(): identifier = value.single_elem().identifier if identifier in new_mapping: new_arguments.append(new_mapping[identifier]) else: new_mapping[identifier] = TensorValue.from_optimistic_value(csp.new_actionable_var(value.dtype, wrap=True)) new_arguments.append(new_mapping[identifier]) else: new_arguments.append(value) else: new_arguments.append(value) new_rule = RegressionRuleApplier(rule.regression_rule, new_arguments) return new_rule, new_mapping
[docs]def map_csp_variable_mapping( csp_variable_mapping: Dict[int, TensorValue], csp: ConstraintSatisfactionProblem, assignments: AssignmentDict ) -> Dict[int, TensorValue]: """Map the CSP variable mapping to the new variable mapping.""" new_mapping = dict() for identifier, value in csp_variable_mapping.items(): if isinstance(value, TensorValue): if value.is_single_optimistic_value(): new_identifier = value.single_elem().identifier if new_identifier in assignments: new_value = csp.ground_assignment_value_partial(assignments, new_identifier) if isinstance(new_value, OptimisticValue): new_mapping[identifier] = TensorValue.from_optimistic_value(new_value) elif isinstance(new_value, TensorValue): new_mapping[identifier] = new_value else: raise TypeError(f'Unknown value type {type(new_value)}') else: new_mapping[identifier] = value else: raise TypeError(f'Unknown value type {type(value)}') return new_mapping
[docs]def map_csp_variable_state( state: State, csp: ConstraintSatisfactionProblem, assignments: AssignmentDict ) -> State: """Map the CSP variable state to the new variable state.""" new_state = state.clone() for feature_name, tensor_value in new_state.features.items(): if tensor_value.tensor_optimistic_values is None: continue for ind in torch.nonzero(tensor_value.tensor_optimistic_values).tolist(): ind = tuple(ind) identifier = tensor_value.tensor_optimistic_values[ind].item() if identifier in assignments: new_value = csp.ground_assignment_value_partial(assignments, identifier) if isinstance(new_value, OptimisticValue): tensor_value.tensor_optimistic_values[ind] = new_value.identifier elif isinstance(new_value, TensorValue): tensor_value.tensor[ind] = new_value.tensor tensor_value.tensor_optimistic_values[ind] = 0 else: raise TypeError(f'Unknown value type {type(new_value)}') return new_state
GroundedSubgoalItem = Union[AchieveExpression, FindExpression, OperatorApplier, RegressionRuleApplier, RegressionCommitFlag] def _ground_maintains_expressions(maintains: Tuple[ValueOutputExpression, ...], bounded_variables): return tuple(ground_fol_expression(e, bounded_variables) for e in maintains)
[docs]def gen_grounded_subgoals_with_placeholders( executor: PDSketchExecutor, state: State, goal: ValueOutputExpression, constraints: Sequence[ValueOutputExpression], candidate_regression_rules: List[ApplicableRegressionRuleItem], enable_csp: bool ) -> Dict[int, Tuple[List[GroundedSubgoalItem], Optional[ConstraintSatisfactionProblem], int]]: """Generated a set of subgoals with placeholders for CSP variables. Args: executor: the executor. state: the current state. goal: the goal expression. constraints: the constraints. candidate_regression_rules: the candidate regression rules. enable_csp: whether to enable the constraint satisfaction problem. Returns: the grounded subgoals. It is a dictionary mapping from the index of the regression rule to a tuple: - the grounded subgoals (which can be AchieveExpression, FindExpression, OperatorApplier, RegressionRuleApplier, or RegressionCommitFlag) - the constraint satisfaction problem (for tracking placeholder variables). - the length of the prefix that can be reordered. """ grounded_subgoals_cache = dict() for regression_rule_index, (rule, bounded_variables) in enumerate(candidate_regression_rules): grounded_subgoals = list() placeholder_csp = ConstraintSatisfactionProblem() if enable_csp else None placeholder_bounded_variables = bounded_variables.copy() rule_applier = make_rule_applier(rule, placeholder_bounded_variables) for i, item in enumerate(rule.body): if isinstance(item, AchieveExpression): grounded_subgoals.append(AchieveExpression( ground_fol_expression(item.goal, placeholder_bounded_variables), maintains=_ground_maintains_expressions(item.maintains, placeholder_bounded_variables), serializability=item.serializability, csp_serializability=item.csp_serializability )) elif isinstance(item, FindExpression): if not enable_csp: raise ValueError('FindExpression must be used with a constraint satisfaction problem.') for variable in item.variables: if isinstance(variable.dtype, ObjectType): placeholder_bounded_variables[variable] = variable else: placeholder_bounded_variables[variable] = create_find_expression_csp_variable(variable, csp=placeholder_csp, bounded_variables=placeholder_bounded_variables) grounded_subgoals.append(FindExpression(item.variables, ground_fol_expression(item.goal, placeholder_bounded_variables), serializability=item.serializability, csp_serializability=item.csp_serializability, ordered=item.ordered)) elif isinstance(item, OperatorApplicationExpression): cur_action = ground_operator_application_expression(item, placeholder_bounded_variables, csp=placeholder_csp, rule_applier=rule_applier) grounded_subgoals.append(cur_action) elif isinstance(item, RegressionRuleApplicationExpression): cur_action = ground_regression_application_expression(item, placeholder_bounded_variables, csp=placeholder_csp) grounded_subgoals.append(cur_action) elif isinstance(item, ListExpansionExpression): if is_and_expr(item.expression) and len(item.expression.arguments) == 1 and item.expression.arguments[0].return_type.is_list_type: # handles ... (and p({x, y, z}, ...)) subgoals = executor.execute(item.expression.arguments[0], state, placeholder_bounded_variables, sgc=PDSketchSGC(state, goal, constraints)) grounded_subgoals.extend(subgoals.values) else: subgoals = executor.execute(item.expression, state, placeholder_bounded_variables, sgc=PDSketchSGC(state, goal, constraints)) assert isinstance(subgoals, TotallyOrderedPlan), f'ListExpansionExpression must be used with a TotallyOrderedPlan, got {type(subgoals)}' grounded_subgoals.extend(subgoals.sequence) elif isinstance(item, RuntimeAssignExpression): placeholder_bounded_variables[item.variable] = item.variable grounded_subgoals.append(RuntimeAssignExpression(item.variable, ground_fol_expression(item.value, placeholder_bounded_variables))) elif isinstance(item, RegressionCommitFlag): grounded_subgoals.append(item) else: raise ValueError(f'Unknown item type {type(item)} in rule {item}.') # pass the serializability information to the previous subgoal. max_reorder_prefix_length = 0 for i, item in enumerate(grounded_subgoals): if isinstance(item, RegressionCommitFlag): if i > 0 and isinstance(grounded_subgoals[i - 1], (AchieveExpression, FindExpression)): grounded_subgoals[i - 1].serializability = item.goal_serializability if isinstance(item, (AchieveExpression, FindExpression)): if item.sequential_decomposable is False: max_reorder_prefix_length = i + 1 grounded_subgoals_cache[regression_rule_index] = (grounded_subgoals, placeholder_csp, max_reorder_prefix_length) return grounded_subgoals_cache