Bijections
All normalizing flow transformations are bijections and compositions thereof.
Base bijections
The following classes define forward and inverse pass methods which all bijections inherit.
- class torchflows.bijections.base.Bijection(event_shape: Size | Tuple[int, ...], context_shape: Size | Tuple[int, ...] = None, **kwargs)
Bijection class.
- __init__(event_shape: Size | Tuple[int, ...], context_shape: Size | Tuple[int, ...] = None, **kwargs)
Bijection constructor.
- Parameters:
event_shape – shape of the event tensor.
context_shape – shape of the context tensor.
kwargs – unused.
- forward(x: Tensor, context: Tensor = None) Tuple[Tensor, Tensor]
Forward bijection map. Returns the output vector and the log Jacobian determinant of the forward transform.
- Parameters:
x (torch.Tensor) – input array with shape (*batch_shape, *event_shape).
context (torch.Tensor) – context array with shape (*batch_shape, *context_shape).
- Returns:
output array and log determinant. The output array has shape (*batch_shape, *event_shape); the log determinant has shape (*batch_shape,).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- inverse(z: Tensor, context: Tensor = None) Tuple[Tensor, Tensor]
Inverse bijection map. Returns the output vector and the log Jacobian determinant of the inverse transform.
- Parameters:
z – input array with shape (*batch_shape, *event_shape).
context – context array with shape (*batch_shape, *context_shape).
- Returns:
output array and log determinant. The output array has shape (*batch_shape, *event_shape); the log determinant has shape (*batch_shape,).
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- class torchflows.bijections.base.BijectiveComposition(event_shape: Size | Tuple[int, ...], layers: List[Bijection], context_shape: Size | Tuple[int, ...] = None, **kwargs)
Composition of bijections. Inherits from Bijection.
- __init__(event_shape: Size | Tuple[int, ...], layers: List[Bijection], context_shape: Size | Tuple[int, ...] = None, **kwargs)
BijectiveComposition constructor.
- Parameters:
event_shape – shape of the event tensor.
layers (List[Bijection]) – bijection layers.
context_shape – shape of the context tensor.
kwargs – unused.
Bijection subclasses for different NF families
To improve efficiency of forward and inverse passes in NF layers, we subclass the base bijections with respect to different families of NF architectures. On the pages below, we list base classes for each family, and provide a list of already implemented classes.
Inverting a bijection
Each bijection can be inverted with the invert function.