This project is a continuing effort after SwiFT and the official code repo for 'Predicting task-related brain activity from resting-state brain dynamics with fMRI Transformer.' Feel free to ask the authors any questions regarding this project.
Contact
- First author
- Junbeom Kwon: kjb961013@snu.ac.kr
- Corresponding author
- Professor Jiook Cha: connectome@snu.ac.kr
Effective usage of this repository requires learning a couple of technologies: PyTorch, PyTorch Lightning. Knowledge of some experiment logging frameworks like Weights&Biases, Neptune is also recommended.
This repository implements the SwiFUN (SwiFUN).
- Our code offers the following things.
- Trainer based on PyTorch Lightning for running SwiFT and SwiFUN (same as Swin UNETR).
- Data preprocessing/loading pipelines for 4D fMRI datasets.
We highly recommend you to use our conda environment.
# clone project
git clone https://github.com/Transconnectome/SwiFUN.git
# install project
cd SwiFT
conda env create -f envs/py39.yaml
conda activate py39Our directory structure looks like this:
├── notebooks <- Useful Jupyter notebook examples are given (TBU)
├── output <- Experiment log and checkpoints will be saved here
├── project
│ ├── module <- Every module is given in this directory
│ │ ├── models <- Models (Swin fMRI Transformer)
│ │ ├── utils
│ │ │ ├── data_module.py <- Dataloader & codes for matching fMRI scans and target variables
│ │ │ └── data_preprocessing_and_load
│ │ │ ├── datasets.py <- Dataset Class for each dataset
│ │ │ └── preprocessing.py <- Preprocessing codes for step 6
│ │ └── pl_classifier.py <- LightningModule
│ └── main.py <- Main code that trains and tests the 4DSwinTransformer model
│
├── test
│ ├── module_test_swin.py <- Code for debugging SwinTransformer
│ └── module_test_swin4d.py <- Code for debugging 4DSwinTransformer
│
├── sample_scripts <- Example shell scripts for training
│
├── .gitignore <- List of files/folders ignored by git
├── export_DDP_vars.sh <- setup file for running torch DistributedDataParallel (DDP)
└── README.md
- Single forward & backward pass for debugging SwinTransformer4D model.
cd SwiFUN/
python test/module_test_swin4d.pyYou can check the arguments list by using -h
python project/main.py --data_module dummy --classifier_module default -h4.2 Hidden Arguments for PyTorch lightning
pytorch_lightning offers useful arguments for training. For example, we used --max_epochs and --default_root_dir in our experiments. We recommend the user refer to the following link to check the argument lists.
- Training SwiFT in an interactive way
# interactive
cd SwiFUN
bash sample_scripts/sample_train_swifun.shThis bash script was tested on the server cluster (Linux) with 8 RTX 3090 GPUs. You should correct the following lines.
[to be updated]
cd {path to your 'SwiFUN' directory}
source /usr/anaconda3/etc/profile.d/conda.sh (init conda) # might change if you have your own conda.
conda activate {conda env name}
MAIN_ARGS='--loggername neptune --classifier_module v6 --dataset_name {dataset_name} --image_path {path to the image data}' # This script assumes that you have preprocessed HCP dataset. You may run the codes anyway with "--dataset_name Dummy"
DEFAULT_ARGS='--project_name {neptune project name}'
export NEPTUNE_API_TOKEN="{Neptune API token allocated to each user}"
export CUDA_VISIBLE_DEVICES={usable GPU number}- Training SwiFUN with Slurm (if you run the codes at Slurm-based clusters) Please refer to the tutorial for Slurm commands.
cd SwiFUN
sbatch sample_scripts/sample_train_swifun.slurmWe offer two options for loggers.
- Tensorboard (https://www.tensorflow.org/tensorboard)
- Log & model checkpoints are saved in
--default_root_dir - Logging test code with Tensorboard is not available.
- Log & model checkpoints are saved in
- Neptune AI (https://neptune.ai/)
- Generate a new workspace and project on the Neptune website.
- Academic workspace offers 200GB of storage and collaboration for free.
- export NEPTUNE_API_TOKEN="YOUR API TOKEN" in your script.
- specify the "--project_name" argument with your Neptune project name. ex) "--project_name user-id/project"
- Generate a new workspace and project on the Neptune website.
These preprocessing codes are implemented based on the initial repository by GonyRosenman TFF
To make your own dataset, you should execute either of the minimal preprocessing steps:
- fMRIprep Preprocessing with fMRIprep
- FSL UKB Preprocessing pipeline
- We ensure that each brain is registered to the MNI space, and the whole brain mask is applied to remove non-brain regions.
- We are investigating how additional preprocessing steps to remove confounding factors such as head movement impact performance.
After the minimal preprocessing steps, you should perform additional preprocessing to use SwiFT. (You can find the preprocessing code at 'project/module/utils/data_preprocessing_and_load/preprocessing.py')
- normalization: voxel normalization(not used) and whole-brain z-normalization (mainly used)
- change fMRI volumes to floating point 16 to save storage and decrease IO bottleneck.
- each fMRI volume is saved separately as torch checkpoints to facilitate window-based training.
- remove non-brain(background) voxels that are over 96 voxels.
- you should open your fMRI scans to determine the level that does not cut out the brain regions
- you can use
nilearnto visualize your fMRI data. (official documentation: here)
from nilearn import plotting from nilearn.image import mean_img plotting.view_img(mean_img(fmri_filename), threshold=None)
- if your dimension is under 96, you can pad non-brain voxels at 'datasets.py' files.
- refer to the annotation in the 'preprocessing.py' code to adjust it for your own datasets.
The resulting data structure is as follows:
├── {Dataset name}_MNI_to_TRs
├── img <- Every normalized volume is located in this directory
│ ├── sub-01 <- subject name
│ │ ├── frame_0.pt <- Each torch pt file contains one volume in a fMRI sequence (total number of pt files = length of fMRI sequence)
│ │ ├── frame_1.pt
│ │ │ :
│ │ ├── frame_{T}.pt <- the last volume in an fMRI sequence (length T)
│ │ └── global_stats.pt <- min, max, mean value of fMRI for the subject
│ └── sub-02
│ │ ├── frame_0.pt
│ │ ├── frame_1.pt
│ │ ├── :
└── metadata
└── metafile.csv <- file containing target variable
- The data loading pipeline works by processing image and metadata at 'project/module/utils/data_module.py' and passing the paired image-label tuples to the Dataset classes at 'project/module/utils/data_preprocessing_and_load/datasets.py.'
- you should implement codes for combining image path, subject_name, and target variables at 'project/module/utils/data_module.py'
- you should define Dataset Class for your dataset at 'project/module/utils/data_preprocessing_and_load/datasets.py.' In the Dataset class (getitem), you should specify how many background voxels you would add or remove to make the volumes shaped 96 * 96 * 96.
@article{kwon2024predicting,
title={Predicting task-related brain activity from resting-state brain dynamics with fMRI Transformer},
author={Kwon, Junbeom and Seo, Jungwoo and Wang, Heehwan and Moon, Taesup and Yoo, Shinjae and Cha, Jiook},
journal={bioRxiv},
pages={2024--05},
year={2024},
publisher={Cold Spring Harbor Laboratory}
}