#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : dpll_sampling.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 11/20/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
"""The DPLL-Sampling algorithm for solving CSPs. This algorithm is specifically designed
for solving CSPs with mixed Boolean and continuous variables. At a high level, the algorithm
uses DPLL-style search to find a solution to Boolean variables. After the value for all Boolean
variables are fixed, the algorithm uses a sampling-based method to find a solution to the continuous variables.
"""
from typing import Optional, Union, Tuple, List, Dict
import itertools
import collections
import jacinle
from concepts.dsl.dsl_types import BOOL, ValueType, NamedTensorValueType, TensorValueTypeBase, PyObjValueType
from concepts.dsl.dsl_functions import Function
from concepts.dsl.value import ListValue
from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.constraint import Constraint, GroupConstraint, ConstraintSatisfactionProblem, OptimisticValue, OptimisticValueRecord, Assignment, AssignmentType, AssignmentDict, SimulationFluentConstraintFunction
from concepts.dsl.expression import VariableExpression, ValueOutputExpression, BoolOpType, QuantificationOpType, BoolExpression, PredicateEqualExpression, FunctionApplicationExpression
from concepts.dsl.expression import is_and_expr
from concepts.dsl.expression_utils import iter_exprs
from concepts.pdsketch.executor import PDSketchExecutor, GeneratorManager
from concepts.pdsketch.predicate import Predicate
from concepts.pdsketch.generator import Generator, FancyGenerator
__all__ = [
    'CSPNotSolvable', 'CSPNoGenerator', 'ConstraintList',
    'dpll_apply_assignments',
    'dpll_filter_deterministic_equal', 'dpll_filter_deterministic_clauses', 'dpll_filter_unused_rhs',
    'dpll_find_bool_variable', 'dpll_find_grounded_function_application', 'dpll_find_typegen_variable', 'dpll_find_gen_variable_combined',
    'GeneratorMatchingInputType', 'GeneratorMatchingOutputType', 'GeneratorMatchingIOReturnType', 'GeneratorMatchingReturnType',
    'csp_dpll_sampling_solve', 'csp_dpll_simplify'
]
ConstraintList = List[Optional[Union[Constraint, GroupConstraint]]]
[docs]
class CSPNotSolvable(Exception):
    """An exception raised when the CSP is not solvable."""
    pass 
[docs]
class CSPNoGenerator(Exception):
    """An exception raised when there is no generator that can be matched in order to solve the CSP.
    Note that this does not mean that the CSP is not solvable."""
    pass 
def _determined(*args) -> bool:
    """Helper function: if all arguments are determined.
    Args:
        *args: the arguments.
    Returns:
        True if all arguments are determined.
    """
    for x in args:
        if isinstance(x, OptimisticValue):
            return False
    return True
def _ground_assignment_value_partial(assignments: Dict[int, Assignment], dtype: ValueType, identifier: int) -> Union[TensorValue, OptimisticValue]:
    """Get the value of a variable based on the assignment dictionary. It will follow the EQUAL assignment types.
    The key difference between the :meth:`~concepts.dsl.constraint.ground_assignment_value` is the return type of this function.
    Specifically, the :meth:`concepts.dsl.constraint.ground_assignment_value` (exported) function returns the actual Value object.
    This function returns a wrapped Value object. When the value is not determined, it will return an OptimisticValue.
    Args:
        assignments: the assignment dictionary.
        dtype: the type of the variable.
    Returns:
        the value of the variable, wrapped in either :class:`~concepts.dsl.constraint.DeterminedValue` or :class:`~concepts.dsl.constraint.OptimisticValue`.
    """
    while identifier in assignments and assignments[identifier].t is AssignmentType.EQUAL:
        identifier = assignments[identifier].d
    if identifier in assignments:
        return assignments[identifier].d
    return OptimisticValue(dtype, identifier)
