#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : vae1d.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/17/2022
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
import jactorch.nn as jacnn
from typing import Optional, List, Any, TypeVar, Tuple
Tensor = TypeVar('torch.Tensor')
[docs]
class BaseVAE(nn.Module):
[docs]
    def __init__(self) -> None:
        super(BaseVAE, self).__init__() 
    training: bool
[docs]
    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError 
[docs]
    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError 
[docs]
    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """Reparameterization trick to sample from N(mu, var) from N(0,1).
        Args:
            mu: Mean of the latent Gaussian [B x D]
            logvar: Standard deviation of the latent Gaussian [B x D]
        Returns:
            z: Samples from the latent Gaussian [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu 
[docs]
    def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError 
[docs]
    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input, **kwargs)
        z = self.reparameterize(mu, log_var)
        return [self.decode(z, **kwargs), input, mu, log_var] 
[docs]
    def loss_function(self, recons, input, mu, log_var, **kwargs) -> dict:
        kld_weight = 1  # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=-1))
        loss = recons_loss + kld_weight * kld_loss
        return {'loss': loss, 'loss/recon':recons_loss.item(), 'loss/kl':-kld_loss.item()} 
 
[docs]
class VanillaVAE1d(BaseVAE):
[docs]
    def __init__(self, input_dim: int, latent_dim: int, hidden_dims: Optional[List[int]] = None, **kwargs) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.encoder = jacnn.MLPLayer(input_dim, latent_dim * 2, hidden_dims, activation=nn.LeakyReLU(), flatten=False, last_activation=False)
        self.decoder = jacnn.MLPLayer(latent_dim, input_dim, list(reversed(hidden_dims)), activation=nn.LeakyReLU(), flatten=False, last_activation=False) 
    training: bool
    input_dim: int
    """The dimension of the input data."""
    latent_dim: int
    """The dimension of the latent space."""
    hidden_dims: Optional[List[int]]
    """The hidden dimensions of the encoder and decoder. If None, the encoder and decoder are single-layer networks."""
[docs]
    def encode(self, input: Tensor) -> List[Tensor]:
        result = self.encoder(input)
        mu, log_var = result.split(self.latent_dim, dim=-1)
        return [mu, log_var] 
[docs]
    def decode(self, z: Tensor) -> Tensor:
        return self.decoder(z) 
[docs]
    def sample(self, nr_samples:int, current_device: int, **kwargs) -> Tensor:
        z = torch.randn(nr_samples, self.latent_dim)
        z = z.to(current_device)
        samples = self.decode(z)
        return samples 
 
[docs]
class ConditionalVAE1d(BaseVAE):
[docs]
    def __init__(self, input_dim: int, condition_dim: int, latent_dim: int, hidden_dims: Optional[List[int]] = None, **kwargs) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.condition_dim = condition_dim
        self.latent_dim = latent_dim
        self.hidden_dims = hidden_dims
        self.encoder = jacnn.MLPLayer(input_dim + condition_dim, latent_dim * 2, hidden_dims, activation=nn.LeakyReLU(), flatten=False, last_activation=False)
        self.decoder = jacnn.MLPLayer(latent_dim + condition_dim, input_dim, list(reversed(hidden_dims)), activation=nn.LeakyReLU(), flatten=False, last_activation=False) 
    training: bool
    input_dim: int
    """The dimension of the input data."""
    condition_dim: int
    """The dimension of the condition."""
    latent_dim: int
    """The dimension of the latent space."""
    hidden_dims: Optional[List[int]]
    """The hidden dimensions of the encoder and decoder. If None, the encoder and decoder are single-layer networks."""
[docs]
    def encode(self, input: Tensor, label: Tensor) -> List[Tensor]:
        result = self.encoder(torch.cat([input, label], dim=-1))
        mu, log_var = result.split(self.latent_dim, dim=-1)
        return [mu, log_var] 
[docs]
    def decode(self, z: Tensor, label: Tensor) -> Tensor:
        return self.decoder(torch.cat([z, label], dim=-1)) 
[docs]
    def sample(self, label: Tensor, nr_samples: int, **kwargs) -> Tensor:
        z = torch.randn(nr_samples, self.latent_dim)
        label = label.unsqueeze(0).expand((nr_samples, label.size(0)))
        samples = self.decode(z, label)
        return samples