Geometry-Aware Longitudinal Flow Matching with Mixture of Experts
FLUX (Flow matching for Unpaired longitudinal data with miXture-of-experts) couples adjacent-marginal transport along a chain of snapshots with metric flow matching and a MixtureVelocityNet for simultaneous generative modeling and unsupervised regime discovery.
Yale University · Wu Tsai Institute · Kavli Institute
Many scientific systems are observed only as a sequence of unpaired population snapshots at discrete timepoints, yielding a chain of marginal distributions from which underlying dynamics and regime transitions must be inferred. Flow matching learns a velocity field whose ODE transports a source to a target without neural ODE simulation or restrictive diffusion schedules.
FLUX unifies longitudinal adjacent-marginal training with learned Riemannian interpolation (RBF or GAGA-pullback metrics and a bend network) and a mixture-of-experts velocity with Straight-Through Gumbel-Softmax gating, so transport stays consistent along the full chain while gating reveals distinct dynamical modes when they exist.
Geometry and bend networks are frozen before velocity training. Evaluation uses compute_metrics.py.
| Stage | Script | Output |
|---|---|---|
| 1. Geometry | train_benchmark_rbf.py or train_benchmark_vae.py |
rbf_network_best.pth or best_gaga_model.pth |
| 2. Bend | train_benchmark_bend.py |
bend_network_best.pth |
| 3. Velocity | train_benchmark_velocity.py |
velocity_network_best.pth |
Additional experiments in the paper use a low-dimensional multi-marginal control (synthetic 2D marginals along a smooth curve) to isolate geometry-aware interpolation where explicit regime structure is absent.
See the repository README for full options, conda setup, and dataset-specific flags.
git clone https://github.com/josueortc/flux.git
cd flux
conda create -n flux python=3.10 -y && conda activate flux
# Install PyTorch for your platform, then:
pip install -r requirements.txt
Learn the Riemannian metric for your benchmark dataset.
python scripts/benchmark_data/train_benchmark_rbf.py \
--dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
--save_dir saved_models/lorenz
Train the bend network using the saved geometry checkpoint.
python scripts/benchmark_data/train_benchmark_bend.py \
--dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
--geo_model_path saved_models/lorenz/rbf_network_best.pth \
--save_dir saved_models/lorenz
Train MixtureVelocityNet with Gumbel routing, then evaluate.
python scripts/benchmark_data/train_benchmark_velocity.py \
--dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
--geo_model_path saved_models/lorenz/rbf_network_best.pth \
--bend_model_path saved_models/lorenz/bend_network_best.pth \
--use_gumbel_routing --num_experts 2 --save_dir saved_models/lorenz
python scripts/benchmark_data/compute_metrics.py \
--dataset lorenz --model_dir saved_models/lorenz
If you use FLUX in your research, please cite:
@article{ortegacaro2026flux,
title={FLUX: Geometry-Aware Longitudinal Flow Matching with Mixture of Experts},
author={Ortega Caro, Josue and others},
journal={Under review},
year={2026}
}