[docs]
def dpll_apply_assignments(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> ConstraintList:
    """Apply the assignments to the constraints. Essentially, it replaces all variables that have been determined in the assignment dictionary with the value.
    This function will also check all the constraints to make sure that the assignments are valid.
    When a constraint is invalid, this function will raises CSPNotSolvable. Otherwise, constraints that have been satisfied will be removed from the list.
    Args:
        executor: the executor.
        constraints: the list of constraints.
        assignments: the dictionary of assignments.
    Returns:
        the list of constraints that have not been satisfied.
    """
    new_constraints = list()
    for c in constraints:
        if c is None:
            continue
        if isinstance(c, GroupConstraint):
            has_unsatisfied_subconstraint = False
            for c2 in constraints:
                if not c2.is_group_constraint and c2.group is not None and c2.group == c:
                    has_unsatisfied_subconstraint = True
                    break
            if has_unsatisfied_subconstraint:
                new_constraints.append(c)
            continue
        # If the return value of the constraint is ignored, simply ignore the entire constraint.
        if isinstance(c.rv, OptimisticValue) and c.rv.identifier in assignments and assignments[c.rv.identifier].t is AssignmentType.IGNORE:
            continue
        # Ground the arguments and the return value.
        new_args = list(c.arguments)
        for i, x in enumerate(c.arguments):
            if isinstance(x, OptimisticValue) and x.identifier in assignments:
                new_args[i] = _ground_assignment_value_partial(assignments, x.dtype, x.identifier)
        new_rv = c.rv
        if isinstance(c.rv, OptimisticValue) and c.rv.identifier in assignments:
            new_rv = _ground_assignment_value_partial(assignments, c.rv.dtype, c.rv.identifier)
        # Evaluate the constraint.
        nc = Constraint(c.function, new_args, new_rv, note=c.note, group=c.group)
        if _determined(nc.rv) and _determined(*nc.arguments) and not isinstance(nc.function, SimulationFluentConstraintFunction):
            if _check_constraint(executor, nc):
                continue
            else:
                raise CSPNotSolvable(f'Constraint {c} is not satisfied.')
        new_constraints.append(nc)
    return new_constraints 
def _check_constraint(executor: PDSketchExecutor, c: Constraint) -> bool:
    """Helper function: check if a constraint has been satisfied based on the current assignments.
    Args:
        executor: the executor.
        c: the constraint.
    Returns:
        True if the constraint has been satisfied.
    """
    if c.function is BoolOpType.NOT:
        return c.arguments[0].item() == (not c.rv.item())
    elif c.function in (QuantificationOpType.FORALL, BoolOpType.AND):
        return all([x.item() for x in c.arguments]) == c.rv.item()
    elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR):
        return any([x.item() for x in c.arguments]) == c.rv.item()
    elif c.is_equal_constraint:
        if c.arguments[0].dtype == BOOL:
            return (c.arguments[0].item() == c.arguments[1].item()) == c.rv.item()
        else:
            return _check_eq(executor, c.arguments[0].dtype, c.arguments[0], c.arguments[1]) == c.rv.item()
    elif isinstance(c.function, SimulationFluentConstraintFunction):
        return False
    else:
        assert isinstance(c.function, Predicate)
        # NB(Jiayuan Mao @ 09/05): for generator placeholders, they can only be set true through the corresponding generators.
        if c.function.is_generator_placeholder:
            return False
        func = executor.get_function_implementation(c.function.name)
        rv = func(*c.arguments, return_type=c.function.return_type)
        if rv.dtype == BOOL:
            return (rv.item() > 0.5) == c.rv.item()
        else:
            return _check_eq(executor, c.function.return_type, rv, c.rv.item())
def _check_eq(executor: PDSketchExecutor, dtype: Union[TensorValueTypeBase, PyObjValueType], v1: TensorValue, v2: TensorValue) -> bool:
    """Helper function: check if two values are equal. Internally used by :meth:`_check_constraint`.
    Args:
        executor: the executor.
        dtype: the type of the values.
        v1: the first value.
        v2: the second value.
    Returns:
        True if the two values are equal.
    """
    if isinstance(dtype, TensorValueTypeBase) and dtype.is_intrinsically_quantized():
        return (v1.tensor == v2.tensor).item()
    assert isinstance(dtype, NamedTensorValueType)
    eq_function = executor.get_function_implementation('type::' + dtype.typename + '::equal')
    return bool(eq_function(v1, v2).item())
