Source code for concepts.pdsketch.csp_solvers.dpll_sampling

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dpll_sampling.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 11/20/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""The DPLL-Sampling algorithm for solving CSPs. This algorithm is specifically designed
for solving CSPs with mixed Boolean and continuous variables. At a high level, the algorithm
uses DPLL-style search to find a solution to Boolean variables. After the value for all Boolean
variables are fixed, the algorithm uses a sampling-based method to find a solution to the continuous variables.
"""

from typing import Optional, Union, Tuple, List, Dict

import itertools
import collections
import jacinle

from concepts.dsl.dsl_types import BOOL, ValueType, NamedTensorValueType, TensorValueTypeBase, PyObjValueType
from concepts.dsl.dsl_functions import Function
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.constraint import Constraint, GroupConstraint, ConstraintSatisfactionProblem, OptimisticValue, OptimisticValueRecord, Assignment, AssignmentType, AssignmentDict, SimulationFluentConstraintFunction
from concepts.dsl.expression import VariableExpression, ValueOutputExpression, BoolOpType, QuantificationOpType, BoolExpression, PredicateEqualExpression, FunctionApplicationExpression
from concepts.dsl.expression import is_and_expr, iter_exprs

from concepts.pdsketch.executor import PDSketchExecutor, GeneratorManager
from concepts.pdsketch.predicate import Predicate
from concepts.pdsketch.generator import Generator, FancyGenerator

__all__ = [
    'CSPNotSolvable', 'CSPNoGenerator', 'ConstraintList',
    'dpll_apply_assignments',
    'dpll_filter_deterministic_equal', 'dpll_filter_deterministic_clauses', 'dpll_filter_unused_rhs',
    'dpll_find_bool_variable', 'dpll_find_grounded_function_application', 'dpll_find_typegen_variable', 'dpll_find_gen_variable_combined',
    'GeneratorMatchingInputType', 'GeneratorMatchingOutputType', 'GeneratorMatchingIOReturnType', 'GeneratorMatchingReturnType',
    'csp_dpll_sampling_solve', 'csp_dpll_simplify'
]

ConstraintList = List[Optional[Union[Constraint, GroupConstraint]]]


[docs]class CSPNotSolvable(Exception): """An exception raised when the CSP is not solvable.""" pass
[docs]class CSPNoGenerator(Exception): """An exception raised when there is no generator that can be matched in order to solve the CSP. Note that this does not mean that the CSP is not solvable.""" pass
def _determined(*args) -> bool: """Helper function: if all arguments are determined. Args: *args: the arguments. Returns: True if all arguments are determined. """ for x in args: if isinstance(x, OptimisticValue): return False return True def _ground_assignment_value_partial(assignments: Dict[int, Assignment], dtype: ValueType, identifier: int) -> Union[TensorValue, OptimisticValue]: """Get the value of a variable based on the assignment dictionary. It will follow the EQUAL assignment types. The key difference between the :meth:`~concepts.dsl.constraint.ground_assignment_value` is the return type of this function. Specifically, the :meth:`concepts.dsl.constraint.ground_assignment_value` (exported) function returns the actual Value object. This function returns a wrapped Value object. When the value is not determined, it will return an OptimisticValue. Args: assignments: the assignment dictionary. dtype: the type of the variable. Returns: the value of the variable, wrapped in either :class:`~concepts.dsl.constraint.DeterminedValue` or :class:`~concepts.dsl.constraint.OptimisticValue`. """ while identifier in assignments and assignments[identifier].t is AssignmentType.EQUAL: identifier = assignments[identifier].d if identifier in assignments: return assignments[identifier].d return OptimisticValue(dtype, identifier)
[docs]def dpll_apply_assignments(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> ConstraintList: """Apply the assignments to the constraints. Essentially, it replaces all variables that have been determined in the assignment dictionary with the value. This function will also check all the constraints to make sure that the assignments are valid. When a constraint is invalid, this function will raises CSPNotSolvable. Otherwise, constraints that have been satisfied will be removed from the list. Args: executor: the executor. constraints: the list of constraints. assignments: the dictionary of assignments. Returns: the list of constraints that have not been satisfied. """ new_constraints = list() for c in constraints: if c is None: continue if isinstance(c, GroupConstraint): has_unsatisfied_subconstraint = False for c2 in constraints: if not c2.is_group_constraint and c2.group is not None and c2.group == c: has_unsatisfied_subconstraint = True break if has_unsatisfied_subconstraint: new_constraints.append(c) continue # If the return value of the constraint is ignored, simply ignore the entire constraint. if isinstance(c.rv, OptimisticValue) and c.rv.identifier in assignments and assignments[c.rv.identifier].t is AssignmentType.IGNORE: continue # Ground the arguments and the return value. new_args = list(c.arguments) for i, x in enumerate(c.arguments): if isinstance(x, OptimisticValue) and x.identifier in assignments: new_args[i] = _ground_assignment_value_partial(assignments, x.dtype, x.identifier) new_rv = c.rv if isinstance(c.rv, OptimisticValue) and c.rv.identifier in assignments: new_rv = _ground_assignment_value_partial(assignments, c.rv.dtype, c.rv.identifier) # Evaluate the constraint. nc = Constraint(c.function, new_args, new_rv, note=c.note, group=c.group) if _determined(nc.rv) and _determined(*nc.arguments) and not isinstance(nc.function, SimulationFluentConstraintFunction): if _check_constraint(executor, nc): continue else: raise CSPNotSolvable(f'Constraint {c} is not satisfied.') new_constraints.append(nc) return new_constraints
def _check_constraint(executor: PDSketchExecutor, c: Constraint) -> bool: """Helper function: check if a constraint has been satisfied based on the current assignments. Args: executor: the executor. c: the constraint. Returns: True if the constraint has been satisfied. """ if c.function is BoolOpType.NOT: return c.arguments[0].item() == (not c.rv.item()) elif c.function in (QuantificationOpType.FORALL, BoolOpType.AND): return all([x.item() for x in c.arguments]) == c.rv.item() elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR): return any([x.item() for x in c.arguments]) == c.rv.item() elif c.is_equal_constraint: if c.arguments[0].dtype == BOOL: return (c.arguments[0].item() == c.arguments[1].item()) == c.rv.item() else: return _check_eq(executor, c.arguments[0].dtype, c.arguments[0], c.arguments[1]) == c.rv.item() elif isinstance(c.function, SimulationFluentConstraintFunction): return False else: assert isinstance(c.function, Predicate) # NB(Jiayuan Mao @ 09/05): for generator placeholders, they can only be set true through the corresponding generators. if c.function.is_generator_placeholder: return False func = executor.get_function_implementation(c.function.name) rv = func(*c.arguments, return_type=c.function.return_type) if rv.dtype == BOOL: return (rv.item() > 0.5) == c.rv.item() else: return _check_eq(executor, c.function.return_type, rv, c.rv.item()) def _check_eq(executor: PDSketchExecutor, dtype: Union[TensorValueTypeBase, PyObjValueType], v1: TensorValue, v2: TensorValue) -> bool: """Helper function: check if two values are equal. Internally used by :meth:`_check_constraint`. Args: executor: the executor. dtype: the type of the values. v1: the first value. v2: the second value. Returns: True if the two values are equal. """ if isinstance(dtype, TensorValueTypeBase) and dtype.is_intrinsically_quantized(): return (v1.tensor == v2.tensor).item() assert isinstance(dtype, NamedTensorValueType) eq_function = executor.get_function_implementation('type::' + dtype.typename + '::equal') return bool(eq_function(v1, v2).item())
[docs]def dpll_filter_deterministic_equal(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Tuple[bool, ConstraintList]: """Filter the constraints to remove the ones that are determined to be equal. Args: executor: the executor. constraints: the list of constraints. assignments: the dictionary of assignments. Returns: a tuple of (whether we have made progress, the list of constraints that have not been satisfied). """ progress = False for i, c in enumerate(constraints): if not c.is_group_constraint and c.is_equal_constraint: if isinstance(c.rv, TensorValue): # If the constraint looks like `x == x`, we can simply ignore it. if isinstance(c.arguments[0], OptimisticValue) and isinstance(c.arguments[1], OptimisticValue) and c.arguments[0].identifier == c.arguments[1].identifier: if c.rv.item(): constraints[i] = None progress = True continue else: raise CSPNotSolvable(f'Constraint {c} can not be satisfied: {c.arguments[0]} is not equal to itself.') # If the constraint looks like: (x == y) == True, then we can set x = y. if c.rv.item(): if isinstance(c.arguments[0], OptimisticValue): if isinstance(c.arguments[1], OptimisticValue): assignments[c.arguments[0].identifier] = Assignment(AssignmentType.EQUAL, c.arguments[1].identifier) else: assignments[c.arguments[0].identifier] = Assignment(AssignmentType.VALUE, c.arguments[1]) constraints[i] = None elif isinstance(c.arguments[1], OptimisticValue): if isinstance(c.arguments[0], OptimisticValue): assignments[c.arguments[1].identifier] = Assignment(AssignmentType.EQUAL, c.arguments[0].identifier) else: assignments[c.arguments[1].identifier] = Assignment(AssignmentType.VALUE, c.arguments[0]) constraints[i] = None else: raise AssertionError('Sanity check failed.') progress = True else: if c.arguments[0].dtype == BOOL: constraints[i] = Constraint(BoolOpType.NOT, [c.arguments[0]], c.arguments[1], note=c.note, group=c.group) progress = True else: if isinstance(c.arguments[0], OptimisticValue) and isinstance(c.arguments[1], OptimisticValue) and c.arguments[0].identifier == c.arguments[1].identifier: assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, True) constraints[i] = None progress = True if progress: return progress, dpll_apply_assignments(executor, constraints, assignments) return progress, constraints
[docs]def dpll_filter_unused_rhs(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment], index2record: Dict[int, OptimisticValueRecord]) -> ConstraintList: """Filter out constraints that only appear once in the RHS of the constraints. In this case, the variable can be ignored and the related constraints can be removed. Args: executor: the executor. constraints: the list of constraints. assignments: the dictionary of assignments. index2record: the dictionary of variable records. Returns: the list of constraints that have not been satisfied, after removing all unused variables. """ used: Dict[int, int] = collections.defaultdict(int) for i, record in index2record.items(): if record.actionable: used[i] += 100 for c in constraints: if c.is_group_constraint: continue for x in c.arguments: if isinstance(x, OptimisticValue): used[x.identifier] += 100 # as long as a variable has appeared in the lhs of a constraint, it is used. if isinstance(c.rv, OptimisticValue): used[c.rv.identifier] += 1 # if the variable has only appeared in the rhs of a constraint for once, it is not used. for k, v in used.items(): if v == 1: assignments[k] = Assignment(AssignmentType.IGNORE, None) return dpll_apply_assignments(executor, constraints, assignments)
[docs]def dpll_filter_deterministic_clauses(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Tuple[bool, ConstraintList]: """Filter out Boolean constraints that have been determined. For example, AND(x, y, z) == true, then x == true, y == true, z == true. This function will remove the AND(x, y, z) constraint. There is another case that this function handles: for Boolean constraints, if everything on the LHS is determined, then we can determine the RHS. Args: executor: the executor. constraints: the list of constraints. assignments: the dictionary of assignments. Returns: a tuple of (whether we have made progress, the list of constraints that have not been satisfied). """ progress = False for i, c in enumerate(constraints): if c.is_group_constraint: continue if isinstance(c.function, (QuantificationOpType, BoolOpType)): if _determined(c.rv): if ( (c.function in (QuantificationOpType.FORALL, BoolOpType.AND)) or (c.function in (QuantificationOpType.EXISTS, BoolOpType.OR) and len(c.arguments) <= 1) ): if c.rv.item(): for x in c.arguments: if isinstance(x, OptimisticValue): assignments[x.identifier] = Assignment(AssignmentType.VALUE, True) progress = True # print('assign', optimistic_value_id(x), True) elif not x: raise CSPNotSolvable() else: # AND(x, y, z) == false determined_values = [x for x in c.arguments if _determined(x)] if False in determined_values: progress = True constraints[i] = None elif len(determined_values) == len(c.arguments): raise CSPNotSolvable() elif len(determined_values) == len(c.arguments) - 1: for x in c.arguments: if not _determined(x): progress = True assignments[x.identifier] = Assignment(AssignmentType.VALUE, False) elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR): if not c.rv.item(): for x in c.arguments: if isinstance(x, OptimisticValue): progress = True assignments[x.identifier] = Assignment(AssignmentType.VALUE, False) # print('assign', optimistic_value_id(x), False) elif x: raise CSPNotSolvable() else: # OR(x, y, z) == TRUE determined_values = [x.item() for x in c.arguments if _determined(x)] if True in determined_values: progress = True constraints[i] = None elif len(determined_values) == len(c.arguments): raise CSPNotSolvable() elif len(determined_values) == len(c.arguments) - 1: for x in c.arguments: if not _determined(x): progress = True assignments[x.identifier] = Assignment(AssignmentType.VALUE, True) elif c.function is BoolOpType.NOT: progress = True assignments[c.arguments[0].identifier] = Assignment(AssignmentType.VALUE, not c.rv.item()) elif _determined(*c.arguments): progress = True if c.function in (QuantificationOpType.FORALL, BoolOpType.AND): assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, all(x.item() for x in c.arguments)) elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR): assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, any(x.item() for x in c.arguments)) elif c.function is BoolOpType.NOT: assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, not c.arguments[0].item()) elif c.is_equal_constraint and _determined(*c.arguments): progress = True assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, c.arguments[0].item() == c.arguments[1].item()) if progress: return progress, dpll_apply_assignments(executor, constraints, assignments) return progress, constraints
[docs]def dpll_filter_duplicated_constraints(executor: PDSketchExecutor, constraints: ConstraintList) -> Tuple[bool, ConstraintList]: """Filter out duplicated constraints. For example, if we have x == 1 and x == 1, then we can remove one of them. Args: executor: the executor. constraints: the list of constraints. Returns: a tuple of (whether we have made progress, the list of constraints that have not been satisfied). """ progress = False string_set = set() for i, c in enumerate(constraints): if c.is_group_constraint: continue # TODO(Jiayuan Mao @ 2023/11/24): since constraint_str contains shortened encodings for TensorValues, we should not use it here as a hash. cstr = c.constraint_str() if cstr in string_set: progress = True constraints[i] = None else: string_set.add(cstr) if progress: return progress, dpll_apply_assignments(executor, constraints, {}) return progress, constraints
[docs]def dpll_find_bool_variable(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Optional[int]: """Find a Boolean variable that is not determined. As a heuristic, we will look for the variable that appear in the maximum number of constraints. Args: constraints: the list of constraints. assignments: the dictionary of assignments. Returns: the variable that is not determined. """ count: Dict[int, int] = collections.defaultdict(int) for c in constraints: if c.is_group_constraint: continue for x in itertools.chain(c.arguments, [c.rv]): if isinstance(x, OptimisticValue) and x.identifier not in assignments and x.dtype == BOOL: count[x.identifier] += 1 if len(count) == 0: return None return max(count, key=count.get)
[docs]def dpll_find_grounded_function_application(executor: PDSketchExecutor, constraints: ConstraintList) -> Optional[Constraint]: """Find a function application whose arguments are all determined. Args: executor: the executor. constraints: the list of constraints. Returns: the function application that is not determined. """ for c in constraints: if c.is_group_constraint: continue if _determined(*c.arguments) and isinstance(c.function, Function): return c return None
[docs]def dpll_find_typegen_variable(executor: PDSketchExecutor, dtype: ValueType) -> Optional[Generator]: assert isinstance(dtype, NamedTensorValueType) for g in executor.domain.generators.values(): if len(g.function.arguments) == 0 and g.function.return_type[0] == dtype: if isinstance(g.certifies, BoolExpression) and g.certifies.bool_op is BoolOpType.AND: return g return None
GeneratorMatchingInputType = List[Optional[TensorValue]] GeneratorMatchingOutputType = List[Optional[OptimisticValue]] GeneratorMatchingIOReturnType = Tuple[Optional[GeneratorMatchingInputType], Optional[GeneratorMatchingOutputType]] def _match_generator(c: Constraint, g: Generator, certifies_expr: Optional[ValueOutputExpression] = None, allow_star_matching: bool = False) -> GeneratorMatchingIOReturnType: def gen_input_output(func_arguments, rv_variable=None): inputs: GeneratorMatchingInputType = [None for _ in range(len(g.input_vars))] outputs: GeneratorMatchingOutputType = [None for _ in range(len(g.output_vars))] for argc, argg in zip(c.arguments, func_arguments): if isinstance(argc, OptimisticValue): if argg.name.startswith('?g'): outputs[int(argg.name[2:])] = argc else: return None, None else: if argg.name.startswith('?c'): inputs[int(argg.name[2:])] = argc elif allow_star_matching and argg.name == '??': continue else: return None, None if rv_variable is not None: if rv_variable.name.startswith('?c'): inputs[int(rv_variable.name[2:])] = c.rv else: return None, None return inputs, outputs if certifies_expr is None: certifies_expr = g.flatten_certifies if isinstance(c.rv, TensorValue) and c.rv.dtype == BOOL: if c.rv.item(): # match (pred ?x ?y) == True if isinstance(certifies_expr, FunctionApplicationExpression): if c.function.name == certifies_expr.function.name: return gen_input_output(certifies_expr.arguments) else: if isinstance(certifies_expr, BoolExpression) and certifies_expr.bool_op is BoolOpType.NOT: # match (pred ?x ?y) == False inner_expr = certifies_expr.arguments[0] if isinstance(inner_expr, FunctionApplicationExpression): if c.function.name == inner_expr.function.name: return gen_input_output(inner_expr.arguments) elif c.is_equal_constraint and isinstance(inner_expr, PredicateEqualExpression): # match (equal ?x ?y) == False if c.arguments[0].dtype == inner_expr.predicate.return_type: return gen_input_output([inner_expr.predicate, inner_expr.value]) if isinstance(c.rv, TensorValue) and isinstance(c.function, Function): # match (pred ?x ?y) == ?z if isinstance(certifies_expr, PredicateEqualExpression) and c.function.name == certifies_expr.predicate.function.name: if isinstance(certifies_expr.value, VariableExpression): return gen_input_output(certifies_expr.predicate.arguments, certifies_expr.value) return None, None GeneratorMatchingReturnType = Optional[List[Tuple[Constraint, Generator, Optional[GeneratorMatchingInputType], Optional[GeneratorMatchingOutputType]]]] def _find_gen_variable_group(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType: all_generators = list() for c in constraints: if c.is_group_constraint: for g in c.candidate_generators: all_generators.append((c, *g)) if len(all_generators) == 0: return None return all_generators def _find_gen_variable(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType: # Step 1: find all applicable generators. all_generators = list() for c in constraints: if c.is_group_constraint: continue for g in executor.domain.generators.values(): i, o = _match_generator(c, g) # jacinle.log_function.print('matching', c, g, i, o) if i is not None: all_generators.append((c, g, i, o)) # Step 2: find if there is any variable with only one generator. target_to_generator: Dict[int, list] = collections.defaultdict(list) for c, g, i, o in all_generators: for target in o: target_to_generator[target.identifier].append((c, g, i, o)) for target, generators in target_to_generator.items(): if len(target_to_generator[target]) == 1: return target_to_generator[target] if len(all_generators) > 0: max_priority = max([r[1].priority for r in all_generators]) all_generators = [r for r in all_generators if r[1].priority == max_priority] return all_generators return None def _find_gen_variable_advanced(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType: def match_io(list1, list2): for x, y in zip(list1, list2): if x is None or y is None: continue if isinstance(x, OptimisticValue): if isinstance(y, OptimisticValue): if x.identifier != y.identifier: return False else: return False elif isinstance(x, TensorValue): if isinstance(y, TensorValue): if not _check_eq(executor, x.dtype, x, y): return False else: return False else: raise TypeError(f'Invalid type: {type(x)}.') return True def is_star_expression(sub_certifies): for x in iter_exprs(sub_certifies): if isinstance(x, VariableExpression) and x.name == '??': return True return False generator2matched = collections.defaultdict(list) for c in constraints: if c.is_group_constraint: continue for g in executor.domain.generators.values(): if is_and_expr(g.flatten_certifies): for sub_certifies_index, sub_certifies in enumerate(g.flatten_certifies.arguments): i, o = _match_generator(c, g, sub_certifies, allow_star_matching=True) if i is not None: generator2matched[g].append((c, i, o, sub_certifies_index)) all_generators = list() # TODO: implement the exact matching algorithm. for g, matched in generator2matched.items(): all_matches = list() for result_index in range(len(matched)): c, i, o, sub_certifies_index = matched[result_index] used = False # constraints, g, i, o, matched_sub_certifies_index for mcs, mg, mi, mo, matched_sci in all_matches: if match_io(i, mi) and match_io(o, mo): for j in range(len(mi)): if mi[j] is None: mi[j] = i[j] for j in range(len(mo)): if mo[j] is None: mo[j] = o[j] mcs.append(c) matched_sci.add(sub_certifies_index) used = True break if not used: all_matches.append(([c], g, i, o, {sub_certifies_index})) this_is_star_expression = list() for sub_certifies in g.flatten_certifies.arguments: this_is_star_expression.append(is_star_expression(sub_certifies)) for mcs, mg, mi, mo, matched_sci in all_matches: # the matched constraints should cover all sentences in the flatten_expression. match_succ = True for sub_certifies_index, sub_certifies in enumerate(g.flatten_certifies.arguments): if not this_is_star_expression[sub_certifies_index] and sub_certifies_index not in matched_sci: match_succ = False break if not match_succ: continue if None not in mi and None not in mo: all_generators.append((mcs, mg, mi, mo)) return all_generators if len(all_generators) > 0 else None def _find_fancy_gen_variable( executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, constraints: ConstraintList, assignments: AssignmentDict ) -> Optional[List[Tuple[List[Constraint], Dict[int, Union[TensorValueTypeBase, PyObjValueType]], FancyGenerator]]]: results = list() for g in sorted(executor.domain.fancy_generators.values(), key=lambda generator: generator.priority, reverse=True): g: FancyGenerator this_constraints = list() this_variable_dtypes = dict() assert is_and_expr(g.flatten_certifies) for certifies_expr in g.flatten_certifies.arguments: assert isinstance(certifies_expr, FunctionApplicationExpression) for arg in certifies_expr.arguments: assert isinstance(arg, VariableExpression) and arg.name == '??' for c in constraints: if c.function.name == certifies_expr.function.name: this_constraints.append(c) for arg in itertools.chain(c.arguments, [c.rv]): if isinstance(arg, OptimisticValue): this_variable_dtypes[arg.identifier] = arg.dtype if len(this_constraints) == 0: continue results.append((this_constraints, this_variable_dtypes, g)) if len(results) > 0: return results return None
[docs]def dpll_find_gen_variable_combined(executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, constraints: ConstraintList, assignments: AssignmentDict) -> GeneratorMatchingReturnType: """Combine the generator matching in the following order: 1. Use :func:`_find_gen_variable` to find the generator with the highest priority. 2. Use :func:`_find_gen_variable_advanced` to find the generator with the highest priority, using star-matching. 3. Use :func:`_find_typegen_variable` to find the generator with the highest priority, using type-matching. """ rv = _find_gen_variable_group(executor, constraints) if rv is not None: return rv rv = _find_gen_variable(executor, constraints) if rv is not None: return rv rv = _find_gen_variable_advanced(executor, constraints) if rv is not None: return rv for name, record in csp.index2record.items(): dtype = record.dtype if name not in assignments and isinstance(dtype, NamedTensorValueType): g = dpll_find_typegen_variable(executor, dtype) if g is not None: rv = [(None, g, [], [OptimisticValue(dtype, name)])] return rv return None
[docs]def csp_dpll_sampling_solve( executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, *, generator_manager: Optional[GeneratorManager] = None, max_generator_trials: int = 3, enable_ignore: bool = False, solvable_only: bool = False, verbose: bool = False ) -> Optional[Union[bool, AssignmentDict]]: """Solve the constraint satisfaction problem using the DPLL-sampling algorithm. Args: executor: the executor. csp: the constraint satisfaction problem. generator_manager: the generator manager. max_generator_trials: the maximum number of trials for each generator. enable_ignore: whether to ignore constraints whose RHS value is not determined. solvable_only: whether to only return whether the problem is solvable, without returning the solution. verbose: whether to print verbose information. Returns: When `solvable_only` is True, return a single Boolean value indicating whether the problem is solvable. When `solvable_only` is False, return an assignment dictionary. When the problem is not solvable, return None. Raises: CSPNotSolvable: when the problem is not solvable. CSPNoGenerator: when no generator can be found to solve the problem. However, the problem may still be solvable. """ if generator_manager is None: generator_manager = GeneratorManager(executor, store_history=False) constraints = csp.constraints.copy() if verbose: jacinle.log_function.print('csp_dpll_sampling_solve: max_generator_trials =', max_generator_trials) jacinle.log_function.print('Constraints:', len(constraints)) jacinle.log_function.print(*[jacinle.indent_text(str(c)) for c in constraints], sep='\n') @jacinle.log_function(verbose=False) def dfs(constraints, assignments): if len(constraints) == 0: return assignments progress = True while progress: progress, constraints = dpll_filter_deterministic_equal(executor, constraints, assignments) if enable_ignore: constraints = dpll_filter_unused_rhs(executor, constraints, assignments, csp.index2record) progress = True while progress: progress, constraints = dpll_filter_deterministic_clauses(executor, constraints, assignments) progress, constraints = dpll_filter_duplicated_constraints(executor, constraints) if verbose: jacinle.log_function.print('Remaining constraints:', len(constraints)) jacinle.log_function.print(*constraints, sep='\n') if len(constraints) == 0: return assignments if (next_bool_var := dpll_find_bool_variable(executor, constraints, assignments)) is not None: assignments_true = assignments.copy() assignments_true[next_bool_var] = Assignment(AssignmentType.VALUE, True) try: constraints_true = dpll_apply_assignments(executor, constraints, assignments_true) return dfs(constraints_true, assignments_true) except CSPNotSolvable: pass assignments_false = assignments.copy() assignments_false[next_bool_var] = Assignment(AssignmentType.VALUE, False) try: constraints_false = dpll_apply_assignments(executor, constraints, assignments_false) return dfs(constraints_false, assignments_false) except CSPNotSolvable: pass raise CSPNotSolvable() elif (next_fapp := dpll_find_grounded_function_application(executor, constraints)) is not None: function: Predicate = next_fapp.function arguments = next_fapp.arguments external_function = executor.get_function_implementation(function.name) output = external_function(*arguments, auto_broadcast=False) target = next_fapp.rv new_assignments = assignments.copy() new_assignments[target.identifier] = Assignment(AssignmentType.VALUE, output) try: new_constraints = constraints.copy() new_constraints[new_constraints.index(next_fapp)] = None new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments) return dfs(new_constraints, new_assignments) except CSPNotSolvable: pass raise CSPNotSolvable() elif (next_gen_vars := _find_fancy_gen_variable(executor, csp, constraints, assignments)) is not None: if len(next_gen_vars) > 0 and next_gen_vars[0][1].unsolvable: raise CSPNotSolvable() for vv in next_gen_vars: c, dtype_mapping, g = vv generator = generator_manager.call(g, max_generator_trials, tuple(), c) generator = iter(generator) for j in range(max_generator_trials): try: output_index, outputs = next(generator) except StopIteration: break if outputs is None: break assert isinstance(outputs, dict) new_assignments = assignments.copy() for k, v in outputs.items(): # v = TensorValue.from_scalar(v, dtype_mapping[k]) new_assignments[k] = Assignment(AssignmentType.VALUE, v) try: new_constraints = constraints.copy() for cc in c: new_constraints[new_constraints.index(cc)] = None new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments) return dfs(new_constraints, new_assignments) except CSPNotSolvable: pass raise CSPNotSolvable() elif (next_gen_vars := dpll_find_gen_variable_combined(executor, csp, constraints, assignments)) is not None: if len(next_gen_vars) > 1: # jacinle.log_function.print('generator orders', *[str(vv[1]).split('\n')[0] for vv in next_gen_vars], sep='\n ') pass if len(next_gen_vars) > 0 and next_gen_vars[0][1].unsolvable: raise CSPNotSolvable() for vv in next_gen_vars: c, g, args, outputs_target = vv generator = generator_manager.call(g, max_generator_trials, tuple(args), c) generator = iter(generator) for j in range(max_generator_trials): try: output_index, outputs = next(generator) except StopIteration: break if outputs is None: break if not isinstance(outputs, tuple) and g.function.ftype.is_singular_return: outputs = (outputs, ) if not g.function.ftype.is_singular_return: assert len(outputs) == len(g.function.return_type) # jacinle.log_function.print('running generator', g, f'count = {j}') new_assignments = assignments.copy() for output, target in zip(outputs, outputs_target): if isinstance(target, ListValue): assert isinstance(output, ListValue) assert len(output) == len(target) for k, v in zip(target.values, output.values): new_assignments[k.identifier] = Assignment(AssignmentType.VALUE, v) else: # output = TensorValue.from_scalar(output, target.dtype) new_assignments[target.identifier] = Assignment(AssignmentType.VALUE, output) # jacinle.log_function.print('assigned', target, output) try: new_constraints = constraints.copy() if isinstance(c, list): for cc in c: new_constraints[new_constraints.index(cc)] = None else: new_constraints[new_constraints.index(c)] = None if isinstance(c, GroupConstraint): for i, cc in enumerate(new_constraints): if cc is not None and not cc.is_group_constraint and cc.group is not None and id(cc.group) == id(c): new_constraints[i] = None # jacinle.log_function.print('new assignments', new_assignments) # jacinle.log_function.print('new_constraints', new_constraints) new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments) return dfs(new_constraints, new_assignments) except CSPNotSolvable: pass raise CSPNotSolvable() else: # jacinle.log_function.print('Can not find a generator. Constraints:\n ' + '\n '.join([str(x) for x in constraints])) raise CSPNoGenerator('Can not find a generator. Constraints:\n ' + '\n '.join([str(x) for x in constraints])) try: assignments = dfs(constraints, {}) if solvable_only: return True for name, record in csp.index2record.items(): dtype = record.dtype if name not in assignments: g = dpll_find_typegen_variable(executor, dtype) if g is None: raise NotImplementedError('Can not find a generator for unbounded variable {}, type {}.'.format(name, dtype)) else: output, = executor.get_function_implementation(g.function.name)() assignments[name] = Assignment(AssignmentType.VALUE, TensorValue.from_scalar(output, dtype)) return assignments except CSPNotSolvable: return None except CSPNoGenerator: raise
[docs]def csp_dpll_simplify( executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, enable_ignore: bool = True, return_assignments: bool = False ) -> Union[ConstraintSatisfactionProblem, Tuple[ConstraintSatisfactionProblem, AssignmentDict]]: """Simplify the CSP using DPLL algorithm. Args: executor: the executor. csp: the CSP. enable_ignore: whether to ignore constraints whose RHS value is not determined. return_assignments: whether to return the assignments. Returns: the simplified CSP. """ constraints = csp.constraints.copy() assignments = dict() if len(constraints) == 0: return csp while True: nr_constraints = len(constraints) progress = True while progress: progress, constraints = dpll_filter_deterministic_equal(executor, constraints, assignments) if enable_ignore: constraints = dpll_filter_unused_rhs(executor, constraints, assignments, csp.index2record) progress = True while progress: progress, constraints = dpll_filter_deterministic_clauses(executor, constraints, assignments) if len(constraints) == nr_constraints: break if return_assignments: return csp.clone(constraints), assignments return csp.clone(constraints)