BrainLM

A foundation model for brain activity recordings

We introduce the Brain Language Model (BrainLM), a foundation model for brain activity dynamics trained on 6,700 hours of fMRI recordings. Utilizing self-supervised masked-prediction training, BrainLM demonstrates proficiency in both fine-tuning and zero-shot inference. Fine-tuning allows for the accurate prediction of clinical variables like age, anxiety, and PTSD as well as forecasting of future brain states. In zero-shot mode, BrainLM can identify intrinsic functional networks directly from raw fMRI data.

Josue Ortega Caro Antonio Henrique de Oliveira Fonseca Syed A Rizvi Matteo Rosati Christopher Averill James L Cross Prateek Mittal Emanuele Zappala Rahul Madhav Dhodapkar Chadi Abdallah David van Dijk

Yale University · Wu Tsai Institute · Kavli Institute

Abstract Data Tokenization Architecture Usage Cite

Foundation model for brain dynamics

We introduce the Brain Language Model (BrainLM), a foundation model for brain activity dynamics trained on 6,700 hours of fMRI recordings. Utilizing self-supervised masked-prediction training, BrainLM demonstrates proficiency in both fine-tuning and zero-shot inference tasks. Fine-tuning allows for the accurate prediction of clinical variables like age, anxiety, and PTSD as well as forecasting of future brain states. Critically, the model generalizes well to entirely new external cohorts not seen during training.

In zero-shot inference mode, BrainLM can identify intrinsic functional networks directly from raw fMRI data without any network-based supervision during training. The model also generates interpretable latent representations that reveal relationships between brain activity patterns and cognitive states. Overall, BrainLM offers a versatile and interpretable framework for elucidating the complex spatiotemporal dynamics of human brain activity.

Implementation: The public repository provides a decoder-only transformer with block-causal attention over timepoint-major tokens, Hugging Face Hub integration, and support for pretraining and finetuning on Arrow-format fMRI data.

Parcellated fMRI (BOLD)

BrainLM expects fMRI recordings in parcellated form: one time series per brain region (parcel).

Parcellated BOLD

424 parcels (e.g. AAL)

Each sample is a matrix of shape (timepoints × num_parcels), e.g. 200 × 424 or 200 × 400 (Schaefer). Each column is one parcel's BOLD time series.

Recording column

One column per sample

Default column name: Voxelwise_RobustScaler_Normalized_Recording. Optional metadata columns (age, clinical scores) for finetuning.

Coordinates

424 × (X, Y, Z)

Separate Arrow dataset: one row per parcel, columns X, Y, Z (e.g. MNI space). Loaded once and broadcast to every sample.

Arrow dataset format

Train and validation data are Hugging Face Arrow datasets saved with dataset.save_to_disk(path). Each example has a recording matrix and optional metadata.

# Train/val: each example
  recording_col:  (num_timepoints, num_parcels)  # e.g. (500, 424)
  Age.At.MHQ:     float  # optional, for finetuning
  PHQ9.Severity:  float  # optional

# Coords: one dataset, num_parcels rows
  X, Y, Z: float  # one row per parcel

From fMRI to tokens

Timepoint-major order; patches per parcel; xyz spatial + sinusoidal temporal encodings.

BrainLM

Patch-based, timepoint-major

Each parcel's time series (e.g. 200 timepoints) is split into patches of size P (e.g. 20) → 10 tokens per parcel. Each patch is projected to hidden size and receives xyz spatial (learned from parcel coordinates) and sinusoidal temporal encoding.

Token count
(T/P) × num_parcels + 1 (CLS)
Order
All parcels at t=0, then t=1, …
Position
xyz_emb + temporal_emb

Token order (timepoint-major)

Tokens are arranged as: [CLS, all_parcels_t0, all_parcels_t1, …, all_parcels_tN]. Parcels at the same timepoint can attend to each other; attention across time is causal.

Decoder-only transformer

Block-causal attention: within-timepoint bidirectional, across-time causal. Masked reconstruction loss; CLS token for finetuning.

fMRI (T, 424) Mask Random % or forward [MASK] Tokenize Patch + xyz + temporal + CLS Transformer Block-causal attention PreNorm, FFN (GELU) × L layers t attends to s ≤ t Output Recon (masked) CLS → MLP

Block-causal attention

Parcels at the same timepoint attend to each other (bidirectional). Parcels at time t can attend to all times s ≤ t; no look-ahead. Implemented with a block lower-triangular mask.

Masking

Random: a fraction of tokens (e.g. 75%) replaced with learned [MASK]. Forward: only the last temporal token per parcel masked. Loss is MSE or MAE on masked positions.

Finetuning

A 3-layer MLP on the CLS token output supports scalar regression (e.g. age, PHQ9, PCL). Same input format as pretraining; optional metadata columns for targets.

Run BrainLM on toy data

Generate a small Arrow dataset and run BrainLM end-to-end without needing large fMRI archives.

Synthetic quickstart (zero dependencies)

1

Generate synthetic data

Create Arrow datasets (train, val, coords) with the expected schema.

python generate_sample_data.py --output_dir ./sample_data --num_train 100 --num_val 20
2

Pretrain

Run pretraining on the synthetic data.

python train.py \
  --output_dir ./runs/demo \
  --train_dataset_path ./sample_data/train \
  --val_dataset_path ./sample_data/val \
  --coords_dataset_path ./sample_data/coords \
  --num_timepoints_per_voxel 200 \
  --timepoint_patching_size 20 \
  --hidden_size 256 \
  --num_hidden_layers 4 \
  --max_train_samples 20
3

Finetune

Finetune on a scalar target (e.g. age) using the same data format.

python finetune.py \
  --model_name_or_path ./runs/demo \
  --train_dataset_path ./sample_data/train \
  --val_dataset_path ./sample_data/val \
  --coords_dataset_path ./sample_data/coords \
  --variable_of_interest_col_name Age.At.MHQ \
  --output_dir ./runs/finetune_demo

How to use BrainLM

1

Install

Clone the repo and install dependencies.

git clone https://github.com/josueortc/BrainLM.git
cd BrainLM
pip install -r requirements.txt
# or: pip install -e .
2

Load from Hugging Face Hub

If a pretrained model is on the Hub (e.g. josueortc/brainlm):

from transformers import AutoConfig, AutoModelForPreTraining

config = AutoConfig.from_pretrained("josueortc/brainlm", trust_remote_code=True)
model = AutoModelForPreTraining.from_pretrained("josueortc/brainlm", trust_remote_code=True)

Cite BrainLM

If you use BrainLM in your research, please consider citing:

@inproceedings{ortegacaro2024brainlm,
  title={BrainLM: A foundation model for brain activity recordings},
  author={Ortega Caro, Josue and Oliveira Fonseca, Antonio Henrique and Rizvi, Syed A and Rosati, Matteo and Averill, Christopher and Cross, James L and Mittal, Prateek and Zappala, Emanuele and Dhodapkar, Rahul Madhav and Abdallah, Chadi and van Dijk, David},
  booktitle={ICLR},
  year={2024},
  url={https://openreview.net/forum?id=RwI7ZEfR27}
}