Flow wrappers

The Flow object contains a base distribution and a bijection.

class torchflows.flows.BaseFlow(event_shape, base_distribution: Distribution | str = 'standard_normal')

Base normalizing flow class.

fit(x_train: Tensor, n_epochs: int = 500, lr: float = 0.05, batch_size: int | str = 1024, shuffle: bool = True, show_progress: bool = False, w_train: Tensor = None, context_train: Tensor = None, x_val: Tensor = None, w_val: Tensor = None, context_val: Tensor = None, keep_best_weights: bool = True, early_stopping: bool = False, early_stopping_threshold: int = 50, max_batch_size_mb: int = None, time_limit_seconds: float | int = None)

Fit the normalizing flow to a dataset.

Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. Bijection parameters are iteratively updated for a specified number of epochs. If context data is provided, the normalizing flow learns the distribution of data conditional on context data.

Parameters:
  • x_train – training data with shape (n_training_data, *event_shape).

  • n_epochs – perform fitting for this many steps.

  • lr – learning rate. In general, lower learning rates are recommended for high-parametric bijections.

  • batch_size – in each epoch, split training data into batches of this size and perform a parameter update for each batch.

  • shuffle – shuffle training data. This helps avoid incorrect fitting if nearby training samples are similar.

  • show_progress – show a progress bar with the current batch loss.

  • w_train – training data weights with shape (n_training_data,).

  • context_train – training data context tensor with shape (n_training_data, *context_shape).

  • x_val – validation data with shape (n_validation_data, *event_shape).

  • w_val – validation data weights with shape (n_validation_data,).

  • context_val – validation data context tensor with shape (n_validation_data, *context_shape).

  • keep_best_weights – if True and validation data is provided, keep the bijection weights with the highest probability of validation data.

  • early_stopping – if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs.

  • early_stopping_threshold – if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs.

  • max_batch_size_mb (int) – maximum batch size in megabytes.

  • time_limit_seconds (Union[float, int]) – maximum allowed time for training.

regularization()

Compute the regularization term used in training.

variational_fit(target_log_prob: callable, n_epochs: int = 500, lr: float = 0.05, n_samples: int = 1, early_stopping: bool = False, early_stopping_threshold: int = 50, keep_best_weights: bool = True, show_progress: bool = False, check_for_divergences: bool = False, time_limit_seconds: float | int = None)

Train the normalizing flow to fit a target log probability.

Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. Refer to Rezende, Mohamed: “Variational Inference with Normalizing Flows” (2015) for more details (https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1).

Parameters:
  • target_log_prob (callable) – function that computes the unnormalized target log density for a batch of points. Receives input batch with shape (*batch_shape, *event_shape) and outputs batch with shape (*batch_shape).

  • n_epochs (int) – number of training epochs.

  • lr (float) – learning rate for the AdamW optimizer.

  • n_samples (float) – number of samples to estimate the variational loss in each training step.

  • show_progress (bool) – if True, show a progress bar during training.

class torchflows.flows.Flow(bijection: Bijection, **kwargs)

Normalizing flow class. Inherits from BaseFlow.

This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs.

__init__(bijection: Bijection, **kwargs)

Flow constructor.

Parameters:
  • bijection (Bijection) – transformation component of the normalizing flow.

  • kwargs – keyword arguments passed to BaseFlow.

forward_with_log_prob(x: Tensor, context: Tensor = None)

Transform the input x to the space of the base distribution.

Parameters:
  • x (torch.Tensor) – input tensor.

  • context (torch.Tensor) – context tensor upon which the transformation is conditioned.

Returns:

transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the transformation.

Return type:

Tuple[torch.Tensor, torch.Tensor]

log_prob(x: Tensor, context: Tensor = None) Tensor

Compute the logarithm of the probability density of input x according to the normalizing flow.

Parameters:
  • x (torch.Tensor) – input tensor.

  • context (torch.Tensor) – context tensor.

Returns:

tensor of log probabilities.

Return type:

torch.Tensor.

sample(sample_shape: int | Size | Tuple[int, ...], context: Tensor = None, no_grad: bool = False, return_log_prob: bool = False) Tensor | Tuple[Tensor, Tensor]

Sample from the normalizing flow.

If context given, sample n tensors for each context tensor. Otherwise, sample n tensors.

Parameters:
  • sample_shape – shape of tensors to sample.

  • context (torch.Tensor) – context tensor with shape c.

  • no_grad (bool) – if True, do not track gradients in the inverse pass.

  • return_log_prob – if True, return log probabilities of sampled points as the second tuple component.

Returns:

samples with shape (*sample_shape, *event_shape) if no context given or (*sample_shape, *c, *event_shape) if context given.

Return type:

torch.Tensor

class torchflows.flows.FlowMixture(flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False)

Base class for mixtures of normalizing flows. Inherits from BaseFlow.

A mixture uses flow objects as components, as well as their associated categorical distribution weights. It is a typical statistical mixture.

__init__(flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False)

FlowMixture constructor.

Parameters:
  • flows (List[Flow]) – normalizing flow components.

  • weights (List[float]) – mixture weights corresponding to flow components. All weights must be greater than 0. The sum of the weights must equal 1.

  • trainable_weights (bool) – if True, makes the weights trainable.

log_prob(x: Tensor, context: Tensor = None) Tensor

Compute the log probability density of inputs x.

Parameters:
  • x (torch.Tensor) – input tensor.

  • context (torch.Tensor) – context tensor.

Returns:

tensor of log probabilities.

Return type:

torch.Tensor

sample(n: int, context: Tensor = None, no_grad: bool = False, return_log_prob: bool = False) Tensor

Sample from the flow mixture.

Parameters:
  • n (int) – number of samples to draw.

  • context (torch.Tensor) – context tensor.

  • no_grad (bool) – if True, do not track gradients in the inverse pass during sampling.

  • return_log_prob – if True, return log probabilities of sampled points as the second tuple component.

Returns:

tensor of drawn samples.

Return type:

torch.Tensor