FLUX

Geometry-Aware Longitudinal Flow Matching with Mixture of Experts

FLUX (Flow matching for Unpaired longitudinal data with miXture-of-experts) couples adjacent-marginal transport along a chain of snapshots with metric flow matching and a MixtureVelocityNet for simultaneous generative modeling and unsupervised regime discovery.

Josue Ortega Caro Collaborators

Yale University · Wu Tsai Institute · Kavli Institute

Code Paper forthcoming
Figure overview: five unpaired marginal snapshots (t₁–t₅) and a three-regime manifold with colored samples.
(A) Unpaired marginals at each timepoint; (B) manifold with distinct dynamical regimes.
Figure Overview Contributions Pipeline Benchmarks Usage Cite

Longitudinal flows on curved manifolds

Many scientific systems are observed only as a sequence of unpaired population snapshots at discrete timepoints, yielding a chain of marginal distributions from which underlying dynamics and regime transitions must be inferred. Flow matching learns a velocity field whose ODE transports a source to a target without neural ODE simulation or restrictive diffusion schedules.

FLUX unifies longitudinal adjacent-marginal training with learned Riemannian interpolation (RBF or GAGA-pullback metrics and a bend network) and a mixture-of-experts velocity with Straight-Through Gumbel-Softmax gating, so transport stays consistent along the full chain while gating reveals distinct dynamical modes when they exist.

What FLUX adds

  • Coupled pipeline: adjacent-marginal chaining with metric flow matching and learned conditional paths across the marginal chain, including high-dimensional data and irregular session spacing.
  • MixtureVelocityNet: MoE velocity with Gumbel-Softmax gating for unsupervised regime discovery; manifold-aware paths are important so gating does not collapse when regimes exist.
  • Benchmarks: Lorenz attractor with distinct dynamical regimes, NeuralTable widefield calcium imaging across learning, embryoid body (EB) scRNA-seq across differentiation stages, plus generative and regime-discovery baselines (including a 2D multi-marginal control from the paper).

Three-stage training

Geometry and bend networks are frozen before velocity training. Evaluation uses compute_metrics.py.

Stage Script Output
1. Geometry train_benchmark_rbf.py or train_benchmark_vae.py rbf_network_best.pth or best_gaga_model.pth
2. Bend train_benchmark_bend.py bend_network_best.pth
3. Velocity train_benchmark_velocity.py velocity_network_best.pth

Datasets

Lorenz
8 marginals
Chaotic vs subcritical regimes; sharp distributional shift and multi-step composition metrics.
NeuralTable
Widefield Ca2+
41 cortical areas, 22 day-level marginals, Go/No-Go visual learning across phases.
Embryoid body (EB)
scRNA-seq
~17k cells, five timepoints, three differentiation stages (pluripotent, commitment, differentiated) in PCA embedding space.

Additional experiments in the paper use a low-dimensional multi-marginal control (synthetic 2D marginals along a smooth curve) to isolate geometry-aware interpolation where explicit regime structure is absent.

Install and run (Lorenz)

See the repository README for full options, conda setup, and dataset-specific flags.

Clone and environment

git clone https://github.com/josueortc/flux.git
cd flux
conda create -n flux python=3.10 -y && conda activate flux
# Install PyTorch for your platform, then:
pip install -r requirements.txt
1

Metric (RBF)

Learn the Riemannian metric for your benchmark dataset.

python scripts/benchmark_data/train_benchmark_rbf.py \
  --dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
  --save_dir saved_models/lorenz
2

Bend network

Train the bend network using the saved geometry checkpoint.

python scripts/benchmark_data/train_benchmark_bend.py \
  --dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
  --geo_model_path saved_models/lorenz/rbf_network_best.pth \
  --save_dir saved_models/lorenz
3

Velocity + MoE

Train MixtureVelocityNet with Gumbel routing, then evaluate.

python scripts/benchmark_data/train_benchmark_velocity.py \
  --dataset lorenz --lorenz_mode day_marginals --num_marginals 8 \
  --geo_model_path saved_models/lorenz/rbf_network_best.pth \
  --bend_model_path saved_models/lorenz/bend_network_best.pth \
  --use_gumbel_routing --num_experts 2 --save_dir saved_models/lorenz

python scripts/benchmark_data/compute_metrics.py \
  --dataset lorenz --model_dir saved_models/lorenz

Citation

If you use FLUX in your research, please cite:

@article{ortegacaro2026flux,
  title={FLUX: Geometry-Aware Longitudinal Flow Matching with Mixture of Experts},
  author={Ortega Caro, Josue and others},
  journal={Under review},
  year={2026}
}