Source code for concepts.benchmark.logic_induction.graph_dataset

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : graph_dataset.py
# Author : Honghua Dong
# Email  : dhh19951@gmail.com
# Date   : 05/07/2018
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.

import numpy as np

from torch.utils.data.dataset import Dataset
from torchvision import datasets

import jacinle.random as random

from concepts.benchmark.algorithm_env.graph import random_generate_graph, random_generate_graph_dnc

__all__ = ['GraphOutDegreeDataset', 'GraphConnectivityDataset', 'GraphAdjacentDataset']


[docs]class GraphDatasetBase(Dataset):
[docs] def __init__(self, nr_nodes, p, epoch_size, directed=False, gen_method='dnc'): if type(nr_nodes) is int: self.nr_nodes = (max(nr_nodes // 2, 1), nr_nodes) else: self.nr_nodes = tuple(nr_nodes) self.p = p self.epoch_size = epoch_size self.directed = directed self.gen_method = gen_method
def _gen_graph(self, item): nr_nodes = item % (self.nr_nodes[1] - self.nr_nodes[0] + 1) + self.nr_nodes[0] if type(self.p) is float: p = self.p else: p = self.p[0] + random.rand() * (self.p[1] - self.p[0]) gen_graph = random_generate_graph_dnc if self.gen_method == 'dnc' else random_generate_graph return gen_graph(nr_nodes, p, directed=self.directed)
[docs] def __len__(self): return self.epoch_size
[docs]class GraphOutDegreeDataset(GraphDatasetBase):
[docs] def __init__(self, nr_nodes, p, epoch_size, degree=2, directed=False, gen_method='dnc'): super().__init__(nr_nodes, p, epoch_size, directed, gen_method) self.degree = degree
[docs] def __getitem__(self, item): graph = self._gen_graph(item) return dict( n=graph._nr_nodes, relations=np.expand_dims(graph.get_edges(), axis=-1), target=(graph.get_out_degree() == self.degree).astype('float'), )
[docs]class GraphConnectivityDataset(GraphDatasetBase):
[docs] def __init__(self, nr_nodes, p, epoch_size, dist_limit=None, directed=False, gen_method='dnc'): super().__init__(nr_nodes, p, epoch_size, directed, gen_method) self.dist_limit = dist_limit
[docs] def __getitem__(self, item): graph = self._gen_graph(item) return dict( n=graph._nr_nodes, relations=np.expand_dims(graph.get_edges(), axis=-1), # relations=graph.get_relations(), target=graph.get_connectivity(self.dist_limit, exclude_self=True), )
[docs]class GraphAdjacentDataset(GraphDatasetBase):
[docs] def __init__(self, nr_nodes, p, epoch_size, nr_colors, directed=False, gen_method='dnc', is_mnist_colors=False, is_train=True): super().__init__(nr_nodes, p, epoch_size, directed, gen_method) self._nr_colors = nr_colors self._mnist_colors = is_mnist_colors if is_mnist_colors: assert nr_colors == 10 transform = None self.mnist = datasets.MNIST('../data', train=is_train, download=True, transform=transform)
[docs] def __getitem__(self, item): graph = self._gen_graph(item) n = graph._nr_nodes if self._mnist_colors: m = self.mnist.__len__() digits = [] colors = [] for i in range(n): x = random.randint(m) digit, color = self.mnist.__getitem__(x) digits.append(np.array(digit)[np.newaxis]) colors.append(color) digits, colors = np.array(digits), np.array(colors) else: colors = random.randint(self._nr_colors, size=n) states = np.zeros((n, self._nr_colors)) adjacent = np.zeros((n, self._nr_colors)) for i in range(n): states[i, colors[i]] = 1 adjacent[i, colors[i]] = 1 for j in range(n): if graph.has_edge(i, j): adjacent[i, colors[j]] = 1 if self._mnist_colors: states = digits return dict( n=n, relations=np.expand_dims(graph.get_edges(), axis=-1), states=states, colors=colors, target=adjacent, # connectivity=graph.get_connectivity(self.dist_limit, exclude_self=True), )