Source code for concepts.dm.crow.planners.regression_dependency

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : regression_dependency.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 08/5/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

"""Dependencies for the regression planning."""

import tempfile
import os
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple, List, Dict

from jacinle.utils.printing import indent_text
from concepts.dm.crow.planners.regression_planning import SupportedCrowExpressionType
from concepts.dm.crow.behavior_utils import format_behavior_statement


[docs] @dataclass(unsafe_hash=True) class RegressionTraceStatement(object): stmt: SupportedCrowExpressionType scope_id: int = None new_scope_id: Optional[int] = None additional_info: Optional[str] = None scope: Optional[dict] = None derived_from: Optional[SupportedCrowExpressionType] = None
[docs] def node_string(self, scopes: Dict[int, dict]) -> str: scope_id = self.new_scope_id if self.new_scope_id is not None else self.scope_id basic_fmt = format_behavior_statement(self.stmt, scopes, scope_id) if self.derived_from is not None: basic_fmt += '\n derived from: ' + indent_text(format_behavior_statement(self.derived_from, scopes, scope_id), 1, indent_first=False) if self.additional_info is not None: basic_fmt += '\n note: ' + self.additional_info return basic_fmt
[docs] class RegressionDependencyGraph(object):
[docs] def __init__(self, scopes: Dict[int, dict]): self.scopes = scopes self.nodes = list() self.node2index = dict() self.edges = dict()
nodes: List[RegressionTraceStatement] node2index: Dict[RegressionTraceStatement, int] edges: Dict[RegressionTraceStatement, List[int]]
[docs] def add_node(self, node: RegressionTraceStatement) -> 'RegressionDependencyGraph': self.nodes.append(node) self.node2index[node] = len(self.nodes) - 1 return self
[docs] def connect(self, x: RegressionTraceStatement, y: RegressionTraceStatement) -> 'RegressionDependencyGraph': """Connect two nodes in the dependency graph. x is the "parent" of y. Args: x: the parent node. y: the child node. """ self.edges.setdefault(x, list()).append(self.node2index[y]) return self
[docs] def print(self, i: int = 0, indent_level: int = 0) -> None: print(indent_text(f'{i}::' + self.nodes[i].node_string(self.scopes), indent_level)) for child in self.edges.get(self.nodes[i], []): self.print(child, indent_level + 1)
[docs] def sort_nodes_into_levels(self): levels = dict() def dfs(i): max_level = -1 for child in self.edges.get(self.nodes[i], []): max_level = max(dfs(child), max_level) levels[i] = max_level + 1 return max_level + 1 max_level = dfs(0) output_levels = list() for i in range(max_level + 1): output_levels.append([j for j in range(len(self.nodes)) if levels[j] == i]) return output_levels
[docs] def render_graphviz(self, filename: Optional[str] = None) -> None: try: import graphviz except ImportError: raise ImportError('Please install graphviz first by running "pip install graphviz".') dot = graphviz.Digraph(comment='Regression Dependency Graph') for i, node in enumerate(self.nodes): dot.node(str(i), node.node_string(self.scopes).replace('\n', '\l') + '\l', shape='rectangle') levels = self.sort_nodes_into_levels() for i in range(len(levels)): dot.node(f'level_{i}', '', ordering='out', style='invis') for j in levels[i]: dot.edge(f'level_{i}', str(j), style='invis') for i in reversed(range(len(levels))): if i > 0: dot.edge(f'level_{i}', f'level_{i - 1}', style='invis') for x, ys in self.edges.items(): for y in ys: dot.edge(str(self.node2index[x]), str(y)) if filename is not None: if filename.endswith('.png'): actual_filename = filename[:-4] dot.render(actual_filename, format='png', cleanup=True) print(f'Graphviz file saved to "{filename}".') elif filename.endswith('.pdf'): actual_filename = filename[:-4] dot.render(actual_filename, format='pdf', cleanup=True) print(f'Graphviz file saved to "{filename}".') elif filename.endswith('.dot'): dot.render(filename) print(f'Graphviz file saved to "{filename}".') else: raise ValueError(f'Unsupported file format: {filename}. Only PNG, PDF, and DOT are supported.') else: with tempfile.NamedTemporaryFile(suffix='.pdf') as f: dot.render(f.name[:-4], format='pdf', cleanup=True) print(f'Graphviz file saved to "{f.name}". Now opening it in the default PDF viewer...') os.system(f'open "{f.name}"') import time; time.sleep(3) # We need to sleep for a while to prevent the file from being deleted too early.
[docs] def recover_graph_from_trace(trace: Sequence[RegressionTraceStatement], scopes: Dict[int, dict]) -> RegressionDependencyGraph: graph = RegressionDependencyGraph(scopes) scope_to_node = dict() graph.add_node(trace[0]) scope_to_node[trace[0].new_scope_id] = trace[0] for stmt in trace[1:]: graph.add_node(stmt) if stmt.scope_id in scope_to_node: graph.connect(scope_to_node[stmt.scope_id], stmt) if stmt.new_scope_id is not None: scope_to_node[stmt.new_scope_id] = stmt return graph