Tutorial 1.4: Use Enumerative Search to Learn a Function#

[1]:
# From tutorial/1-dsl/1-types-and-functions
from concepts.dsl.dsl_types import ValueType, ConstantType, BOOL, FLOAT32, INT64, VectorValueType, FormatContext
from concepts.dsl.dsl_functions import Function, FunctionTyping
from concepts.dsl.function_domain import FunctionDomain

t_item = ValueType('item')
t_item_set = ValueType('item_set')
t_concept_name = ConstantType('concept_name')
t_shape = ValueType('shape')
t_color = ValueType('color')
t_size = VectorValueType(FLOAT32, 3, alias='size')
t_int = INT64

domain = FunctionDomain()
domain.define_type(t_item)
domain.define_type(t_item_set)
domain.define_type(t_concept_name)
domain.define_type(t_color)
domain.define_type(t_shape)
domain.define_type(t_size)
domain.define_function(Function('scene', FunctionTyping[t_item_set]()))
domain.define_function(Function('filter_color', FunctionTyping[t_item_set](t_item_set, t_concept_name)))
domain.define_function(Function('filter_shape', FunctionTyping[t_item_set](t_item_set, t_concept_name)))
domain.define_function(Function('unique', FunctionTyping[t_item](t_item_set)))
domain.define_function(Function('color_of', FunctionTyping[t_color](t_item)))
domain.define_function(Function('shape_of', FunctionTyping[t_shape](t_item)))
domain.define_function(Function('size_of', FunctionTyping[t_size](t_item)))
domain.define_function(Function('same_color', FunctionTyping[BOOL](t_color, t_color)))
domain.define_function(Function('same_shape', FunctionTyping[BOOL](t_shape, t_shape)))
domain.define_function(Function('same_size', FunctionTyping[BOOL](t_size, t_size)))
domain.define_function(Function('count', FunctionTyping[t_int](t_item_set)))

domain.define_const(t_concept_name, 'box')
domain.define_const(t_concept_name, 'sphere')
domain.define_const(t_concept_name, 'red')
domain.define_const(t_concept_name, 'blue')
domain.define_const(t_concept_name, 'green')
[2]:
# From tutorial/1-dsl/2-execution
from dataclasses import dataclass, field
from typing import Tuple, List
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor

@dataclass
class Item(object):
    color: str
    shape: str
    size: Tuple[float, float, float]


@dataclass
class Scene(object):
    items: List[Item]

class Executor(FunctionDomainExecutor):
    def scene(self):
        return self.grounding.items
    def filter_color(self, inputs, color_name):
        return [o for o in inputs if o.color == color_name]
    def filter_shape(self, inputs, shape_name):
        return [o for o in inputs if o.shape == shape_name]
    def unique(self, inputs):
        assert len(inputs) == 1
        return inputs[0]
    def color_of(self, obj):
        return obj.color
    def shape_of(self, obj):
        return obj.shape
    def size_of(self, obj):
        return obj.size
    def same_color(self, c1, c2):
        return c1 == c2
    def same_shape(self, s1, s2):
        return s1 == s2
    def same_size(self, z1, z2):
        return all(abs(sz1 - sz2) < 0.1 for sz1, sz2 in zip(z1, z2))
    def count(self, inputs):
        return len(inputs)

executor = Executor(domain)
15 16:55:46 Function scene automatically registered.
15 16:55:46 Function filter_color automatically registered.
15 16:55:46 Function filter_shape automatically registered.
15 16:55:46 Function unique automatically registered.
15 16:55:46 Function color_of automatically registered.
15 16:55:46 Function shape_of automatically registered.
15 16:55:46 Function size_of automatically registered.
15 16:55:46 Function same_color automatically registered.
15 16:55:46 Function same_shape automatically registered.
15 16:55:46 Function same_size automatically registered.
15 16:55:46 Function count automatically registered.
[3]:
scene1 = Scene([
    Item('red', 'box', (1, 1, 1)),
    Item('blue', 'box', (1, 1, 1)),
    Item('green', 'box', (2, 2, 2))
])
scene2 = Scene([
    Item('red', 'box', (1, 1, 1)),
    Item('red', 'box', (1, 1, 1)),
])
[4]:
target_expr = domain.f_count(domain.f_filter_color(domain.f_scene(), 'red'))

print(target_expr)
print('scene1:', executor.execute(target_expr, grounding=scene1))
print('scene2:', executor.execute(target_expr, grounding=scene2))
count(filter_color(scene(), V(red, dtype=concept_name)))
scene1: V(1, dtype=int64)
scene2: V(2, dtype=int64)
[5]:
from concepts.dsl.learning.function_domain_search import FunctionDomainExpressionEnumerativeSearcher
[6]:
enumerator = FunctionDomainExpressionEnumerativeSearcher(domain)
candidate_expressions = enumerator.gen_function_application_expressions(
    return_type=t_int,
    max_depth=3,
    search_constants=True
)
with FormatContext(function_format_lambda=True).as_default():
    for x in candidate_expressions:
        print(x.expression)
print(f'In total: {len(candidate_expressions)} candidate expressions.')
count(scene())
count(filter_color(scene(), V(box, dtype=concept_name)))
count(filter_color(scene(), V(sphere, dtype=concept_name)))
count(filter_color(scene(), V(red, dtype=concept_name)))
count(filter_color(scene(), V(blue, dtype=concept_name)))
count(filter_color(scene(), V(green, dtype=concept_name)))
count(filter_shape(scene(), V(box, dtype=concept_name)))
count(filter_shape(scene(), V(sphere, dtype=concept_name)))
count(filter_shape(scene(), V(red, dtype=concept_name)))
count(filter_shape(scene(), V(blue, dtype=concept_name)))
count(filter_shape(scene(), V(green, dtype=concept_name)))
In total: 11 candidate expressions.
[7]:
from concepts.dsl.learning.function_domain_search import learn_expression_from_examples
[8]:
io_examples = [
    ([], executor.execute(target_expr, grounding=scene1), scene1),
    ([], executor.execute(target_expr, grounding=scene2), scene2)
]
[9]:
learn_expression_from_examples(
    domain, executor,
    input_output=io_examples,
    criterion=lambda x, y: x.value == y.value,
    candidate_expressions=candidate_expressions
)
[9]:
FunctionApplicationExpression<count(filter_color(scene(), V(red, dtype=concept_name)))>
[10]:
learn_expression_from_examples(
    domain, executor,
    input_output=io_examples,
    criterion=lambda x, y: x.value == y.value,
    candidate_expressions=None  # The algorithm will automatically infer the target type.
)
[10]:
FunctionApplicationExpression<count(filter_color(scene(), V(red, dtype=concept_name)))>