[docs]
def dpll_filter_deterministic_equal(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Tuple[bool, ConstraintList]:
    """Filter the constraints to remove the ones that are determined to be equal.
    Args:
        executor: the executor.
        constraints: the list of constraints.
        assignments: the dictionary of assignments.
    Returns:
        a tuple of (whether we have made progress, the list of constraints that have not been satisfied).
    """
    progress = False
    for i, c in enumerate(constraints):
        if not c.is_group_constraint and c.is_equal_constraint:
            if isinstance(c.rv, TensorValue):
                # If the constraint looks like `x == x`, we can simply ignore it.
                if isinstance(c.arguments[0], OptimisticValue) and isinstance(c.arguments[1], OptimisticValue) and c.arguments[0].identifier == c.arguments[1].identifier:
                    if c.rv.item():
                        constraints[i] = None
                        progress = True
                        continue
                    else:
                        raise CSPNotSolvable(f'Constraint {c} can not be satisfied: {c.arguments[0]} is not equal to itself.')
                # If the constraint looks like: (x == y) == True, then we can set x = y.
                if c.rv.item():
                    if isinstance(c.arguments[0], OptimisticValue):
                        if isinstance(c.arguments[1], OptimisticValue):
                            assignments[c.arguments[0].identifier] = Assignment(AssignmentType.EQUAL, c.arguments[1].identifier)
                        else:
                            assignments[c.arguments[0].identifier] = Assignment(AssignmentType.VALUE, c.arguments[1])
                        constraints[i] = None
                    elif isinstance(c.arguments[1], OptimisticValue):
                        if isinstance(c.arguments[0], OptimisticValue):
                            assignments[c.arguments[1].identifier] = Assignment(AssignmentType.EQUAL, c.arguments[0].identifier)
                        else:
                            assignments[c.arguments[1].identifier] = Assignment(AssignmentType.VALUE, c.arguments[0])
                        constraints[i] = None
                    else:
                        raise AssertionError('Sanity check failed.')
                    progress = True
                else:
                    if c.arguments[0].dtype == BOOL:
                        constraints[i] = Constraint(BoolOpType.NOT, [c.arguments[0]], c.arguments[1], note=c.note, group=c.group)
                        progress = True
            else:
                if isinstance(c.arguments[0], OptimisticValue) and isinstance(c.arguments[1], OptimisticValue) and c.arguments[0].identifier == c.arguments[1].identifier:
                    assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, True)
                    constraints[i] = None
                    progress = True
    if progress:
        return progress, dpll_apply_assignments(executor, constraints, assignments)
    return progress, constraints 
[docs]
def dpll_filter_unused_rhs(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment], index2record: Dict[int, OptimisticValueRecord]) -> ConstraintList:
    """Filter out constraints that only appear once in the RHS of the constraints. In this case, the variable can be ignored and the related constraints can be removed.
    Args:
        executor: the executor.
        constraints: the list of constraints.
        assignments: the dictionary of assignments.
        index2record: the dictionary of variable records.
    Returns:
        the list of constraints that have not been satisfied, after removing all unused variables.
    """
    used: Dict[int, int] = collections.defaultdict(int)
    for i, record in index2record.items():
        if record.actionable:
            used[i] += 100
    for c in constraints:
        if c.is_group_constraint:
            continue
        for x in c.arguments:
            if isinstance(x, OptimisticValue):
                used[x.identifier] += 100  # as long as a variable has appeared in the lhs of a constraint, it is used.
        if isinstance(c.rv, OptimisticValue):
            used[c.rv.identifier] += 1  # if the variable has only appeared in the rhs of a constraint for once, it is not used.
    for k, v in used.items():
        if v == 1:
            assignments[k] = Assignment(AssignmentType.IGNORE, None)
    return dpll_apply_assignments(executor, constraints, assignments) 
[docs]
def dpll_filter_deterministic_clauses(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Tuple[bool, ConstraintList]:
    """Filter out Boolean constraints that have been determined. For example, AND(x, y, z) == true, then
        x == true, y == true, z == true. This function will remove the AND(x, y, z) constraint.
    There is another case that this function handles: for Boolean constraints, if everything on the LHS is determined, then
    we can determine the RHS.
    Args:
        executor: the executor.
        constraints: the list of constraints.
        assignments: the dictionary of assignments.
    Returns:
        a tuple of (whether we have made progress, the list of constraints that have not been satisfied).
    """
    progress = False
    for i, c in enumerate(constraints):
        if c.is_group_constraint:
            continue
        if isinstance(c.function, (QuantificationOpType, BoolOpType)):
            if _determined(c.rv):
                if (
                    (c.function in (QuantificationOpType.FORALL, BoolOpType.AND)) or
                    (c.function in (QuantificationOpType.EXISTS, BoolOpType.OR) and len(c.arguments) <= 1)
                ):
                    if c.rv.item():
                        for x in c.arguments:
                            if isinstance(x, OptimisticValue):
                                assignments[x.identifier] = Assignment(AssignmentType.VALUE, True)
                                progress = True
                                # print('assign', optimistic_value_id(x), True)
                            elif not x:
                                raise CSPNotSolvable()
                    else:
                        # AND(x, y, z) == false
                        determined_values = [x for x in c.arguments if _determined(x)]
                        if False in determined_values:
                            progress = True
                            constraints[i] = None
                        elif len(determined_values) == len(c.arguments):
                            raise CSPNotSolvable()
                        elif len(determined_values) == len(c.arguments) - 1:
                            for x in c.arguments:
                                if not _determined(x):
                                    progress = True
                                    assignments[x.identifier] = Assignment(AssignmentType.VALUE, False)
                elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR):
                    if not c.rv.item():
                        for x in c.arguments:
                            if isinstance(x, OptimisticValue):
                                progress = True
                                assignments[x.identifier] = Assignment(AssignmentType.VALUE, False)
                                # print('assign', optimistic_value_id(x), False)
                            elif x:
                                raise CSPNotSolvable()
                    else:
                        # OR(x, y, z) == TRUE
                        determined_values = [x.item() for x in c.arguments if _determined(x)]
                        if True in determined_values:
                            progress = True
                            constraints[i] = None
                        elif len(determined_values) == len(c.arguments):
                            raise CSPNotSolvable()
                        elif len(determined_values) == len(c.arguments) - 1:
                            for x in c.arguments:
                                if not _determined(x):
                                    progress = True
                                    assignments[x.identifier] = Assignment(AssignmentType.VALUE, True)
                elif c.function is BoolOpType.NOT:
                    progress = True
                    assignments[c.arguments[0].identifier] = Assignment(AssignmentType.VALUE, not c.rv.item())
            elif _determined(*c.arguments):
                progress = True
                if c.function in (QuantificationOpType.FORALL, BoolOpType.AND):
                    assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, all(x.item() for x in c.arguments))
                elif c.function in (QuantificationOpType.EXISTS, BoolOpType.OR):
                    assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, any(x.item() for x in c.arguments))
                elif c.function is BoolOpType.NOT:
                    assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, not c.arguments[0].item())
        elif c.is_equal_constraint and _determined(*c.arguments):
            progress = True
            assignments[c.rv.identifier] = Assignment(AssignmentType.VALUE, c.arguments[0].item() == c.arguments[1].item())
    if progress:
        return progress, dpll_apply_assignments(executor, constraints, assignments)
    return progress, constraints 
