PRISMT

Pattern Reconstruction and Interpretation with a Structured Multimodal Transformer

An interpretable transformer framework for spatiotemporal neural dynamics. PRISMT combines masked autoencoding with causal temporal attention to learn structured representations from multi-region brain recordings and reveal which regions, timepoints, and cross-modal interactions drive neural computation.

Josue Ortega Caro Hannah M. Batchelor Sweyta Lohani David van Dijk Jessica A. Cardin

Yale University · Department of Neuroscience · Kavli Institute · Wu Tsai Institute

Abstract Data Tokenization Architecture Demos Usage Cite

Interpretable Models for Neural Dynamics

Understanding how distributed neural circuits coordinate across brain regions and time is a central challenge in neuroscience. Standard analysis methods often treat brain areas independently or rely on hand-crafted features, missing the structured spatiotemporal dependencies that underlie perception, learning, and decision-making. PRISMT addresses this gap with a general-purpose transformer architecture designed for interpretable analysis of multi-region neural recordings.

PRISMT tokenizes spatiotemporal brain activity — where each scalar measurement (one brain area at one timepoint) becomes a token — and applies masked autoencoding with causal temporal attention to learn structured representations. The causal mask ensures that predictions at time t depend only on past and present activity, respecting the temporal structure of neural computation. Through attention rollout, the model produces interpretable maps that reveal which brain regions, timepoints, and cross-modal interactions are most informative for a given task.

PRISMT is designed to work with any trial-based spatiotemporal neural data — widefield calcium imaging, multi-region electrophysiology, or multimodal recordings. It supports both classification (e.g., decoding behavioral states, learning phases, or genotypes) and reconstruction tasks, making it a flexible tool for discovery-driven neuroscience.

Application: Cortical Cholinergic Signaling During Learning

In the motivating application, we applied PRISMT to simultaneous dual-color mesoscopic imaging of acetylcholine (ACh) and calcium signaling in the neocortex of awake mice across multiple stages of visual task learning. PRISMT identified key cortical regions that reorganize with learning and revealed that cholinergic signaling exhibits spatiotemporally selective plasticity in frontal cortical subnetworks during visual perceptual learning.

Widefield Calcium Imaging

Dual-color mesoscopic imaging captures both acetylcholine and calcium dynamics across the cortical surface.

Cortical Coverage

~25 Allen Atlas Regions

Bilateral cortical areas parcellated using the Allen Brain Atlas, spanning visual, somatosensory, motor, and frontal cortex.

Dual-Color Signals

ACh + Calcium

Simultaneous imaging of cholinergic (ACh) and neural (calcium) activity, creating a multimodal view of cortical dynamics.

Trial Structure

Timepoints × Brain Areas

Each trial is a matrix of shape (timepoints, brain_areas) containing dF/F or z-scored fluorescence signals.

Standardized Data Format

All data is stored in a standardized .mat structure with the following organization:

standardized_data
  ├── n_datasets: scalar
  └── dataset_001: struct
      ├── dff:      (trials, timepoints, brain_areas)  % raw fluorescence
      ├── zscore:   (trials, timepoints, brain_areas)  % z-scored signal
      ├── stim:     (trials, 1)   % stimulus condition
      ├── response: (trials, 1)   % behavioral response
      ├── phase:    (trials, 1)   % learning phase
      └── mouse:    (trials, 1)   % subject identity

From Brain Signals to Tokens

PRISMTransformer supports configurable tokenization strategies. The choice of tokenization determines the granularity at which the model operates.

Classification

Area-level Tokenization

Each brain area becomes a single token. The full time series for that area is projected to the hidden dimension via a linear layer.

Token count
Nareas + 1 (CLS)
Embedding
Linear(time_points → hidden_dim)
Output
CLS token → classifier head
Reconstruction

Scalar Tokenization

Each scalar value (one timepoint, one brain area) becomes a token. This creates a fine-grained sequence that enables masked reconstruction.

Token count
Ntime × Nareas
Embedding
Linear(1 → hidden_dim)
Position
time_emb + area_emb

Token Ordering (Scalar Tokenization)

Tokens are arranged in row-major order over time, then brain areas:

PRISMTransformer Architecture

PRISMTransformer uses causal temporal attention for both classification and masked reconstruction, with configurable tokenization.

