Source code for concepts.nn.vae1d

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : vae1d.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/17/2022
#
# This file is part of HACL-PyTorch.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F
import jactorch.nn as jacnn

from typing import Optional, List, Any, TypeVar, Tuple

Tensor = TypeVar('torch.Tensor')


[docs]class BaseVAE(nn.Module):
[docs] def __init__(self) -> None: super(BaseVAE, self).__init__()
training: bool
[docs] def encode(self, input: Tensor) -> List[Tensor]: raise NotImplementedError
[docs] def decode(self, input: Tensor) -> Any: raise NotImplementedError
[docs] def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """Reparameterization trick to sample from N(mu, var) from N(0,1). Args: mu: Mean of the latent Gaussian [B x D] logvar: Standard deviation of the latent Gaussian [B x D] Returns: z: Samples from the latent Gaussian [B x D] """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu
[docs] def sample(self, batch_size:int, current_device: int, **kwargs) -> Tensor: raise NotImplementedError
[docs] def forward(self, input: Tensor, **kwargs) -> List[Tensor]: mu, log_var = self.encode(input, **kwargs) z = self.reparameterize(mu, log_var) return [self.decode(z, **kwargs), input, mu, log_var]
[docs] def loss_function(self, recons, input, mu, log_var, **kwargs) -> dict: kld_weight = 1 # Account for the minibatch samples from the dataset recons_loss = F.mse_loss(recons, input) kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=-1)) loss = recons_loss + kld_weight * kld_loss return {'loss': loss, 'loss/recon':recons_loss.item(), 'loss/kl':-kld_loss.item()}
[docs]class VanillaVAE1d(BaseVAE):
[docs] def __init__(self, input_dim: int, latent_dim: int, hidden_dims: Optional[List[int]] = None, **kwargs) -> None: super().__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.hidden_dims = hidden_dims self.encoder = jacnn.MLPLayer(input_dim, latent_dim * 2, hidden_dims, activation=nn.LeakyReLU(), flatten=False, last_activation=False) self.decoder = jacnn.MLPLayer(latent_dim, input_dim, list(reversed(hidden_dims)), activation=nn.LeakyReLU(), flatten=False, last_activation=False)
training: bool input_dim: int """The dimension of the input data.""" latent_dim: int """The dimension of the latent space.""" hidden_dims: Optional[List[int]] """The hidden dimensions of the encoder and decoder. If None, the encoder and decoder are single-layer networks."""
[docs] def encode(self, input: Tensor) -> List[Tensor]: result = self.encoder(input) mu, log_var = result.split(self.latent_dim, dim=-1) return [mu, log_var]
[docs] def decode(self, z: Tensor) -> Tensor: return self.decoder(z)
[docs] def sample(self, nr_samples:int, current_device: int, **kwargs) -> Tensor: z = torch.randn(nr_samples, self.latent_dim) z = z.to(current_device) samples = self.decode(z) return samples
[docs]class ConditionalVAE1d(BaseVAE):
[docs] def __init__(self, input_dim: int, condition_dim: int, latent_dim: int, hidden_dims: Optional[List[int]] = None, **kwargs) -> None: super().__init__() self.input_dim = input_dim self.condition_dim = condition_dim self.latent_dim = latent_dim self.hidden_dims = hidden_dims self.encoder = jacnn.MLPLayer(input_dim + condition_dim, latent_dim * 2, hidden_dims, activation=nn.LeakyReLU(), flatten=False, last_activation=False) self.decoder = jacnn.MLPLayer(latent_dim + condition_dim, input_dim, list(reversed(hidden_dims)), activation=nn.LeakyReLU(), flatten=False, last_activation=False)
training: bool input_dim: int """The dimension of the input data.""" condition_dim: int """The dimension of the condition.""" latent_dim: int """The dimension of the latent space.""" hidden_dims: Optional[List[int]] """The hidden dimensions of the encoder and decoder. If None, the encoder and decoder are single-layer networks."""
[docs] def encode(self, input: Tensor, label: Tensor) -> List[Tensor]: result = self.encoder(torch.cat([input, label], dim=-1)) mu, log_var = result.split(self.latent_dim, dim=-1) return [mu, log_var]
[docs] def decode(self, z: Tensor, label: Tensor) -> Tensor: return self.decoder(torch.cat([z, label], dim=-1))
[docs] def sample(self, label: Tensor, nr_samples: int, **kwargs) -> Tensor: z = torch.randn(nr_samples, self.latent_dim) label = label.unsqueeze(0).expand((nr_samples, label.size(0))) samples = self.decode(z, label) return samples