Event shapes
Torchflows supports modeling tensors with arbitrary shapes. For example, we can model events with shape (2, 3, 5) as follows:
import torch
from torchflows.flows import Flow
from torchflows.architectures import RealNVP
torch.manual_seed(0)
event_shape = (2, 3, 5)
n_data = 1000
x_train = torch.randn(size=(n_data, *event_shape))
print(x_train.shape) # (1000, 2, 3, 5)
flow = Flow(RealNVP(event_shape))
flow.fit(x_train, show_progress=True)
x_new = flow.sample((500,))
print(x_new.shape) # (500, 2, 3, 5)