Input (T, N_areas) Tokenize Transpose Linear(T → D) + CLS token Position Learnable + Dropout Transformer PreNorm Causal Temporal Attn FFN (GELU) × L layers CLS Head Extract CLS Linear → classes Attention Rollout: CLS → brain areas Interpretable maps of which areas drive classification

Causal Temporal Attention

Uses the same CausalTemporalAttention mechanism as reconstruction. With area-level tokenization, all tokens share a single effective timepoint, so the causal mask allows full bidirectional attention — consistent mechanism, appropriate behaviour for classification.

Attention Rollout

After training, attention weights are composed across layers (with residual mixing) to produce a single CLS-to-areas attention map, revealing which brain regions most influence the classification decision.

Output

Supports classification (learning phase, genotype) and regression. The CLS token's final representation is projected through a linear head.

Input (T, N_areas) Mask Random areas zero + replace Learnable [M] Tokenize Each scalar Linear(1→D) T×N tokens Position time_emb + area_emb Transformer PreNorm Causal Temporal Attention FFN (GELU) × L layers Recon Linear D → 1 Causal Temporal Masking Token at time t attends to time s where s ≤ t Same-timepoint areas attend freely to each other

Masked Autoencoding

Random brain areas are masked at the input, replaced with a shared learnable [MASK] token. The model learns to reconstruct the masked values from the remaining context.

Causal Temporal Attention

A structured attention mask enforces temporal causality: tokens at time t can attend to all tokens at time s ≤ t. Within each timepoint, all brain areas attend freely to each other.

Multimodal Integration

Both ACh and calcium modalities share the same token space. The model learns cross-modal dependencies through the attention mechanism, discovering how neuromodulatory patterns relate to neural activity.

Explore the Model

Interact with the visualizations below to understand how PRISMT processes cortical data.

Masking & Reconstruction

Click on individual cells to mask tokens (each cell is one timepoint × brain area scalar). Then watch how the PRISMTransformer reconstructs the masked values from surrounding context using causal temporal attention.

Timepoints ↓
Signal Masked Reconstructed

Causal Temporal Attention Mask

Hover over cells to see which tokens can attend to which. The causal mask ensures that information flows only forward in time, while brain areas within the same timepoint can freely communicate.

How to Use PRISMT

1

Prepare Your Data

Standardize your widefield imaging data into the expected .mat format using the provided MATLAB scripts or the PRISMT GUI.

% In MATLAB: standardize your data
run_prismt_gui   % Launch the PRISMT GUI
% Or use the command-line standardization script:
standardize_data('your_data.mat', 'output_standardized.mat')
2

Configure Training

Set model hyperparameters and training configuration. PRISMT supports both command-line arguments and the MATLAB GUI for configuration.

# Train a classification model (e.g., learning phase)
python train.py \
  --data_path data/standardized.mat \
  --task_name phase_classification \
  --hidden_dim 128 \
  --num_heads 4 \
  --num_layers 3 \
  --batch_size 16 \
  --epochs 100 \
  --lr 5e-5 \
  --scheduler_type cosine_warmup
3

Analyze Results

Extract attention rollout maps and generate brain visualizations to identify which cortical regions drive classification.

# Run analysis and generate brain maps
python analyze_results.py \
  --checkpoint results/best_model.pt \
  --data_path data/standardized.mat \
  --output_dir results/analysis/
4

Hyperparameter Optimization

Use Optuna-based HPO to find the best model configuration for your specific dataset.

# Run HPO with Optuna
python hpo_optuna.py \
  --data_path data/standardized.mat \
  --n_trials 50 \
  --task_name phase_classification

Default Configuration

hidden_dim 128
num_heads 4
num_layers 3
ff_dim 256
dropout 0.3
batch_size 16
learning_rate 5e-5
optimizer AdamW
scheduler cosine_warmup
early_stopping patience=15

Cite Our Work

If you find PRISMT useful for your research, please consider citing:

@article{ortegacaro2025cholinergic,
  title={Selective changes in cortical cholinergic signaling during learning},
  author={Ortega Caro, Josue and Batchelor, Hannah M. and Lohani, Sweyta
          and van Dijk, David and Cardin, Jessica A.},
  journal={bioRxiv},
  year={2025},
  doi={10.1101/2025.08.29.673096},
  publisher={Cold Spring Harbor Laboratory}
}