Continuous bijections

class torchflows.bijections.continuous.base.ContinuousBijection(event_shape: Size | Tuple[int, ...], f: ODEFunction, context_shape: Size | Tuple[int, ...] = None, end_time: float = 1.0, solver: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, **kwargs)

Base class for bijections of continuous normalizing flows.

Reference: Chen et al. “Neural Ordinary Differential Equations” (2019); https://arxiv.org/abs/1806.07366.

__init__(event_shape: Size | Tuple[int, ...], f: ODEFunction, context_shape: Size | Tuple[int, ...] = None, end_time: float = 1.0, solver: str = 'euler', atol: float = 1e-05, rtol: float = 1e-05, **kwargs)

ContinuousBijection constructor.

Parameters:
  • event_shape – shape of the event tensor.

  • f – function to be integrated.

  • context_shape – shape of the context tensor.

  • end_time – integrate f from time 0 to this time. Default: 1.

  • solver – which solver to use.

  • atol – absolute tolerance for numerical integration.

  • rtol – relative tolerance for numerical integration.

  • kwargs – unused.

forward(x: Tensor, integration_times: Tensor = None, noise: Tensor = None, **kwargs) Tuple[Tensor, Tensor]

Forward pass for the continuous bijection.

Parameters:
  • x (torch.Tensor) – tensor with shape (*batch_shape, *event_shape).

  • integration_times (torch.Tensor) –

  • noise (torch.Tensor) –

  • kwargs – keyword arguments to be passed to self.inverse.

Returns:

transformed tensor and log determinant of the transformation.

Return type:

Tuple[torch.Tensor, torch.Tensor]

inverse(z: Tensor, integration_times: Tensor = None, **kwargs) Tuple[Tensor, Tensor]

Inverse pass of the continuous bijection.

Parameters:
  • z – tensor with shape (*batch_shape, *event_shape).

  • integration_times

  • kwargs – keyword arguments passed to self.f.before_odeint in the torchdiffeq solver.

Returns:

transformed tensor and log determinant of the transformation.

Return type:

Tuple[torch.Tensor, torch.Tensor]