This repository contains the code for our paper PriViT (arxiv link). Here we provide all the necessary files and instructions to replicate our training, testing, and benchmarking processes.
This section will guide you through the basic steps to get the code running on your local machine for development and testing purposes. For detailed instructions about each script, please refer to the specific subsections below.
You would need to create two separate environments one for training privit models, and another for benchmarking the trained models on secretflow using SEMI2K. We have exported the environment so you can simply create one by loading the yml files provided in the src. privit_training_environmnet.yml has packages for training environment and spu-jax.yml has packages for the benchmarking environment.
git clone privit
cd privit
cd src
conda env create -n compression --file privit_training_environmnet.yml
You can find the training script in the slurm script files. Ensure that correct datapath is updated in the dataset.py folder.
- tinyimagenet.sbatch: This is the slurm script to train privit model on Tiny Imagenet dataset
- cifar100.sbatch: This is the slurm script to train privit model on Cifar 100 dataset.
- cifar10.sbatch: This is the slurm script to train privit model on Cifar 10 dataset.
- train.py: This script has the primary training logic of PriViT
We have also released model checkpoints here.
You can find the inference script in inference.sbatch file, primary inference logic is in inference.py.
The folder benchmark/ contains all flax code used for benchmarking the performance of these PyTorch models using secretflow framework on SEMI2k protocol. For detailed instructions on how to setup a benchmarking setup using secretflow, refer their documentation. 2pc.json is the configuration file, update the IP address of the two nodes in this file. Start the server on two nodes using this:
conda activate spu-jax
python nodectl.py -c 2pc.json start --node_id node:0 &> node0.log &
conda activate spu-jax
python nodectl.py -c 2pc.json start --node_id node:1 &> node1.log &
To benchmark the privit model run on node 0
python privit_secretflow.py --config 2pc.json --checkpoint "/path/to/checkpoint" --dataset tiny_imagenet (or cifar10 or cifar100)
To benchmark the mpcvit model run on node 0
python mpcvit_secretflow.py --config 2pc.json --checkpoint "/path/to/checkpoint" --dataset tiny_imagenet (or cifar10 or cifar100)
These scripts load the pytorch checkpoints of privit/mpcvit and converts them to be compatible with flax.
Ablation studies are performed using this script
- train_without_kd.py: Training without knowledge distillation (kd).
- train_without_pretrain.py: Training without using pretrained checkpoints.
If you find our work helpful to your research, please cite our paper:
@misc{dhyani2023privit,
title={PriViT: Vision Transformers for Fast Private Inference},
author={Naren Dhyani and Jianqiao Mo and Minsu Cho and Ameya Joshi and Siddharth Garg and Brandon Reagen and Chinmay Hegde},
year={2023},
eprint={2310.04604},
archivePrefix={arXiv},
primaryClass={cs.CR}
}