Source code for concepts.pdsketch.crow.crow_planner

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : crow_planner.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 11/09/2023
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

from typing import Any, Optional, Union, Iterable, Iterator, Tuple, List, Dict
from dataclasses import dataclass

import jacinle
import torch

from concepts.dsl.dsl_types import UnnamedPlaceholder, Variable
from concepts.dsl.expression import VariableExpression, ConstantExpression, ValueOutputExpression, ListExpansionExpression, is_and_expr, iter_exprs
from concepts.dsl.expression_visitor import IdentityExpressionVisitor
from concepts.dsl.constraint import OptimisticValue, GroupConstraint, ConstraintSatisfactionProblem, AssignmentDict
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorValue
from concepts.pdsketch.operator import OperatorApplicationExpression, OperatorApplier
from concepts.pdsketch.executor import PDSketchExecutor, PDSketchSGC
from concepts.pdsketch.domain import State
from concepts.pdsketch.regression_rule import CSPCommitFlag, AchieveExpression, FindExpression, SubgoalSerializability
from concepts.pdsketch.planners.optimistic_search_bilevel_utils import OptimisticSearchSymbolicPlan
from concepts.pdsketch.planners.optimistic_search import ground_actions
from concepts.pdsketch.csp_solvers.dpll_sampling import csp_dpll_sampling_solve
from concepts.pdsketch.crow.regression_utils import surface_fol_downcast, ground_fol_expression
from concepts.pdsketch.crow.regression_utils import evaluate_bool_scalar_expression, ground_operator_application_expression, gen_applicable_regression_rules, ApplicableRegressionRuleGroup
from concepts.pdsketch.crow.crow_state import PartiallyOrderedPlan, TotallyOrderedPlan