[docs]
def dpll_filter_duplicated_constraints(executor: PDSketchExecutor, constraints: ConstraintList) -> Tuple[bool, ConstraintList]:
    """Filter out duplicated constraints. For example, if we have x == 1 and x == 1, then we can remove one of them.
    Args:
        executor: the executor.
        constraints: the list of constraints.
    Returns:
        a tuple of (whether we have made progress, the list of constraints that have not been satisfied).
    """
    progress = False
    string_set = set()
    for i, c in enumerate(constraints):
        if c.is_group_constraint:
            continue
        # TODO(Jiayuan Mao @ 2023/11/24): since constraint_str contains shortened encodings for TensorValues, we should not use it here as a hash.
        cstr = c.constraint_str()
        if cstr in string_set:
            progress = True
            constraints[i] = None
        else:
            string_set.add(cstr)
    if progress:
        return progress, dpll_apply_assignments(executor, constraints, {})
    return progress, constraints 
[docs]
def dpll_find_bool_variable(executor: PDSketchExecutor, constraints: ConstraintList, assignments: Dict[int, Assignment]) -> Optional[int]:
    """Find a Boolean variable that is not determined. As a heuristic, we will look for the variable that appear in the maximum number of constraints.
    Args:
        constraints: the list of constraints.
        assignments: the dictionary of assignments.
    Returns:
        the variable that is not determined.
    """
    count: Dict[int, int] = collections.defaultdict(int)
    for c in constraints:
        if c.is_group_constraint:
            continue
        for x in itertools.chain(c.arguments, [c.rv]):
            if isinstance(x, OptimisticValue) and x.identifier not in assignments and x.dtype == BOOL:
                count[x.identifier] += 1
    if len(count) == 0:
        return None
    return max(count, key=count.get) 
[docs]
def dpll_find_grounded_function_application(executor: PDSketchExecutor, constraints: ConstraintList) -> Optional[Constraint]:
    """Find a function application whose arguments are all determined.
    Args:
        executor: the executor.
        constraints: the list of constraints.
    Returns:
        the function application that is not determined.
    """
    for c in constraints:
        if c.is_group_constraint:
            continue
        if _determined(*c.arguments) and isinstance(c.function, Function):
            return c
    return None 
[docs]
def dpll_find_typegen_variable(executor: PDSketchExecutor, dtype: ValueType) -> Optional[Generator]:
    assert isinstance(dtype, NamedTensorValueType)
    for g in executor.domain.generators.values():
        if len(g.function.arguments) == 0 and g.function.return_type[0] == dtype:
            if isinstance(g.certifies, BoolExpression) and g.certifies.bool_op is BoolOpType.AND:
                return g
    return None 
