Choosing a base distribution

We may replace the default standard Gaussian distribution with any torch distribution that is also a module. Some custom distributions are already implemented. We show an example for a diagonal Gaussian base distribution with mean 3 and standard deviation 2.

import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian

torch.manual_seed(0)
event_shape = (10,)
base_distribution = DiagonalGaussian(
    loc=torch.full(size=event_shape, fill_value=3.0),
    scale=torch.full(size=event_shape, fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)

x_new = flow.sample((10,))

Nontrivial event shapes

When the event has more than one axis, the base distribution must deal with flattened data. We show an example below.

Note

The requirement to work with flattened data may change in the future.

import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
from torchflows.base_distributions.gaussian import DiagonalGaussian

torch.manual_seed(0)
event_shape = (2, 3, 5)
event_size = int(torch.prod(torch.as_tensor(event_shape)))
base_distribution = DiagonalGaussian(
    loc=torch.full(size=(event_size,), fill_value=3.0),
    scale=torch.full(size=(event_size,), fill_value=2.0),
)
flow = Flow(RealNVP(event_shape), base_distribution=base_distribution)

x_new = flow.sample((10,))