Continuous bijections
- class torchflows.bijections.continuous.base.ContinuousBijection(event_shape: Size | Tuple[int, ...], f: ODEFunction, context_shape: Size | Tuple[int, ...] = None, end_time: float = 1.0, solver: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, **kwargs)
Base class for bijections of continuous normalizing flows.
Reference: Chen et al. “Neural Ordinary Differential Equations” (2019); https://arxiv.org/abs/1806.07366.
- __init__(event_shape: Size | Tuple[int, ...], f: ODEFunction, context_shape: Size | Tuple[int, ...] = None, end_time: float = 1.0, solver: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, **kwargs)
ContinuousBijection constructor.
- Parameters:
event_shape – shape of the event tensor.
f – function to be integrated.
context_shape – shape of the context tensor.
end_time – integrate f from time 0 to this time. Default: 1.
solver – which solver to use.
atol – absolute tolerance for numerical integration.
rtol – relative tolerance for numerical integration.
kwargs – unused.
- forward(x: Tensor, integration_times: Tensor = None, noise: Tensor = None, **kwargs) Tuple[Tensor, Tensor]
Forward pass for the continuous bijection.
- Parameters:
x (torch.Tensor) – tensor with shape (*batch_shape, *event_shape).
integration_times (torch.Tensor) –
noise (torch.Tensor) –
kwargs – keyword arguments to be passed to self.inverse.
- Returns:
transformed tensor and log determinant of the transformation.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- inverse(z: Tensor, integration_times: Tensor = None, **kwargs) Tuple[Tensor, Tensor]
Inverse pass of the continuous bijection.
- Parameters:
z – tensor with shape (*batch_shape, *event_shape).
integration_times –
kwargs – keyword arguments passed to self.f.before_odeint in the torchdiffeq solver.
- Returns:
transformed tensor and log determinant of the transformation.
- Return type:
Tuple[torch.Tensor, torch.Tensor]