GeneratorMatchingInputType = List[Optional[TensorValue]]
GeneratorMatchingOutputType = List[Optional[OptimisticValue]]
GeneratorMatchingIOReturnType = Tuple[Optional[GeneratorMatchingInputType], Optional[GeneratorMatchingOutputType]]
def _match_generator(c: Constraint, g: Generator, certifies_expr: Optional[ValueOutputExpression] = None, allow_star_matching: bool = False) -> GeneratorMatchingIOReturnType:
    def gen_input_output(func_arguments, rv_variable=None):
        inputs: GeneratorMatchingInputType = [None for _ in range(len(g.input_vars))]
        outputs: GeneratorMatchingOutputType = [None for _ in range(len(g.output_vars))]
        for argc, argg in zip(c.arguments, func_arguments):
            if isinstance(argc, OptimisticValue):
                if argg.name.startswith('?g'):
                    outputs[int(argg.name[2:])] = argc
                else:
                    return None, None
            else:
                if argg.name.startswith('?c'):
                    inputs[int(argg.name[2:])] = argc
                elif allow_star_matching and argg.name == '??':
                    continue
                else:
                    return None, None
        if rv_variable is not None:
            if rv_variable.name.startswith('?c'):
                inputs[int(rv_variable.name[2:])] = c.rv
            else:
                return None, None
        return inputs, outputs
    if certifies_expr is None:
        certifies_expr = g.flatten_certifies
    if isinstance(c.rv, TensorValue) and c.rv.dtype == BOOL:
        if c.rv.item():  # match (pred ?x ?y) == True
            if isinstance(certifies_expr, FunctionApplicationExpression):
                if c.function.name == certifies_expr.function.name:
                    return gen_input_output(certifies_expr.arguments)
        else:
            if isinstance(certifies_expr, BoolExpression) and certifies_expr.bool_op is BoolOpType.NOT:  # match (pred ?x ?y) == False
                inner_expr = certifies_expr.arguments[0]
                if isinstance(inner_expr, FunctionApplicationExpression):
                    if c.function.name == inner_expr.function.name:
                        return gen_input_output(inner_expr.arguments)
                elif c.is_equal_constraint and isinstance(inner_expr, PredicateEqualExpression):  # match (equal ?x ?y) == False
                    if c.arguments[0].dtype == inner_expr.predicate.return_type:
                        return gen_input_output([inner_expr.predicate, inner_expr.value])
    if isinstance(c.rv, TensorValue) and isinstance(c.function, Function):  # match (pred ?x ?y) == ?z
        if isinstance(certifies_expr, PredicateEqualExpression) and c.function.name == certifies_expr.predicate.function.name:
            if isinstance(certifies_expr.value, VariableExpression):
                return gen_input_output(certifies_expr.predicate.arguments, certifies_expr.value)
    return None, None
