Flow wrappers
The Flow object contains a base distribution and a bijection.
- class torchflows.flows.BaseFlow(event_shape, base_distribution: Distribution | str = 'standard_normal')
Base normalizing flow class.
- fit(x_train: Tensor, n_epochs: int = 500, lr: float = 0.05, batch_size: int | str = 1024, shuffle: bool = True, show_progress: bool = False, w_train: Tensor = None, context_train: Tensor = None, x_val: Tensor = None, w_val: Tensor = None, context_val: Tensor = None, keep_best_weights: bool = True, early_stopping: bool = False, early_stopping_threshold: int = 50, max_batch_size_mb: int = None, time_limit_seconds: float | int = None)
Fit the normalizing flow to a dataset.
Fitting the flow means finding the parameters of the bijection that maximize the probability of training data. Bijection parameters are iteratively updated for a specified number of epochs. If context data is provided, the normalizing flow learns the distribution of data conditional on context data.
- Parameters:
x_train – training data with shape (n_training_data, *event_shape).
n_epochs – perform fitting for this many steps.
lr – learning rate. In general, lower learning rates are recommended for high-parametric bijections.
batch_size – in each epoch, split training data into batches of this size and perform a parameter update for each batch.
shuffle – shuffle training data. This helps avoid incorrect fitting if nearby training samples are similar.
show_progress – show a progress bar with the current batch loss.
w_train – training data weights with shape (n_training_data,).
context_train – training data context tensor with shape (n_training_data, *context_shape).
x_val – validation data with shape (n_validation_data, *event_shape).
w_val – validation data weights with shape (n_validation_data,).
context_val – validation data context tensor with shape (n_validation_data, *context_shape).
keep_best_weights – if True and validation data is provided, keep the bijection weights with the highest probability of validation data.
early_stopping – if True and validation data is provided, stop the training procedure early once validation loss stops improving for a specified number of consecutive epochs.
early_stopping_threshold – if early_stopping is True, fitting stops after no improvement in validation loss for this many epochs.
max_batch_size_mb (int) – maximum batch size in megabytes.
time_limit_seconds (Union[float, int]) – maximum allowed time for training.
- regularization()
Compute the regularization term used in training.
- variational_fit(target_log_prob: callable, n_epochs: int = 500, lr: float = 0.05, n_samples: int = 1, early_stopping: bool = False, early_stopping_threshold: int = 50, keep_best_weights: bool = True, show_progress: bool = False, check_for_divergences: bool = False, time_limit_seconds: float | int = None)
Train the normalizing flow to fit a target log probability.
Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. Refer to Rezende, Mohamed: “Variational Inference with Normalizing Flows” (2015) for more details (https://arxiv.org/abs/1505.05770, loss definition in Equation 15, training pseudocode for conditional flows in Algorithm 1).
- Parameters:
target_log_prob (callable) – function that computes the unnormalized target log density for a batch of points. Receives input batch with shape (*batch_shape, *event_shape) and outputs batch with shape (*batch_shape).
n_epochs (int) – number of training epochs.
lr (float) – learning rate for the AdamW optimizer.
n_samples (float) – number of samples to estimate the variational loss in each training step.
show_progress (bool) – if True, show a progress bar during training.
- class torchflows.flows.Flow(bijection: Bijection, **kwargs)
Normalizing flow class. Inherits from BaseFlow.
This class represents a bijective transformation of a standard Gaussian distribution (the base distribution). A normalizing flow is itself a distribution which we can sample from or use it to compute the density of inputs.
- __init__(bijection: Bijection, **kwargs)
Flow constructor.
- Parameters:
bijection (Bijection) – transformation component of the normalizing flow.
kwargs – keyword arguments passed to BaseFlow.
- forward_with_log_prob(x: Tensor, context: Tensor = None)
Transform the input x to the space of the base distribution.
- Parameters:
x (torch.Tensor) – input tensor.
context (torch.Tensor) – context tensor upon which the transformation is conditioned.
- Returns:
transformed tensor and the logarithm of the absolute value of the Jacobian determinant of the transformation.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- log_prob(x: Tensor, context: Tensor = None) Tensor
Compute the logarithm of the probability density of input x according to the normalizing flow.
- Parameters:
x (torch.Tensor) – input tensor.
context (torch.Tensor) – context tensor.
- Returns:
tensor of log probabilities.
- Return type:
torch.Tensor.
- sample(sample_shape: int | Size | Tuple[int, ...], context: Tensor = None, no_grad: bool = False, return_log_prob: bool = False) Tensor | Tuple[Tensor, Tensor]
Sample from the normalizing flow.
If context given, sample n tensors for each context tensor. Otherwise, sample n tensors.
- Parameters:
sample_shape – shape of tensors to sample.
context (torch.Tensor) – context tensor with shape c.
no_grad (bool) – if True, do not track gradients in the inverse pass.
return_log_prob – if True, return log probabilities of sampled points as the second tuple component.
- Returns:
samples with shape (*sample_shape, *event_shape) if no context given or (*sample_shape, *c, *event_shape) if context given.
- Return type:
torch.Tensor
- class torchflows.flows.FlowMixture(flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False)
Base class for mixtures of normalizing flows. Inherits from BaseFlow.
A mixture uses flow objects as components, as well as their associated categorical distribution weights. It is a typical statistical mixture.
- __init__(flows: List[Flow], weights: List[float] = None, trainable_weights: bool = False)
FlowMixture constructor.
- Parameters:
flows (List[Flow]) – normalizing flow components.
weights (List[float]) – mixture weights corresponding to flow components. All weights must be greater than 0. The sum of the weights must equal 1.
trainable_weights (bool) – if True, makes the weights trainable.
- log_prob(x: Tensor, context: Tensor = None) Tensor
Compute the log probability density of inputs x.
- Parameters:
x (torch.Tensor) – input tensor.
context (torch.Tensor) – context tensor.
- Returns:
tensor of log probabilities.
- Return type:
torch.Tensor
- sample(n: int, context: Tensor = None, no_grad: bool = False, return_log_prob: bool = False) Tensor
Sample from the flow mixture.
- Parameters:
n (int) – number of samples to draw.
context (torch.Tensor) – context tensor.
no_grad (bool) – if True, do not track gradients in the inverse pass during sampling.
return_log_prob – if True, return log probabilities of sampled points as the second tuple component.
- Returns:
tensor of drawn samples.
- Return type:
torch.Tensor