Source code for concepts.dsl.tensor_value_utils

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : tensor_value_utils.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 06/18/2024
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import torch
from typing import Union, Sequence, List

from concepts.dsl.tensor_value import TensorValue
from concepts.dsl.tensor_state import StateObjectReference

__all__= ['expand_argument_values']


[docs] def expand_argument_values( argument_values: Sequence[Union[TensorValue, int, str, slice, StateObjectReference]], handle_wildcard: bool = False ) -> List[TensorValue]: """Expand a list of argument values to the same batch size. Args: argument_values: a list of argument values. handle_wildcard: whether to handle the wildcard variable '??'. If set to True, the function will return the original argument values without expanding them if any of the argument values contains the wildcard variable. Returns: the result list of argument values. All return values will have the same batch size. """ if handle_wildcard: has_slot_var = False for arg in argument_values: if isinstance(arg, TensorValue): for var in arg.batch_variables: if var == '??': has_slot_var = True break if has_slot_var: return list(argument_values) if len(argument_values) < 2: return list(argument_values) argument_values = list(argument_values) batch_variables = list() batch_sizes = list() for arg in argument_values: if isinstance(arg, TensorValue): for var in arg.batch_variables: if var not in batch_variables: batch_variables.append(var) batch_sizes.append(arg.get_variable_size(var)) else: assert isinstance(arg, (int, str, slice, StateObjectReference)), arg masks = list() for i, arg in enumerate(argument_values): if isinstance(arg, TensorValue): argument_values[i] = arg.expand(batch_variables, batch_sizes) if argument_values[i].tensor_mask is not None: masks.append(argument_values[i].tensor_mask) if len(masks) > 0: final_mask = torch.stack(masks, dim=-1).amin(dim=-1) for arg in argument_values: if isinstance(arg, TensorValue): arg.tensor_mask = final_mask arg._mask_certified_flag = True # now we have corrected the mask. return argument_values