Pattern Reconstruction and Interpretation with a Structured Multimodal Transformer
An interpretable transformer framework for spatiotemporal neural dynamics. PRISMT combines masked autoencoding with causal temporal attention to learn structured representations from multi-region brain recordings and reveal which regions, timepoints, and cross-modal interactions drive neural computation.
Yale University · Department of Neuroscience · Kavli Institute · Wu Tsai Institute
Understanding how distributed neural circuits coordinate across brain regions and time is a central challenge in neuroscience. Standard analysis methods often treat brain areas independently or rely on hand-crafted features, missing the structured spatiotemporal dependencies that underlie perception, learning, and decision-making. PRISMT addresses this gap with a general-purpose transformer architecture designed for interpretable analysis of multi-region neural recordings.
PRISMT tokenizes spatiotemporal brain activity — where each scalar measurement (one brain area at one timepoint) becomes a token — and applies masked autoencoding with causal temporal attention to learn structured representations. The causal mask ensures that predictions at time t depend only on past and present activity, respecting the temporal structure of neural computation. Through attention rollout, the model produces interpretable maps that reveal which brain regions, timepoints, and cross-modal interactions are most informative for a given task.
PRISMT is designed to work with any trial-based spatiotemporal neural data — widefield calcium imaging, multi-region electrophysiology, or multimodal recordings. It supports both classification (e.g., decoding behavioral states, learning phases, or genotypes) and reconstruction tasks, making it a flexible tool for discovery-driven neuroscience.
In the motivating application, we applied PRISMT to simultaneous dual-color mesoscopic imaging of acetylcholine (ACh) and calcium signaling in the neocortex of awake mice across multiple stages of visual task learning. PRISMT identified key cortical regions that reorganize with learning and revealed that cholinergic signaling exhibits spatiotemporally selective plasticity in frontal cortical subnetworks during visual perceptual learning.
Dual-color mesoscopic imaging captures both acetylcholine and calcium dynamics across the cortical surface.
~25 Allen Atlas Regions
Bilateral cortical areas parcellated using the Allen Brain Atlas, spanning visual, somatosensory, motor, and frontal cortex.
ACh + Calcium
Simultaneous imaging of cholinergic (ACh) and neural (calcium) activity, creating a multimodal view of cortical dynamics.
Timepoints × Brain Areas
Each trial is a matrix of shape (timepoints, brain_areas) containing dF/F or z-scored fluorescence signals.
All data is stored in a standardized .mat structure with the following organization:
standardized_data
├── n_datasets: scalar
└── dataset_001: struct
├── dff: (trials, timepoints, brain_areas) % raw fluorescence
├── zscore: (trials, timepoints, brain_areas) % z-scored signal
├── stim: (trials, 1) % stimulus condition
├── response: (trials, 1) % behavioral response
├── phase: (trials, 1) % learning phase
└── mouse: (trials, 1) % subject identity
PRISMTransformer supports configurable tokenization strategies. The choice of tokenization determines the granularity at which the model operates.
Each brain area becomes a single token. The full time series for that area is projected to the hidden dimension via a linear layer.
Each scalar value (one timepoint, one brain area) becomes a token. This creates a fine-grained sequence that enables masked reconstruction.
Tokens are arranged in row-major order over time, then brain areas:
PRISMTransformer uses causal temporal attention for both classification and masked reconstruction, with configurable tokenization.
Uses the same CausalTemporalAttention mechanism as reconstruction. With area-level tokenization, all tokens share a single effective timepoint, so the causal mask allows full bidirectional attention — consistent mechanism, appropriate behaviour for classification.
After training, attention weights are composed across layers (with residual mixing) to produce a single CLS-to-areas attention map, revealing which brain regions most influence the classification decision.
Supports classification (learning phase, genotype) and regression. The CLS token's final representation is projected through a linear head.
Random brain areas are masked at the input, replaced with a shared learnable [MASK] token. The model learns to reconstruct the masked values from the remaining context.
A structured attention mask enforces temporal causality: tokens at time t can attend to all tokens at time s ≤ t. Within each timepoint, all brain areas attend freely to each other.
Both ACh and calcium modalities share the same token space. The model learns cross-modal dependencies through the attention mechanism, discovering how neuromodulatory patterns relate to neural activity.
Interact with the visualizations below to understand how PRISMT processes cortical data.
Click on individual cells to mask tokens (each cell is one timepoint × brain area scalar). Then watch how the PRISMTransformer reconstructs the masked values from surrounding context using causal temporal attention.
Hover over cells to see which tokens can attend to which. The causal mask ensures that information flows only forward in time, while brain areas within the same timepoint can freely communicate.
Standardize your widefield imaging data into the expected .mat format using the provided MATLAB scripts or the PRISMT GUI.
% In MATLAB: standardize your data
run_prismt_gui % Launch the PRISMT GUI
% Or use the command-line standardization script:
standardize_data('your_data.mat', 'output_standardized.mat')
Set model hyperparameters and training configuration. PRISMT supports both command-line arguments and the MATLAB GUI for configuration.
# Train a classification model (e.g., learning phase)
python train.py \
--data_path data/standardized.mat \
--task_name phase_classification \
--hidden_dim 128 \
--num_heads 4 \
--num_layers 3 \
--batch_size 16 \
--epochs 100 \
--lr 5e-5 \
--scheduler_type cosine_warmup
Extract attention rollout maps and generate brain visualizations to identify which cortical regions drive classification.
# Run analysis and generate brain maps
python analyze_results.py \
--checkpoint results/best_model.pt \
--data_path data/standardized.mat \
--output_dir results/analysis/
Use Optuna-based HPO to find the best model configuration for your specific dataset.
# Run HPO with Optuna
python hpo_optuna.py \
--data_path data/standardized.mat \
--n_trials 50 \
--task_name phase_classification
If you find PRISMT useful for your research, please consider citing:
@article{ortegacaro2025cholinergic,
title={Selective changes in cortical cholinergic signaling during learning},
author={Ortega Caro, Josue and Batchelor, Hannah M. and Lohani, Sweyta
and van Dijk, David and Cardin, Jessica A.},
journal={bioRxiv},
year={2025},
doi={10.1101/2025.08.29.673096},
publisher={Cold Spring Harbor Laboratory}
}