GeneratorMatchingReturnType = Optional[List[Tuple[Constraint, Generator, Optional[GeneratorMatchingInputType], Optional[GeneratorMatchingOutputType]]]]
def _find_gen_variable_group(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType:
    all_generators = list()
    for c in constraints:
        if c.is_group_constraint:
            for g in c.candidate_generators:
                all_generators.append((c, *g))
    if len(all_generators) == 0:
        return None
    return all_generators
def _find_gen_variable(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType:
    # Step 1: find all applicable generators.
    all_generators = list()
    for c in constraints:
        if c.is_group_constraint:
            continue
        for g in executor.domain.generators.values():
            i, o = _match_generator(c, g)
            # jacinle.log_function.print('matching', c, g, i, o)
            if i is not None:
                all_generators.append((c, g, i, o))
    # Step 2: find if there is any variable with only one generator.
    target_to_generator: Dict[int, list] = collections.defaultdict(list)
    for c, g, i, o in all_generators:
        for target in o:
            target_to_generator[target.identifier].append((c, g, i, o))
    for target, generators in target_to_generator.items():
        if len(target_to_generator[target]) == 1:
            return target_to_generator[target]
    if len(all_generators) > 0:
        max_priority = max([r[1].priority for r in all_generators])
        all_generators = [r for r in all_generators if r[1].priority == max_priority]
        return all_generators
    return None
def _find_gen_variable_advanced(executor: PDSketchExecutor, constraints: ConstraintList) -> GeneratorMatchingReturnType:
    def match_io(list1, list2):
        for x, y in zip(list1, list2):
            if x is None or y is None:
                continue
            if isinstance(x, OptimisticValue):
                if isinstance(y, OptimisticValue):
                    if x.identifier != y.identifier:
                        return False
                else:
                    return False
            elif isinstance(x, TensorValue):
                if isinstance(y, TensorValue):
                    if not _check_eq(executor, x.dtype, x, y):
                        return False
                else:
                    return False
            else:
                raise TypeError(f'Invalid type: {type(x)}.')
        return True
    def is_star_expression(sub_certifies):
        for x in iter_exprs(sub_certifies):
            if isinstance(x, VariableExpression) and x.name == '??':
                return True
        return False
    generator2matched = collections.defaultdict(list)
    for c in constraints:
        if c.is_group_constraint:
            continue
        for g in executor.domain.generators.values():
            if is_and_expr(g.flatten_certifies):
                for sub_certifies_index, sub_certifies in enumerate(g.flatten_certifies.arguments):
                    i, o = _match_generator(c, g, sub_certifies, allow_star_matching=True)
                    if i is not None:
                        generator2matched[g].append((c, i, o, sub_certifies_index))
    all_generators = list()
    # TODO: implement the exact matching algorithm.
    for g, matched in generator2matched.items():
        all_matches = list()
        for result_index in range(len(matched)):
            c, i, o, sub_certifies_index = matched[result_index]
            used = False
            # constraints, g, i, o, matched_sub_certifies_index
            for mcs, mg, mi, mo, matched_sci in all_matches:
                if match_io(i, mi) and match_io(o, mo):
                    for j in range(len(mi)):
                        if mi[j] is None:
                            mi[j] = i[j]
                    for j in range(len(mo)):
                        if mo[j] is None:
                            mo[j] = o[j]
                    mcs.append(c)
                    matched_sci.add(sub_certifies_index)
                    used = True
                    break
            if not used:
                all_matches.append(([c], g, i, o, {sub_certifies_index}))
        this_is_star_expression = list()
        for sub_certifies in g.flatten_certifies.arguments:
            this_is_star_expression.append(is_star_expression(sub_certifies))
        for mcs, mg, mi, mo, matched_sci in all_matches:
            # the matched constraints should cover all sentences in the flatten_expression.
            match_succ = True
            for sub_certifies_index, sub_certifies in enumerate(g.flatten_certifies.arguments):
                if not this_is_star_expression[sub_certifies_index] and sub_certifies_index not in matched_sci:
                    match_succ = False
                    break
            if not match_succ:
                continue
            if None not in mi and None not in mo:
                all_generators.append((mcs, mg, mi, mo))
    return all_generators if len(all_generators) > 0 else None
def _find_fancy_gen_variable(
    executor: PDSketchExecutor,
    csp: ConstraintSatisfactionProblem,
    constraints: ConstraintList, assignments: AssignmentDict
) -> Optional[List[Tuple[List[Constraint], Dict[int, Union[TensorValueTypeBase, PyObjValueType]], FancyGenerator]]]:
    results = list()
    for g in sorted(executor.domain.fancy_generators.values(), key=lambda generator: generator.priority, reverse=True):
        g: FancyGenerator
        this_constraints = list()
        this_variable_dtypes = dict()
        assert is_and_expr(g.flatten_certifies)
        for certifies_expr in g.flatten_certifies.arguments:
            assert isinstance(certifies_expr, FunctionApplicationExpression)
            for arg in certifies_expr.arguments:
                assert isinstance(arg, VariableExpression) and arg.name == '??'
            for c in constraints:
                if c.function.name == certifies_expr.function.name:
                    this_constraints.append(c)
                for arg in itertools.chain(c.arguments, [c.rv]):
                    if isinstance(arg, OptimisticValue):
                        this_variable_dtypes[arg.identifier] = arg.dtype
        if len(this_constraints) == 0:
            continue
        results.append((this_constraints, this_variable_dtypes, g))
    if len(results) > 0:
        return results
    return None
[docs]
def dpll_find_gen_variable_combined(executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, constraints: ConstraintList, assignments: AssignmentDict) -> GeneratorMatchingReturnType:
    """Combine the generator matching in the following order:
    1. Use :func:`_find_gen_variable` to find the generator with the highest priority.
    2. Use :func:`_find_gen_variable_advanced` to find the generator with the highest priority, using star-matching.
    3. Use :func:`_find_typegen_variable` to find the generator with the highest priority, using type-matching.
    """
    rv = _find_gen_variable_group(executor, constraints)
    if rv is not None:
        return rv
    rv = _find_gen_variable(executor, constraints)
    if rv is not None:
        return rv
    rv = _find_gen_variable_advanced(executor, constraints)
    if rv is not None:
        return rv
    for name, record in csp.index2record.items():
        dtype = record.dtype
        if name not in assignments and isinstance(dtype, NamedTensorValueType):
            g = dpll_find_typegen_variable(executor, dtype)
            if g is not None:
                rv = [(None, g, [], [OptimisticValue(dtype, name)])]
                return rv
    return None 
[docs]
def csp_dpll_sampling_solve(
    executor: PDSketchExecutor, csp: ConstraintSatisfactionProblem, *,
    generator_manager: Optional[GeneratorManager] = None,
    max_generator_trials: int = 3,
    enable_ignore: bool = False, solvable_only: bool = False,
    verbose: bool = False
) -> Optional[Union[bool, AssignmentDict]]:
    """Solve the constraint satisfaction problem using the DPLL-sampling algorithm.
    Args:
        executor: the executor.
        csp: the constraint satisfaction problem.
        generator_manager: the generator manager.
        max_generator_trials: the maximum number of trials for each generator.
        enable_ignore: whether to ignore constraints whose RHS value is not determined.
        solvable_only: whether to only return whether the problem is solvable, without returning the solution.
        verbose: whether to print verbose information.
    Returns:
        When `solvable_only` is True, return a single Boolean value indicating whether the problem is solvable.
        When `solvable_only` is False, return an assignment dictionary.
        When the problem is not solvable, return None.
    Raises:
        CSPNotSolvable: when the problem is not solvable.
        CSPNoGenerator: when no generator can be found to solve the problem. However, the problem may still be solvable.
    """
    if generator_manager is None:
        generator_manager = GeneratorManager(executor, store_history=False)
    constraints = csp.constraints.copy()
    if verbose:
        jacinle.log_function.print('csp_dpll_sampling_solve: max_generator_trials =', max_generator_trials)
        jacinle.log_function.print('Constraints:', len(constraints))
        jacinle.log_function.print(*[jacinle.indent_text(str(c)) for c in constraints], sep='\n')
    @jacinle.log_function(verbose=False)
    def dfs(constraints, assignments):
        if len(constraints) == 0:
            return assignments
        progress = True
        while progress:
            progress, constraints = dpll_filter_deterministic_equal(executor, constraints, assignments)
        if enable_ignore:
            constraints = dpll_filter_unused_rhs(executor, constraints, assignments, csp.index2record)
        progress = True
        while progress:
            progress, constraints = dpll_filter_deterministic_clauses(executor, constraints, assignments)
        progress, constraints = dpll_filter_duplicated_constraints(executor, constraints)
        if verbose:
            jacinle.log_function.print('Remaining constraints:', len(constraints))
            jacinle.log_function.print(*constraints, sep='\n')
        if len(constraints) == 0:
            return assignments
        if (next_bool_var := dpll_find_bool_variable(executor, constraints, assignments)) is not None:
            assignments_true = assignments.copy()
            assignments_true[next_bool_var] = Assignment(AssignmentType.VALUE, True)
            try:
                constraints_true = dpll_apply_assignments(executor, constraints, assignments_true)
                return dfs(constraints_true, assignments_true)
            except CSPNotSolvable:
                pass
            assignments_false = assignments.copy()
            assignments_false[next_bool_var] = Assignment(AssignmentType.VALUE, False)
            try:
                constraints_false = dpll_apply_assignments(executor, constraints, assignments_false)
                return dfs(constraints_false, assignments_false)
            except CSPNotSolvable:
                pass
            raise CSPNotSolvable()
        elif (next_fapp := dpll_find_grounded_function_application(executor, constraints)) is not None:
            function: Predicate = next_fapp.function
            arguments = next_fapp.arguments
            external_function = executor.get_function_implementation(function.name)
            output = external_function(*arguments, auto_broadcast=False)
            target = next_fapp.rv
            new_assignments = assignments.copy()
            new_assignments[target.identifier] = Assignment(AssignmentType.VALUE, output)
            try:
                new_constraints = constraints.copy()
                new_constraints[new_constraints.index(next_fapp)] = None
                new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments)
                return dfs(new_constraints, new_assignments)
            except CSPNotSolvable:
                pass
            raise CSPNotSolvable()
        elif (next_gen_vars := _find_fancy_gen_variable(executor, csp, constraints, assignments)) is not None:
            if len(next_gen_vars) > 0 and next_gen_vars[0][1].unsolvable:
                raise CSPNotSolvable()
            for vv in next_gen_vars:
                c, dtype_mapping, g = vv
                generator = generator_manager.call(g, max_generator_trials, tuple(), c)
                generator = iter(generator)
                for j in range(max_generator_trials):
                    try:
                        output_index, outputs = next(generator)
                    except StopIteration:
                        break
                    if outputs is None:
                        break
                    assert isinstance(outputs, dict)
                    new_assignments = assignments.copy()
                    for k, v in outputs.items():
                        # v = TensorValue.from_scalar(v, dtype_mapping[k])
                        new_assignments[k] = Assignment(AssignmentType.VALUE, v)
                    try:
                        new_constraints = constraints.copy()
                        for cc in c:
                            new_constraints[new_constraints.index(cc)] = None
                        new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments)
                        return dfs(new_constraints, new_assignments)
                    except CSPNotSolvable:
                        pass
            raise CSPNotSolvable()
        elif (next_gen_vars := dpll_find_gen_variable_combined(executor, csp, constraints, assignments)) is not None:
            if len(next_gen_vars) > 1:
                # jacinle.log_function.print('generator orders', *[str(vv[1]).split('\n')[0] for vv in next_gen_vars], sep='\n  ')
                pass
            if len(next_gen_vars) > 0 and next_gen_vars[0][1].unsolvable:
                raise CSPNotSolvable()
            for vv in next_gen_vars:
                c, g, args, outputs_target = vv
                generator = generator_manager.call(g, max_generator_trials, tuple(args), c)
                generator = iter(generator)
                for j in range(max_generator_trials):
                    try:
                        output_index, outputs = next(generator)
                    except StopIteration:
                        break
                    if outputs is None:
                        break
                    if not isinstance(outputs, tuple) and g.function.ftype.is_singular_return:
                        outputs = (outputs, )
                    if not g.function.ftype.is_singular_return:
                        assert len(outputs) == len(g.function.return_type)
                    # jacinle.log_function.print('running generator', g, f'count = {j}')
                    new_assignments = assignments.copy()
                    for output, target in zip(outputs, outputs_target):
                        if isinstance(target, ListValue):
                            assert isinstance(output, ListValue)
                            assert len(output) == len(target)
                            for k, v in zip(target.values, output.values):
                                new_assignments[k.identifier] = Assignment(AssignmentType.VALUE, v)
                        else:
                            # output = TensorValue.from_scalar(output, target.dtype)
                            new_assignments[target.identifier] = Assignment(AssignmentType.VALUE, output)
                            # jacinle.log_function.print('assigned', target, output)
                    try:
                        new_constraints = constraints.copy()
                        if isinstance(c, list):
                            for cc in c:
                                new_constraints[new_constraints.index(cc)] = None
                        else:
                            new_constraints[new_constraints.index(c)] = None
                            if isinstance(c, GroupConstraint):
                                for i, cc in enumerate(new_constraints):
                                    if cc is not None and not cc.is_group_constraint and cc.group is not None and id(cc.group) == id(c):
                                        new_constraints[i] = None
                        # jacinle.log_function.print('new assignments', new_assignments)
                        # jacinle.log_function.print('new_constraints', new_constraints)
                        new_constraints = dpll_apply_assignments(executor, new_constraints, new_assignments)
                        return dfs(new_constraints, new_assignments)
                    except CSPNotSolvable:
                        pass
            raise CSPNotSolvable()
        else:
            # jacinle.log_function.print('Can not find a generator. Constraints:\n  ' + '\n  '.join([str(x) for x in constraints]))
            raise CSPNoGenerator('Can not find a generator. Constraints:\n  ' + '\n  '.join([str(x) for x in constraints]))
    try:
        assignments = dfs(constraints, {})
        if solvable_only:
            return True
        for name, record in csp.index2record.items():
            dtype = record.dtype
            if name not in assignments:
                g = dpll_find_typegen_variable(executor, dtype)
                if g is None:
                    raise NotImplementedError('Can not find a generator for unbounded variable {}, type {}.'.format(name, dtype))
                else:
                    output, = executor.get_function_implementation(g.function.name)()
                    assignments[name] = Assignment(AssignmentType.VALUE, TensorValue.from_scalar(output, dtype))
        return assignments
    except CSPNotSolvable:
        return None
    except CSPNoGenerator:
        raise 
[docs]
def csp_dpll_simplify(
    executor: PDSketchExecutor,
    csp: ConstraintSatisfactionProblem,
    enable_ignore: bool = True, return_assignments: bool = False
) -> Union[ConstraintSatisfactionProblem, Tuple[ConstraintSatisfactionProblem, AssignmentDict]]:
    """Simplify the CSP using DPLL algorithm.
    Args:
        executor: the executor.
        csp: the CSP.
        enable_ignore: whether to ignore constraints whose RHS value is not determined.
        return_assignments: whether to return the assignments.
    Returns:
        the simplified CSP.
    """
    constraints = csp.constraints.copy()
    assignments = dict()
    if len(constraints) == 0:
        return csp
    while True:
        nr_constraints = len(constraints)
        progress = True
        while progress:
            progress, constraints = dpll_filter_deterministic_equal(executor, constraints, assignments)
        if enable_ignore:
            constraints = dpll_filter_unused_rhs(executor, constraints, assignments, csp.index2record)
        progress = True
        while progress:
            progress, constraints = dpll_filter_deterministic_clauses(executor, constraints, assignments)
        if len(constraints) == nr_constraints:
            break
    if return_assignments:
        return csp.clone(constraints), assignments
    return csp.clone(constraints)