Base distributions
Existing base distributions
- class torchflows.base_distributions.gaussian.DiagonalGaussian(loc: Tensor, scale: Tensor, trainable_loc: bool = False, trainable_scale: bool = False)
Diagonal Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
- __init__(loc: Tensor, scale: Tensor, trainable_loc: bool = False, trainable_scale: bool = False)
DiagonalGaussian constructor.
- Parameters:
loc (torch.Tensor) – location vector with shape (event_size,).
scale (torch.Tensor) – scale vector with shape (event_size,).
trainable_loc (bool) – if True, the make the location trainable.
trainable_scale (bool) – if True, the make the scale trainable.
- class torchflows.base_distributions.gaussian.DenseGaussian(loc: Tensor, cov: Tensor, trainable_loc: bool = False)
Dense Gaussian distribution. Extends torch.distributions.Distribution and torch.nn.Module.
- __init__(loc: Tensor, cov: Tensor, trainable_loc: bool = False)
DenseGaussian constructor.
- Parameters:
loc (torch.Tensor) – location vector with shape (event_size,).
cov (torch.Tensor) – covariance matrix with shape (event_size, event_size).
trainable_loc (bool) – if True, the make the location trainable.
- class torchflows.base_distributions.mixture.DiagonalGaussianMixture(locs: Tensor, scales: Tensor, weights: Tensor = None, trainable_locs: bool = False, trainable_scales: bool = False)
Mixture distribution of diagonal Gaussians. Extends Mixture.
- __init__(locs: Tensor, scales: Tensor, weights: Tensor = None, trainable_locs: bool = False, trainable_scales: bool = False)
DiagonalGaussianMixture constructor.
- Parameters:
locs (torch.Tensor) – tensor of locations with shape (n_components, event_size).
scales (torch.Tensor) – tensor of scales with shape (n_components, event_size).
weights (torch.Tensor) – tensor of weights with shape (n_components,).
trainable_locs (bool) – if True, make locations trainable.
trainable_scales (bool) – if True, make scales trainable.
- class torchflows.base_distributions.mixture.DenseGaussianMixture(locs: Tensor, covs: Tensor, weights: Tensor = None, trainable_locs: bool = False)
- __init__(locs: Tensor, covs: Tensor, weights: Tensor = None, trainable_locs: bool = False)
DenseGaussianMixture constructor. Extends Mixture.
- Parameters:
locs (torch.Tensor) – tensor of locations with shape (n_components, event_size).
covs (torch.Tensor) – tensor of covariance matrices with shape (n_components, event_size, event_size).
weights (torch.Tensor) – tensor of weights with shape (n_components,).
trainable_locs (bool) – if True, make locations trainable.
Creating new base distributions
To create a new base distribution, we must create a subclass of torch.distributions.Distribution and torch.nn.Module.
This class should support the methods sampling and log probability computation.
We give an example for the diagonal Gaussian base distribution:
import torch
import torch.distributions
import torch.nn as nn
import math
class DiagonalGaussian(torch.distributions.Distribution, nn.Module):
def __init__(self, loc: torch.Tensor, scale: torch.Tensor):
super().__init__(event_shape=loc.shape, validate_args=False)
self.log_2_pi = math.log(2 * math.pi)
self.register_buffer('loc', loc)
self.register_buffer('log_scale', torch.log(scale))
@property
def scale(self):
return torch.exp(self.log_scale)
def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
noise = torch.randn(size=(*sample_shape, *self.event_shape)).to(self.loc)
# Unsqueeze loc and scale to match batch shape
sample_shape_mask = [None for _ in range(len(sample_shape))]
return self.loc[sample_shape_mask] + noise * self.scale[sample_shape_mask]
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
if len(value.shape) <= len(self.event_shape):
raise ValueError("Incorrect input shape")
# Unsqueeze loc and scale to match batch shape
sample_shape_mask = [None for _ in range(len(value.shape) - len(self.event_shape))]
loc = self.loc[sample_shape_mask]
scale = self.scale[sample_shape_mask]
log_scale = self.log_scale[sample_shape_mask]
# Compute log probability
elementwise_log_prob = -(0.5 * ((value - loc) / scale) ** 2 + 0.5 * self.log_2_pi + log_scale)
return sum_except_batch(elementwise_log_prob, self.event_shape)