A foundation model for brain activity recordings
We introduce the Brain Language Model (BrainLM), a foundation model for brain activity dynamics trained on 6,700 hours of fMRI recordings. Utilizing self-supervised masked-prediction training, BrainLM demonstrates proficiency in both fine-tuning and zero-shot inference. Fine-tuning allows for the accurate prediction of clinical variables like age, anxiety, and PTSD as well as forecasting of future brain states. In zero-shot mode, BrainLM can identify intrinsic functional networks directly from raw fMRI data.
Yale University · Wu Tsai Institute · Kavli Institute
We introduce the Brain Language Model (BrainLM), a foundation model for brain activity dynamics trained on 6,700 hours of fMRI recordings. Utilizing self-supervised masked-prediction training, BrainLM demonstrates proficiency in both fine-tuning and zero-shot inference tasks. Fine-tuning allows for the accurate prediction of clinical variables like age, anxiety, and PTSD as well as forecasting of future brain states. Critically, the model generalizes well to entirely new external cohorts not seen during training.
In zero-shot inference mode, BrainLM can identify intrinsic functional networks directly from raw fMRI data without any network-based supervision during training. The model also generates interpretable latent representations that reveal relationships between brain activity patterns and cognitive states. Overall, BrainLM offers a versatile and interpretable framework for elucidating the complex spatiotemporal dynamics of human brain activity.
Implementation: The public repository provides a decoder-only transformer with block-causal attention over timepoint-major tokens, Hugging Face Hub integration, and support for pretraining and finetuning on Arrow-format fMRI data.
BrainLM expects fMRI recordings in parcellated form: one time series per brain region (parcel).
424 parcels (e.g. AAL)
Each sample is a matrix of shape (timepoints × num_parcels), e.g. 200 × 424 or 200 × 400 (Schaefer). Each column is one parcel's BOLD time series.
One column per sample
Default column name: Voxelwise_RobustScaler_Normalized_Recording. Optional metadata columns (age, clinical scores) for finetuning.
424 × (X, Y, Z)
Separate Arrow dataset: one row per parcel, columns X, Y, Z (e.g. MNI space). Loaded once and broadcast to every sample.
Train and validation data are Hugging Face Arrow datasets saved with dataset.save_to_disk(path). Each example has a recording matrix and optional metadata.
# Train/val: each example
recording_col: (num_timepoints, num_parcels) # e.g. (500, 424)
Age.At.MHQ: float # optional, for finetuning
PHQ9.Severity: float # optional
# Coords: one dataset, num_parcels rows
X, Y, Z: float # one row per parcel
Timepoint-major order; patches per parcel; xyz spatial + sinusoidal temporal encodings.
Each parcel's time series (e.g. 200 timepoints) is split into patches of size P (e.g. 20) → 10 tokens per parcel. Each patch is projected to hidden size and receives xyz spatial (learned from parcel coordinates) and sinusoidal temporal encoding.
Tokens are arranged as: [CLS, all_parcels_t0, all_parcels_t1, …, all_parcels_tN]. Parcels at the same timepoint can attend to each other; attention across time is causal.
Block-causal attention: within-timepoint bidirectional, across-time causal. Masked reconstruction loss; CLS token for finetuning.
Parcels at the same timepoint attend to each other (bidirectional). Parcels at time t can attend to all times s ≤ t; no look-ahead. Implemented with a block lower-triangular mask.
Random: a fraction of tokens (e.g. 75%) replaced with learned [MASK]. Forward: only the last temporal token per parcel masked. Loss is MSE or MAE on masked positions.
A 3-layer MLP on the CLS token output supports scalar regression (e.g. age, PHQ9, PCL). Same input format as pretraining; optional metadata columns for targets.
Generate a small Arrow dataset and run BrainLM end-to-end without needing large fMRI archives.
Create Arrow datasets (train, val, coords) with the expected schema.
python generate_sample_data.py --output_dir ./sample_data --num_train 100 --num_val 20
Run pretraining on the synthetic data.
python train.py \
--output_dir ./runs/demo \
--train_dataset_path ./sample_data/train \
--val_dataset_path ./sample_data/val \
--coords_dataset_path ./sample_data/coords \
--num_timepoints_per_voxel 200 \
--timepoint_patching_size 20 \
--hidden_size 256 \
--num_hidden_layers 4 \
--max_train_samples 20
Finetune on a scalar target (e.g. age) using the same data format.
python finetune.py \
--model_name_or_path ./runs/demo \
--train_dataset_path ./sample_data/train \
--val_dataset_path ./sample_data/val \
--coords_dataset_path ./sample_data/coords \
--variable_of_interest_col_name Age.At.MHQ \
--output_dir ./runs/finetune_demo
Clone the repo and install dependencies.
git clone https://github.com/josueortc/BrainLM.git
cd BrainLM
pip install -r requirements.txt
# or: pip install -e .
If a pretrained model is on the Hub (e.g. josueortc/brainlm):
from transformers import AutoConfig, AutoModelForPreTraining
config = AutoConfig.from_pretrained("josueortc/brainlm", trust_remote_code=True)
model = AutoModelForPreTraining.from_pretrained("josueortc/brainlm", trust_remote_code=True)
If you use BrainLM in your research, please consider citing:
@inproceedings{ortegacaro2024brainlm,
title={BrainLM: A foundation model for brain activity recordings},
author={Ortega Caro, Josue and Oliveira Fonseca, Antonio Henrique and Rizvi, Syed A and Rosati, Matteo and Averill, Christopher and Cross, James L and Mittal, Prateek and Zappala, Emanuele and Dhodapkar, Rahul Madhav and Abdallah, Chadi and van Dijk, David},
booktitle={ICLR},
year={2024},
url={https://openreview.net/forum?id=RwI7ZEfR27}
}