Skip to content

Latest commit

 

History

History
195 lines (154 loc) · 7.88 KB

File metadata and controls

195 lines (154 loc) · 7.88 KB

Diffusion-Driven-Test-Time-Adaptation-via-Synthetic-Domain-Alignment

This repository is the official Pytorch implementation for SDA.

arXiv

Everything to the Synthetic: Diffusion-driven Test-time Adaptation via Synthetic-Domain Alignment
Jiayi Guo, Junhao Zhao, Chunjiang Ge, Chaoqun Du Zanlin Ni, Shiji Song, Humphrey Shi, Gao Huang

Sythetic-Domain Alignment (SDA) is a novel test-time adaptation framework that simultaneously aligns the domains of the source model and target data with the same synthetic domain of a diffusion model.

Overview

SDA is a novel two-stage TTA framework aligning both the domains of the source model and the target data with the synthetic domain. In Stage 1, the source-domain model is adapted to a synthetic-domain model through synthetic data fine-tuning. This synthetic data is first generated using a conditional diffusion model based on domain-agnostic class labels, then re-synthesized through an unconditional diffusion process to ensure domain alignment with the projected target data in Stage 2. In Stage 2, target data is projected into the synthetic domain using unconditional diffusion for synthetic-domain model prediction.

News

  • [2023.06.07] Code, data and models released!

Setup

Installation

git clone https://github.com/SHI-Labs/Diffusion-Driven-Test-Time-Adaptation-via-Synthetic-Domain-Alignment.git
cd Diffusion-Driven-Test-Time-Adaptation-via-Synthetic-Domain-Alignment
conda env create -f environment.yml
conda activate SDA
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
mim install mmcv-full 
mim install mmcls

Dataset

Download ImageNet-C and generate ImageNet-W following the official repos.

For a quick evaluation, also download our re-synthesized ImageNet-C-Syn and ImageNet-W-Syn via Google Drive or Tsinghua Cloud. Place all datasets into data/.

data
|——ImageNet-C
|    |——gaussian_noise
|       |——5 
|           |——n01440764       
|               |——*.JEPG
|——ImageNet-C-Syn
|    |——gaussian_noise
|       |——5 
|           |——n01440764       
|               |——*.JEPG
|——ImageNet-W
|   |——val
|       |——n01440764
|           |——*.JEPG
|——ImageNet-W-Syn
|   |——val
|       |——n01440764
|           |——*.JEPG

You can also re-synthesize the test datasets yourself following the official repo of DDA or using our align.sh.

Model

Download pretrained checkpoints (diffusion models and classifiers) via the following command:

bash scripts/download_ckpt.sh

For a quick evaluation, also download our finetuned checkpoints via Google Drive or Tsinghua Cloud. Place the checkpoints into finetuned_ckpt/.

Evaluation

We provide example commands to evaluate finetuned models on both ImageNet-C and ImageNet-W:

bash scripts/eval.sh

You can also test a customized model with the following formats:

# ImageNet-C
CUDA_VISIBLE_DEVICES=0 python eval/test_ensemble.py <config> <finetuned ckpt> \
--originckpt <pretrained ckpt> --metrics accuracy --datatype C --ensemble sda --corruption <corruption type> --data_prefix1 data/ImageNet-C --data_prefix2 data/ImageNet-C-Syn

# ImageNet-W
CUDA_VISIBLE_DEVICES=0 python eval/test_ensemble.py <config> <finetuned ckpt> \
--originckpt <pretrained ckpt> --metrics accuracy --datatype W --ensemble sda --data_prefix1 data/ImageNet-W --data_prefix2 data/ImageNet-W-Syn

You may need to set up a new config for your customized model according to our evaluation configs.

Training/Fine-tuning

Step 1: Synthetic data generation via conditional diffusion

Run the following command to generate a synthetic dataset via DiT:

bash scripts/gen.sh

The synthetic dataset contains the 1000 ImageNet classes, with 50 images per class:

data
|——DiT-XL-2-DiT-XL-2-256x256-size-256-vae-ema-cfg-1.0-seed-0
|       |——0000
|       |    |——*.png
|       |——0001
|       |——....
|       |——9999

Step 2: Synthetic data alignment via unconditional diffusion

Run the following command to project the synthetic dataset to the domain of ADM:

bash scripts/align.sh

We also provide example commands to project target data (ImageNet-C/W) to the domain of ADM, which is the same as DDA. Check align.sh for more details.

Step 3. Synthetic data fine-tuning

Important: Our fine-tuning code is constructed based on MMPreTrain, which conflicts with the mmcv version used in the data alignment (Step 2) built on DDA. Therefore, it is necessary to set up a new environment before fine-tuning:

conda env create -f environment_ft.yml
conda activate SDA-FT
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
mim install mmcv

We provide fine-tuning configs for five models used in our paper. Run the following command to start synthetic data fine-tuning:

bash scripts/finetune.sh

In our implementation, we use a 30-epoch fine-tuning scheduler. Empirically, we find that 15 epochs of training is sufficient for evaluation.

If you want to fine-tune different models, please refer to mmpretrain to set up a new config.

Results

  • ImageNet-C

  • ImageNet-W

  • Visualization: Grad-CAM results with prediction classes and confidence scores displayed above the images.

Citation

If you find our work helpful, please star 🌟 this repo and cite 📑 our paper. Thanks for your support!

@article{guo2024sda,
  title={Everything to the Synthetic: Diffusion-driven Test-time Adaptation via Synthetic-domain Alignment},
  author={Jiayi Guo and Junhao Zhao and Chunjiang Ge and Chaoqun Du and Zanlin Ni and Shiji Song and Humphrey Shi and Gao Huang},
  journal={arXiv},
  year={2024}
}

Acknowledgements

We thank MMPretrain (model fine-tuning), DiT (data synthesis) and DDA (data alignment).

Contact

guo-jy20 at mails dot tsinghua dot edu dot cn