[docs]@dataclass class SearchState(object): pass
[docs]def crow_recursive( executor: PDSketchExecutor, state: State, goal_expr: Union[str, ValueOutputExpression], *, is_goal_serialized: bool = True, enable_reordering: bool = False, enable_csp: bool = True, max_actions: int = 10, max_csp_branching_factor: int = 5, max_beam_size: int = 20, allow_empty_plan_for_optimistic_goal: bool = False, verbose: bool = True ) -> Tuple[Iterable[Any], Dict[str, Any]]: """Compositional Regression and Optimization Wayfinder. Args: executor: the executor. state: the initial state. goal_expr: the goal expression. is_goal_serialized: whether the goal is serialized already. Otherwise, it will be treated as a conjunction. enable_reordering: whether to enable reordering of subgoals in regression rules. max_actions: the maximum number of actions in a plan. verbose: whether to print verbose information. Returns: A list of plans. Each plan is a tuple of (actions, csp, initial_state, final_state). """ if isinstance(goal_expr, str): goal_expr = executor.parse(goal_expr) search_cache = dict() search_stat = {'nr_expanded_nodes': 0} # NB(Jiayuan Mao @ 2023/09/09): the cache only works for previous_actions == []. # That is, we only cache the search results that start from the initial state. def return_with_cache(goal_set, previous_actions, rv): if len(previous_actions) == 0: goal_str = goal_set.gen_string() if goal_str not in search_cache: search_cache[goal_str] = rv return rv def try_retrieve_cache(goal_set, previous_actions): if len(previous_actions) == 0: goal_str = goal_set.gen_string() if goal_str in search_cache: return search_cache[goal_str] return None @jacinle.log_function(verbose=False) def dfs( s: State, g: PartiallyOrderedPlan, c: Tuple[ValueOutputExpression, ...], csp: Optional[ConstraintSatisfactionProblem], previous_actions: List[OperatorApplier], return_all: bool = False, tail_csp_solve: bool = False, nr_high_level_actions: int = 0 ) -> Iterator[Tuple[State, ConstraintSatisfactionProblem, List[OperatorApplier]]]: """Depth-first search for all possible plans. Args: s: the current state. g: the current goal. c: the list of constraints to maintain. csp: the current constraint satisfaction problem. previous_actions: the previous actions. return_all: whether to return all possible plans. If False, only return the first plan found. tail_csp_solve: whether to solve the CSP after the expansion of the current goal. nr_high_level_actions: the number of high-level actions. Returns: a list of plans. Each plan is a tuple of (final_state, csp, actions). """ if verbose: jacinle.log_function.print('Current goal', g, f'return_all={return_all}', f'previous_actions={previous_actions}') if enable_csp and not tail_csp_solve: return_all = True # Ignore the STRONG and ORDER serializability flags. if (rv := try_retrieve_cache(g, previous_actions)) is not None: return rv all_possible_plans = list() flatten_goals = list(g.iter_goals()) if not _has_optimistic_constant_expression(*flatten_goals) and allow_empty_plan_for_optimistic_goal: """If the current goal contains no optimistic constant, we may directly solve the CSP.""" rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, flatten_goals, s, dict(), csp, csp_note='goal_test') if rv: all_possible_plans.append((s, new_csp, previous_actions)) if not is_optimistic: # If there is no optimistic value, we can stop the search from here. # NB(Jiayuan Mao @ 2023/09/11): note that even if `return_all` is True, we still return here. # This corresponds to an early stopping behavior that defines the space of all possible plans. return return_with_cache(g, previous_actions, all_possible_plans) if nr_high_level_actions > max_actions: return return_with_cache(g, previous_actions, all_possible_plans) search_stat['nr_expanded_nodes'] += 1 candidate_regression_rules = gen_applicable_regression_rules(executor, s, g, c) if _len_candidate_regression_rules(candidate_regression_rules) == 0: return return_with_cache(g, previous_actions, all_possible_plans) some_rule_success = False for chain_index, subgoal_index, this_candidate_regression_rules in candidate_regression_rules: other_goals = g.exclude(chain_index, subgoal_index) cur_goal = g.chains[chain_index].sequence[subgoal_index] if verbose: jacinle.log_function.print('Now trying to excluding goal', cur_goal) if len(other_goals) == 0: other_goals_plans = [(s, csp, previous_actions)] else: # TODO(Jiayuan Mao @ 2023/09/09): change this list to an actual generator call. need_return_all = any(rule.max_rule_prefix_length > 0 for rule, _ in this_candidate_regression_rules) or enable_csp other_goals_plans = list(dfs(s, other_goals, c, csp, previous_actions, nr_high_level_actions=nr_high_level_actions, return_all=need_return_all)) for cur_state, cur_csp, cur_actions in other_goals_plans: rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, [cur_goal], cur_state, dict(), cur_csp, csp_note='goal_test_shortcut') if rv: all_possible_plans.append((cur_state, new_csp, cur_actions)) if not is_optimistic: # NB(Jiayuan Mao @ 2023/09/11): another place where we stop the search and ignores the `return_all` flag. continue if len(this_candidate_regression_rules) == 0: continue if len(other_goals_plans) == 0: continue if len(other_goals) == 0: max_prefix_length = 0 else: if not enable_reordering: max_prefix_length = 0 else: # TODO(Jiayuan Mao @ 2023/11/27): set this number to a very large number, and then use a flag to control the applicability of the reorderings. max_prefix_length = max(rule.max_reorder_prefix_length for rule, _ in this_candidate_regression_rules) prefix_stop_mark = dict() for prefix_length in range(max_prefix_length + 1): for regression_rule_index, (rule, bounded_variables) in enumerate(this_candidate_regression_rules): if prefix_length > rule.max_reorder_prefix_length: continue if regression_rule_index in prefix_stop_mark and prefix_stop_mark[regression_rule_index]: continue if verbose: jacinle.log_function.print('Applying rule', rule, 'for', cur_goal, 'and prefix length', prefix_length, 'goal is', g) if prefix_length == 0: # TODO(Jiayuan Mao @ 2023/11/19): there is a bug for max_rule_prefix_length when there is a list expansion. if rule.max_rule_prefix_length > 0: previous_possible_branches = other_goals_plans else: previous_possible_branches = [other_goals_plans[0]] else: raise NotImplementedError('Reordering is not implemented yet.') # TODO(Jiayuan Mao @ 2023/11/24): implement this. # cur_other_goals = other_goals.add_chain(grounded_subgoals[:prefix_length]) # previous_possible_branches = list(dfs(s, cur_other_goals, c, cur_csp, previous_actions, nr_high_level_actions=nr_high_level_actions + 1, return_all=rule.max_rule_prefix_length > 0)) if len(previous_possible_branches) == 0: if verbose: jacinle.log_function.print('Prefix planning failed!!! Stop.') # If it's not possible to achieve the subset of goals, then it's not possible to achieve the whole goal. # Therefore, this is a break, not a continue. prefix_stop_mark[regression_rule_index] = True continue for prev_state, prev_csp, prev_actions in previous_possible_branches: # construct the new csp and the sequence of grounded subgoals. grounded_subgoals = list() placeholder_csp = ConstraintSatisfactionProblem() placeholder_bounded_variables = bounded_variables.copy() for i, item in enumerate(rule.body): if isinstance(item, AchieveExpression): grounded_subgoals.append(AchieveExpression(ground_fol_expression(item.goal, placeholder_bounded_variables), maintains=[])) elif isinstance(item, FindExpression): for variable in item.variables: placeholder_bounded_variables[variable] = _create_find_expression_variable(variable, csp=placeholder_csp, bounded_variables=placeholder_bounded_variables) grounded_subgoals.append(FindExpression([], ground_fol_expression(item.goal, placeholder_bounded_variables))) elif isinstance(item, OperatorApplicationExpression): cur_action = ground_operator_application_expression(item, placeholder_bounded_variables, csp=placeholder_csp, add_csp_variables=False) grounded_subgoals.append(cur_action) elif isinstance(item, ListExpansionExpression): subgoals = executor.execute(item.expression, s, placeholder_bounded_variables, sgc=PDSketchSGC(s, g, c)) grounded_subgoals.extend(subgoals.sequence) elif isinstance(item, CSPCommitFlag): grounded_subgoals.append(item) else: raise ValueError(f'Unknown item type {type(item)} in rule {item}.') possible_branches = [(prev_state, prev_csp, prev_actions, {})] for i in range(prefix_length, len(grounded_subgoals)): item = grounded_subgoals[i] next_possible_branches = list() if isinstance(item, AchieveExpression): if not enable_csp and item.serializability is SubgoalSerializability.STRONG and len(possible_branches) > 1: possible_branches = [min(possible_branches, key=lambda x: len(x[2]))] need_return_all = enable_csp or i < rule.max_rule_prefix_length for branch_index, (cur_state, cur_csp, cur_actions, cur_csp_variable_mapping) in enumerate(possible_branches): # prev_next_possible_branches_length = len(next_possible_branches) if isinstance(item, AchieveExpression): new_csp = cur_csp.clone() if cur_csp is not None else None subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping) next_possible_branches.extend([(*x, new_csp_variable_mapping) for x in dfs( cur_state, PartiallyOrderedPlan.from_single_goal(subgoal), c + item.maintains, new_csp, cur_actions, return_all=need_return_all, nr_high_level_actions=nr_high_level_actions + 1 )]) elif isinstance(item, FindExpression): if cur_csp is None: raise RuntimeError('FindExpression must be used with a CSP.') new_csp = cur_csp.clone() subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping) with new_csp.with_group(subgoal) as group: rv = executor.execute(subgoal, cur_state, {}, csp=new_csp).item() if isinstance(rv, OptimisticValue): new_csp.add_equal_constraint(rv) _mark_solver(executor, state, bounded_variables, group) next_possible_branches.append((cur_state, new_csp, cur_actions, new_csp_variable_mapping)) elif isinstance(item, OperatorApplier): # TODO(Jiayuan Mao @ 2023/09/11): vectorize this operation, probably only useful when `return_all` is True. new_csp = cur_csp.clone() if cur_csp is not None else None subaction, new_csp_variable_mapping = _map_csp_placeholder_action(item, new_csp, placeholder_csp, placeholder_bounded_variables, cur_csp_variable_mapping) succ, new_state = executor.apply(subaction, cur_state, csp=new_csp, clone=True, action_index=len(cur_actions)) if succ: next_possible_branches.append((new_state, new_csp, cur_actions + [subaction], new_csp_variable_mapping)) elif isinstance(item, CSPCommitFlag): assignments = csp_dpll_sampling_solve(executor, cur_csp) if assignments is not None: new_state = _map_csp_variable_state(cur_state, cur_csp, assignments) new_csp = ConstraintSatisfactionProblem() new_actions = ground_actions(executor, cur_actions, assignments) new_csp_variable_mapping = _map_csp_variable_mapping(cur_csp_variable_mapping, csp, assignments) next_possible_branches.append((new_state, new_csp, new_actions, new_csp_variable_mapping)) # TODO(Jiayuan Mao @ 2023/11/27): okay we need to implement some kind of tracking of "bounded_variables." # This need to be done by tracking some kind of mapping for optimistic variables in "grounded_subgoals." else: raise TypeError(f'Unknown item: {item}') # jacinle.log_function.print(f'Branch {branch_index + 1} of {len(possible_branches)} for {item} has {len(next_possible_branches) - prev_next_possible_branches_length} branches.') possible_branches = next_possible_branches # jacinle.log_function.print(f'Finished search subgoal {i + 1} of {len(grounded_subgoals)}: {item}. Possible branches (length={len(possible_branches)}):') # for x in possible_branches: # jacinle.log_function.print(jacinle.indent_text(str(x[2]))) # all_possible_plans.extend(possible_branches) found_plan = False # TODO(Jiayuan Mao @ 2023/09/11): implement this via maintains checking. for cur_state, cur_csp, actions, _ in possible_branches: rv, is_optimistic, new_csp = evaluate_bool_scalar_expression(executor, flatten_goals, cur_state, dict(), csp=cur_csp, csp_note=f'subgoal_test: {"; ".join([str(x) for x in flatten_goals])}') if rv: if verbose: jacinle.log_function.print('Found a plan', [str(x) for x in actions], 'for goal', g) if is_optimistic and tail_csp_solve: assignments = csp_dpll_sampling_solve(executor, new_csp, verbose=True) if assignments is not None: all_possible_plans.append((cur_state, actions, ground_actions(executor, actions, assignments))) found_plan = True else: all_possible_plans.append((cur_state, new_csp, actions)) found_plan = True if found_plan: prefix_stop_mark[regression_rule_index] = True some_rule_success = True # TODO(Jiayuan Mao @ 2023/09/06): since we have changed the order of prefix_length for-loop and the regression rule for-loop. # We need to use an additional dictionary to store whether we have found a plan for a particular regression rule. # Right now this doesn't matter because we only use the first plan. if not return_all and some_rule_success: break # Break for-loop for `for prev_state in previous_possible_branches`. if not return_all and some_rule_success: break # Break for-loop for `for rule in regression_rules` if not return_all and some_rule_success: break # Break for-loop for `for prefix_length in range(1, rule.max_rule_prefix_length + 1):` if not return_all and some_rule_success: break if len(all_possible_plans) == 0: if verbose: jacinle.log_function.print('No possible plans for goal', g) return return_with_cache(g, previous_actions, []) # TODO(Jiayuan Mao @ 2023/11/19): add unique back. # unique_all_possible_plans = _unique_plans(all_possible_plans) unique_all_possible_plans = all_possible_plans if len(unique_all_possible_plans) != len(all_possible_plans): if verbose: jacinle.log_function.print('Warning: there are duplicate plans for goal', g, f'({len(unique_all_possible_plans)} unique plans vs {len(all_possible_plans)} total plans)') # import ipdb; ipdb.set_trace() unique_all_possible_plans = sorted(unique_all_possible_plans, key=lambda x: len(x[2])) return return_with_cache(g, previous_actions, unique_all_possible_plans) if is_and_expr(goal_expr): if len(goal_expr.arguments) == 1 and goal_expr.arguments[0].return_type.is_list_type: goal_set = [goal_expr] else: goal_set = list(goal_expr.arguments) else: goal_set = [goal_expr] goal_set = PartiallyOrderedPlan((TotallyOrderedPlan(goal_set, is_ordered=is_goal_serialized),)) candidate_plans = dfs(state, goal_set, tuple(), csp=ConstraintSatisfactionProblem() if enable_csp else None, previous_actions=list(), tail_csp_solve=True) candidate_plans = [actions for final_state, csp, actions in candidate_plans] return candidate_plans, search_stat
def _create_find_expression_variable(variable: Variable, csp: ConstraintSatisfactionProblem, bounded_variables: Dict[Variable, Any]): """Create a TensorValue that corresponds to a variable inside a `FindExpression`. Args: variable: the variable in the FindExpression. csp: the current CSP. bounded_variables: the already bounded variables. """ if variable.dtype.is_list_type: length = -1 for v in bounded_variables.values(): if isinstance(v, ListValue): length = len(v) if length == -1: raise ValueError(f'Cannot create a list variable {variable} without specifying the length.') return ListValue(variable.dtype, [TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype.element_type, wrap=True)) for _ in range(length)]) else: return TensorValue.from_optimistic_value(csp.new_actionable_var(variable.dtype, wrap=True)) def _create_find_expression_variable_placeholder(variable: Variable, bounded_variables: Dict[Variable, Any]): """Create a TensorValue that corresponds to a variable inside a `FindExpression`. Unlike `_create_find_expression_variable`, this function only creates placeholder variables. Args: variable: the variable in the FindExpression. bounded_variables: the already bounded variables. """ if variable.dtype.is_list_type: length = -1 for v in bounded_variables.values(): if isinstance(v, ListValue): length = len(v) if length == -1: raise ValueError(f'Cannot create a list variable {variable} without specifying the length.') return ListValue(variable.dtype, [UnnamedPlaceholder(variable.dtype) for _ in range(length)]) else: return UnnamedPlaceholder(variable.dtype) def _mark_solver(executor: PDSketchExecutor, state: State, bounded_variables: Dict[Variable, Any], group: GroupConstraint): """Mark the solver for the current state. Args: executor: the executor. state: the current state. group: the current group constraint. """ for generator in executor.domain.generators.values(): if (matching := surface_fol_downcast(generator.certifies, group.expression)) is not None: matching_success = True inputs = list() outputs = list() for var in generator.context: value = executor.execute(var, state, bounded_variables, optimistic_execution=True) if _has_optimistic_value_or_list(value): matching_success = False break inputs.append(value) for var in generator.generates: this_matching_success = False if isinstance(var, VariableExpression): if var.name in matching and _is_single_optimistic_value_or_list(matching[var.name]): this_matching_success = True outputs.append(_cvt_single_optimistic_value_or_list(matching[var.name])) if not this_matching_success: matching_success = False if matching_success: group.candidate_generators.append((generator, inputs, outputs)) def _has_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool: if isinstance(x, ListValue): return any(_has_optimistic_value_or_list(y) for y in x.values) elif isinstance(x, TensorValue): return x.has_optimistic_value() else: raise ValueError(f'Unknown value type {type(x)}') def _is_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> bool: if isinstance(x, ListValue): return all(_is_single_optimistic_value_or_list(y) for y in x.values) elif isinstance(x, TensorValue): return x.is_single_optimistic_value() else: raise ValueError(f'Unknown value type {type(x)}') def _cvt_single_optimistic_value_or_list(x: Union[ListValue, TensorValue]) -> Union[ListValue, OptimisticValue]: if isinstance(x, ListValue): return ListValue(x.dtype, [_cvt_single_optimistic_value_or_list(y) for y in x.values]) elif isinstance(x, TensorValue): return x.single_elem() else: raise ValueError(f'Unknown value type {type(x)}') def _has_optimistic_constant_expression(*expressions: ValueOutputExpression): """Check if there is a ConstantExpression whose value is an optimistic constant. Useful when checking if the subgoal is fully "gronded." """ for expression in expressions: for e in iter_exprs(expression): if isinstance(e, ConstantExpression) and _has_optimistic_value_or_list(e.constant): return True return False def _len_candidate_regression_rules(candidate_regression_rules: List[ApplicableRegressionRuleGroup]) -> int: """Compute the number of candidate regression rules.""" return sum(len(x.regression_rules) for x in candidate_regression_rules) class _ReplaceCSPVariableVisitor(IdentityExpressionVisitor): def __init__(self, csp: ConstraintSatisfactionProblem, previous_csp: ConstraintSatisfactionProblem, variable_mapping: Dict[int, Any]): self.csp = csp self.previous_csp = previous_csp self.variable_mapping = variable_mapping def _replace_opt_value(self, value: Any): if isinstance(value, ListValue): return ListValue(value.dtype, [self._replace_opt_value(x) for x in value.values]) elif isinstance(value, TensorValue): if value.is_single_optimistic_value(): identifier = value.single_elem().identifier if identifier in self.variable_mapping: return self.variable_mapping[identifier] else: self.variable_mapping[identifier] = TensorValue.from_optimistic_value(self.csp.new_actionable_var(value.dtype, wrap=True)) return self.variable_mapping[identifier] else: return value else: raise ValueError(f'Unknown value type {type(value)}') def visit_constant_expression(self, expr: ConstantExpression) -> ConstantExpression: return ConstantExpression(self._replace_opt_value(expr.constant)) # subgoal, new_csp_variable_mapping = _map_csp_placeholder_goal(item.goal, new_csp, placeholder_csp, placeholder_bounded_variables, cur_bounded_variables, csp_variable_mapping) def _map_csp_placeholder_goal( subgoal: ValueOutputExpression, csp: ConstraintSatisfactionProblem, placeholder_csp: ConstraintSatisfactionProblem, placeholder_bounded_variables: Dict[Variable, Any], csp_variable_mapping: Dict[int, TensorValue] ) -> Tuple[ValueOutputExpression, Dict[int, TensorValue]]: """Map the CSP variables in the subgoal to the CSP variables in the placeholder CSP.""" new_mapping = csp_variable_mapping.copy() visitor = _ReplaceCSPVariableVisitor(csp, placeholder_csp, new_mapping) new_subgoal = visitor.visit(subgoal) return new_subgoal, new_mapping def _map_csp_placeholder_action( action: OperatorApplier, csp: ConstraintSatisfactionProblem, placeholder_csp: ConstraintSatisfactionProblem, placeholder_bounded_variables: Dict[Variable, Any], csp_variable_mapping: Dict[int, TensorValue] ) -> Tuple[OperatorApplier, Dict[int, TensorValue]]: """Map the CSP variables in the action to the CSP variables in the placeholder CSP.""" new_mapping = csp_variable_mapping.copy() new_arguments = list() for value in action.arguments: if isinstance(value, TensorValue): if value.is_single_optimistic_value(): identifier = value.single_elem().identifier if identifier in new_mapping: new_arguments.append(new_mapping[identifier]) else: new_mapping[identifier] = TensorValue.from_optimistic_value(csp.new_actionable_var(value.dtype, wrap=True)) new_arguments.append(new_mapping[identifier]) else: new_arguments.append(value) else: new_arguments.append(value) new_action = OperatorApplier(action.operator, new_arguments) return new_action, new_mapping def _map_csp_variable_mapping( csp_variable_mapping: Dict[int, TensorValue], csp: ConstraintSatisfactionProblem, assignments: AssignmentDict ) -> Dict[int, TensorValue]: """Map the CSP variable mapping to the new variable mapping.""" new_mapping = dict() for identifier, value in csp_variable_mapping.items(): if isinstance(value, TensorValue): if value.is_single_optimistic_value(): new_identifier = value.single_elem().identifier if new_identifier in assignments: new_value = csp.ground_assignment_value_partial(assignments, new_identifier) if isinstance(new_value, OptimisticValue): new_mapping[identifier] = TensorValue.from_optimistic_value(new_value) elif isinstance(new_value, TensorValue): new_mapping[identifier] = new_value else: raise TypeError(f'Unknown value type {type(new_value)}') else: new_mapping[identifier] = value else: raise TypeError(f'Unknown value type {type(value)}') return new_mapping def _map_csp_variable_state( state: State, csp: ConstraintSatisfactionProblem, assignments: AssignmentDict ) -> State: """Map the CSP variable state to the new variable state.""" new_state = state.clone() for feature_name, tensor_value in new_state.features.items(): if tensor_value.tensor_optimistic_values is None: continue for ind in torch.nonzero(tensor_value.tensor_optimistic_values).tolist(): ind = tuple(ind) identifier = tensor_value.tensor_optimistic_values[ind].item() if identifier in assignments: new_value = csp.ground_assignment_value_partial(assignments, identifier) if isinstance(new_value, OptimisticValue): tensor_value.tensor_optimistic_values[ind] = new_value.identifier elif isinstance(new_value, TensorValue): tensor_value.tensor[ind] = new_value.tensor tensor_value.tensor_optimistic_values[ind] = 0 else: raise TypeError(f'Unknown value type {type(new_value)}') return new_state