Source code for concepts.dsl.expression_visitor

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : expression_visitor.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/30/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""A visitor for iterating over expressions."""

from typing import Any, Union

import jacinle
import concepts.dsl.expression as E
from concepts.dsl.expression import Expression

__all__ = ['ExpressionVisitor', 'IdentityExpressionVisitor']


[docs] class ExpressionVisitor(object): """A visitor for iterating over expressions."""
[docs] def visit(self, expr: Expression) -> Any: """The main entry point of the visitor. It will call the corresponding method for the given expression type. Args: expr: the expression to visit. Returns: the result of the visit. """ if isinstance(expr, E.NullExpression): return self.visit_null_expression(expr) elif isinstance(expr, E.VariableExpression): return self.visit_variable_expression(expr) elif isinstance(expr, E.ObjectConstantExpression): return self.visit_object_constant_expression(expr) elif isinstance(expr, E.ConstantExpression): return self.visit_constant_expression(expr) elif isinstance(expr, E.ListCreationExpression): return self.visit_list_creation_expression(expr) elif isinstance(expr, E.ListExpansionExpression): return self.visit_list_expansion_expression(expr) elif isinstance(expr, E.ListFunctionApplicationExpression): return self.visit_list_function_application_expression(expr) elif isinstance(expr, E.FunctionApplicationExpression): return self.visit_function_application_expression(expr) elif isinstance(expr, E.ConditionalSelectExpression): return self.visit_conditional_select_expression(expr) elif isinstance(expr, E.DeicticSelectExpression): return self.visit_deictic_select_expression(expr) elif isinstance(expr, E.BoolExpression): return self.visit_bool_expression(expr) elif isinstance(expr, E.QuantificationExpression): return self.visit_quantification_expression(expr) elif isinstance(expr, E.GeneralizedQuantificationExpression): return self.visit_generalized_quantification_expression(expr) elif isinstance(expr, E.FindAllExpression): return self.visit_find_all_expression(expr) elif isinstance(expr, E.ObjectCompareExpression): return self.visit_object_compare_expression(expr) elif isinstance(expr, E.ValueCompareExpression): return self.visit_value_compare_expression(expr) elif isinstance(expr, E.ConditionExpression): return self.visit_condition_expression(expr) elif isinstance(expr, E.PredicateEqualExpression): return self.visit_predicate_equal_expression(expr) elif isinstance(expr, E.AssignExpression): return self.visit_assign_expression(expr) elif isinstance(expr, E.ConditionalAssignExpression): return self.visit_conditional_assign_expression(expr) elif isinstance(expr, E.DeicticAssignExpression): return self.visit_deictic_assign_expression(expr) else: raise TypeError(f'Unknown expression type: {type(expr)}.')
[docs] def visit_null_expression(self, expr: E.NullExpression) -> Any: raise NotImplementedError()
[docs] def visit_variable_expression(self, expr: E.VariableExpression) -> Any: raise NotImplementedError()
[docs] def visit_object_constant_expression(self, expr: E.ObjectConstantExpression) -> Any: raise NotImplementedError()
[docs] def visit_constant_expression(self, expr: E.ConstantExpression) -> Any: raise NotImplementedError()
[docs] def visit_list_creation_expression(self, expr: E.ListCreationExpression) -> Any: raise NotImplementedError()
[docs] def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> Any: raise NotImplementedError()
[docs] def visit_function_application_expression(self, expr: E.FunctionApplicationExpression) -> Any: raise NotImplementedError()
[docs] def visit_list_function_application_expression(self, expr: E.ListFunctionApplicationExpression) -> Any: raise NotImplementedError()
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> Any: raise NotImplementedError()
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> Any: raise NotImplementedError()
[docs] def visit_bool_expression(self, expr: E.BoolExpression) -> Any: raise NotImplementedError()
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression) -> Any: raise NotImplementedError()
[docs] def visit_generalized_quantification_expression(self, expr: E.GeneralizedQuantificationExpression) -> Any: raise NotImplementedError()
[docs] def visit_find_one_expression(self, expr: E.FindOneExpression) -> Any: raise NotImplementedError()
[docs] def visit_find_all_expression(self, expr: E.FindAllExpression) -> Any: raise NotImplementedError()
[docs] def visit_object_compare_expression(self, expr: E.ObjectCompareExpression) -> Any: raise NotImplementedError()
[docs] def visit_value_compare_expression(self, expr: E.ValueCompareExpression) -> Any: raise NotImplementedError()
[docs] def visit_condition_expression(self, expr: E.ConditionExpression) -> Any: raise NotImplementedError()
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression) -> Any: raise NotImplementedError()
[docs] def visit_assign_expression(self, expr: E.AssignExpression) -> Any: raise NotImplementedError()
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression) -> Any: raise NotImplementedError()
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression) -> Any: raise NotImplementedError()
[docs] class IdentityExpressionVisitor(ExpressionVisitor):
[docs] def visit_null_expression(self, expr: E.NullExpression) -> E.NullExpression: return expr
[docs] def visit_variable_expression(self, expr: E.VariableExpression) -> E.VariableExpression: return type(expr)(expr.variable)
[docs] def visit_function_application_expression(self, expr: Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]) -> Union[E.FunctionApplicationExpression, E.ListFunctionApplicationExpression]: return type(expr)(expr.function, [self.visit(e) for e in expr.arguments])
[docs] def visit_list_creation_expression(self, expr: E.ListCreationExpression) -> E.ListCreationExpression: return type(expr)([self.visit(e) for e in expr.arguments])
[docs] def visit_list_expansion_expression(self, expr: E.ListExpansionExpression) -> E.ListExpansionExpression: return type(expr)(self.visit(expr.expression))
[docs] def visit_list_function_application_expression(self, expr: E.ListFunctionApplicationExpression) -> E.ListFunctionApplicationExpression: return type(expr)(expr.function, [self.visit(e) for e in expr.arguments])
[docs] def visit_bool_expression(self, expr: E.BoolExpression) -> E.BoolExpression: return E.BoolExpression(expr.bool_op, [self.visit(child) for child in expr.arguments])
[docs] def visit_quantification_expression(self, expr: E.QuantificationExpression) -> E.QuantificationExpression: return E.QuantificationExpression(expr.quantification_op, expr.variable, self.visit(expr.expression))
[docs] def visit_generalized_quantification_expression(self, expr: E.GeneralizedQuantificationExpression) -> E.GeneralizedQuantificationExpression: return E.GeneralizedQuantificationExpression(expr.quantification_op, expr.variable, self.visit(expr.expression), return_type=expr.return_type)
[docs] def visit_find_one_expression(self, expr: E.FindOneExpression) -> E.FindOneExpression: return E.FindOneExpression(expr.variable, self.visit(expr.expression))
[docs] def visit_find_all_expression(self, expr: E.FindAllExpression) -> E.FindAllExpression: return E.FindAllExpression(expr.variable, self.visit(expr.expression))
[docs] def visit_object_compare_expression(self, expr: E.ObjectCompareExpression) -> E.ObjectCompareExpression: return E.ObjectCompareExpression(expr.compare_op, self.visit(expr.lhs), self.visit(expr.rhs))
[docs] def visit_value_compare_expression(self, expr: E.ValueCompareExpression) -> E.ValueCompareExpression: return E.ValueCompareExpression(expr.compare_op, self.visit(expr.lhs), self.visit(expr.rhs))
[docs] def visit_condition_expression(self, expr: E.ConditionExpression) -> Any: return type(expr)(self.visit(expr.condition), self.visit(expr.true_value), self.visit(expr.false_value))
[docs] def visit_predicate_equal_expression(self, expr: E.PredicateEqualExpression) -> E.PredicateEqualExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value))
[docs] def visit_assign_expression(self, expr: E.AssignExpression) -> E.AssignExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value))
[docs] def visit_conditional_select_expression(self, expr: E.ConditionalSelectExpression) -> E.ConditionalSelectExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.condition))
[docs] def visit_deictic_select_expression(self, expr: E.DeicticSelectExpression) -> E.DeicticSelectExpression: return type(expr)(expr.variable, self.visit(expr.expression))
[docs] def visit_conditional_assign_expression(self, expr: E.ConditionalAssignExpression) -> E.ConditionalAssignExpression: return type(expr)(self.visit(expr.predicate), self.visit(expr.value), self.visit(expr.condition))
[docs] def visit_deictic_assign_expression(self, expr: E.DeicticAssignExpression) -> E.DeicticAssignExpression: return type(expr)(expr.variable, self.visit(expr.expr))
[docs] def visit_constant_expression(self, expr: Expression) -> Expression: return expr
[docs] def visit_object_constant_expression(self, expr: Expression) -> Expression: return expr