#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : graph_env.py
# Author : Honghua Dong
# Email  : dhh19951@gmail.com
# Date   : 04/27/2018
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import numpy as np
import gym
from typing import Optional, Tuple
from jacinle.utils.tqdm import tqdm
from concepts.benchmark.common.random_env import RandomizedEnv
from concepts.benchmark.algorithm_env.graph import random_generate_graph, random_generate_graph_dnc, random_generate_special_graph
__all__ = ['GraphEnvBase', 'GraphPathEnv']
[docs]
class GraphEnvBase(RandomizedEnv):
    """Graph Env Base."""
[docs]
    def __init__(self, nr_nodes: int, p: float = 0.5, directed: bool = False, gen_method: str = 'edge', np_random: Optional[np.random.RandomState] = None, seed: Optional[int] = None):
        """Initialize the environment.
        Args:
            nr_nodes: the number of nodes in the graph.
            p: parameter for random generation. (Default: 0.5)
                - (edge method): The probability that an edge doesn't exist in directed graph.
                - (dnc method): Control the range of the sample of out-degree.
                - other methods: Unused.
            directed: directed or Undirected graph. Default: `False` (undirected)
            gen_method: use which method to randomly generate a graph.
                - 'edge': By sampling the existence of each edge.
                - 'dnc': Sample out-degree (:math:`m`) of each node, and link to nearest neighbors in the unit square.
                - 'list': generate a chain-like graph.
        """
        super().__init__(np_random, seed)
        self._nr_nodes = nr_nodes
        self._p = p
        self._directed = directed
        self._gen_method = gen_method
        self._graph = None 
    @property
    def action_space(self):
        raise NotImplementedError
    @property
    def observation_space(self):
        raise NotImplementedError
    @property
    def graph(self):
        """The generated graph."""
        return self._graph
    def _gen_random_graph(self):
        """ generate the graph by specified method. """
        n = self._nr_nodes
        p = self._p
        if self._gen_method == 'edge':
            self._graph = random_generate_graph(n, p, self._directed)
        elif self._gen_method == 'dnc':
            self._graph = random_generate_graph_dnc(n, p, self._directed)
        else:
            self._graph = random_generate_special_graph(n, self._gen_method, self._directed) 
[docs]
class GraphPathEnv(GraphEnvBase):
    """Env for Finding a path from starting node to the destination."""
[docs]
    def __init__(self, nr_nodes: int, dist_range: Tuple[int, int], p: float = 0.5, directed: bool = False, gen_method: str = 'edge', np_random: Optional[np.random.RandomState] = None, seed: Optional[int] = None):
        """Initialize the environment.
        Args:
            nr_nodes: the number of nodes in the graph.
            dist_range: the sampling range of distance between starting node and the destination.
            p: parameter for random generation. (Default: 0.5)
                - (edge method): The probability that an edge doesn't exist in directed graph.
                - (dnc method): Control the range of the sample of out-degree.
                - other methods: Unused.
            directed: directed or Undirected graph. Default: `False` (undirected)
            gen_method: use which method to randomly generate a graph.
                - 'edge': By sampling the existence of each edge.
                - 'dnc': Sample out-degree (:math:`m`) of each node, and link to the nearest neighbors in the unit square.
                - 'list': generate a chain-like graph.
            np_random: random state. If None, a new random state will be created based on the seed.
            seed: random seed. If None, a randomly chosen seed will be used.
        """
        super().__init__(nr_nodes, p, directed, gen_method, np_random=np_random, seed=seed)
        self._dist_range = dist_range
        self._dist = None
        self._dist_matrix = None
        self._task = None
        self._current = None
        self._steps = None
        self.action_space = gym.spaces.MultiDiscrete([nr_nodes, nr_nodes]) 
[docs]
    @classmethod
    def make(cls, n: int, dist_range: Tuple[int, int], p: float = 0.5, directed: bool = False, gen_method: str = 'edge', seed: Optional[int] = None) -> gym.Env:
        env = cls(n, dist_range, p=p, directed=directed, gen_method=gen_method, seed=seed)
        return env 
    @property
    def dist(self) -> int:
        return self._dist
[docs]
    def reset_nr_nodes(self, nr_nodes: int):
        self._nr_nodes = nr_nodes
        self.action_space = gym.spaces.MultiDiscrete([nr_nodes, nr_nodes]) 
[docs]
    def reset(self, **kwargs):
        """Restart the environment."""
        self._dist = self._gen_random_distance()
        self._task = None
        while True:
            self._gen_random_graph()
            self._dist_matrix = self._graph.get_shortest()
            self._task = self._gen_random_task()
            if self._task is not None:
                break
        self._current = self._task[0]
        self._steps = 0
        return self.get_state() 
[docs]
    def step(self, action):
        """Move to the target node from the current node if has_edge(current -> target)."""
        if self._current == self._task[1]:
            return self.get_state(), 1, True, {}
        if self._graph.has_edge(self._current, action):
            self._current = action
        if self._current == self._task[1]:
            return self.get_state(), 1, True, {}
        self._steps += 1
        if self._steps >= self.dist:
            return self.get_state(), 0, True, {}
        return self.get_state(), 0, False, {} 
    def _gen_random_distance(self):
        lower, upper = self._dist_range
        upper = min(upper, self._nr_nodes - 1)
        return self.np_random.randint(upper - lower + 1) + lower
    def _gen_random_task(self):
        """Sample the starting node and the destination according to the distance."""
        st, ed = np.where(self._dist_matrix == self._dist)
        if len(st) == 0:
            return None
        ind = self.np_random.randint(len(st))
        return st[ind], ed[ind]
[docs]
    def get_state(self):
        relation = self._graph.get_edges()
        current_state = np.zeros_like(relation)
        current_state[self._current, :] = 1
        return np.stack([relation, current_state], axis=-1) 
[docs]
    def oracle_policy(self, state):
        """Oracle policy: Swap the first two numbers that are not sorted."""
        current = self._current
        target = self._task[1]
        if current == target:
            return 0
        possible_actions = state[current, :, 0] == 1
        possible_actions = possible_actions & self._dist_matrix[:, target] < self._dist_matrix[current, target]
        if np.sum(possible_actions) == 0:
            raise RuntimeError('No action found.')
        return self.np_random.choice(np.where(possible_actions)[0]) 
[docs]
    def generate_data(self, nr_data_points: int):
        data = list()
        for _ in tqdm(range(nr_data_points)):
            obs = self.reset()
            states, actions = [obs], list()
            while True:
                action = self.oracle_policy(obs)
                if action is None:
                    raise RuntimeError('No action found.')
                obs, _, finished, _ = self.step(action)
                states.append(obs)
                actions.append(action)
                if finished:
                    break
            data.append({'states': states, 'actions': actions, 'optimal_steps': self._dist, 'actual_steps': len(actions)})
        return data