diff --git a/client.py b/client.py index 9d93e80..bed4bcf 100644 --- a/client.py +++ b/client.py @@ -1,13 +1,16 @@ +import copy import time + import torch -from utils import get_optimizer, get_model import torch.nn as nn -from torch.optim import lr_scheduler -import torch.nn.functional as F from torch.autograd import Variable -import copy +from torch.optim import lr_scheduler + from optimization import Optimization -class Client(): +from utils import get_optimizer, get_model + + +class Client: def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride): self.cid = cid self.project_dir = project_dir @@ -17,7 +20,7 @@ def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, self.local_epoch = local_epoch self.lr = lr self.batch_size = batch_size - + self.dataset_sizes = self.data.train_dataset_sizes[cid] self.train_loader = self.data.train_loaders[cid] @@ -25,7 +28,7 @@ def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, self.classifier = self.full_model.classifier.classifier self.full_model.classifier.classifier = nn.Sequential() self.model = self.full_model - self.distance=0 + self.distance = 0 self.optimization = Optimization(self.train_loader, self.device) # print("class name size",class_names_size[cid]) @@ -40,21 +43,20 @@ def train(self, federated_model, use_cuda): optimizer = get_optimizer(self.model, self.lr) scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1) - + criterion = nn.CrossEntropyLoss() since = time.time() - print('Client', self.cid, 'start training') for epoch in range(self.local_epoch): - print('Epoch {}/{}'.format(epoch, self.local_epoch - 1)) + print('Local Epoch of {}: {}/{}'.format(self.cid, epoch, self.local_epoch - 1)) print('-' * 10) scheduler.step() self.model.train(True) running_loss = 0.0 running_corrects = 0.0 - + for data in self.train_loader: inputs, labels = data b, c, h, w = inputs.shape @@ -65,7 +67,7 @@ def train(self, federated_model, use_cuda): labels = Variable(labels.cuda().detach()) else: inputs, labels = Variable(inputs), Variable(labels) - + optimizer.zero_grad() outputs = self.model(inputs) @@ -86,7 +88,7 @@ def train(self, federated_model, use_cuda): 'train', epoch_loss, epoch_acc)) self.y_loss.append(epoch_loss) - self.y_err.append(1.0-epoch_acc) + self.y_err.append(1.0 - epoch_acc) time_elapsed = time.time() - since print('Client', self.cid, ' Training complete in {:.0f}m {:.0f}s'.format( @@ -97,12 +99,11 @@ def train(self, federated_model, use_cuda): time_elapsed // 60, time_elapsed % 60)) # save_network(self.model, self.cid, 'last', self.project_dir, self.model_name, gpu_ids) - + self.classifier = self.model.classifier.classifier self.distance = self.optimization.cdw_feature_distance(federated_model, self.old_classifier, self.model) self.model.classifier.classifier = nn.Sequential() - def generate_soft_label(self, x, regularization): return self.optimization.kd_generate_soft_label(self.model, x, regularization) @@ -116,4 +117,4 @@ def get_train_loss(self): return self.y_loss[-1] def get_cos_distance_weight(self): - return self.distance \ No newline at end of file + return self.distance diff --git a/data_utils.py b/data_utils.py index 0aa3d65..d454670 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,13 +1,15 @@ -from torch.utils.data import Dataset -from PIL import Image -from torchvision import datasets, transforms import os -import json + import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import datasets, transforms + from random_erasing import RandomErasing + class ImageDataset(Dataset): - def __init__(self, imgs, transform = None): + def __init__(self, imgs, transform=None): self.imgs = imgs self.transform = transform @@ -15,11 +17,11 @@ def __len__(self): return len(self.imgs) def __getitem__(self, index): - data,label = self.imgs[index] + data, label = self.imgs[index] return self.transform(Image.open(data)), label -class Data(): +class Data: def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, train_all): self.datasets = datasets.split(',') self.batch_size = batch_size @@ -27,39 +29,48 @@ def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, trai self.color_jitter = color_jitter self.data_dir = data_dir self.train_all = '_all' if train_all else '' - + self.data_transforms = {} + self.train_loaders = {} + self.train_dataset_sizes = {} + self.train_class_sizes = {} + self.client_list = [] + self.test_loaders = {} + self.gallery_meta = {} + self.query_meta = {} + self.kd_loader = None + def transform(self): transform_train = [ - transforms.Resize((256,128), interpolation=3), - transforms.Pad(10), - transforms.RandomCrop((256,128)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ] + transforms.Resize((256, 128), interpolation=3), + transforms.Pad(10), + transforms.RandomCrop((256, 128)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ] transform_val = [ - transforms.Resize(size=(256,128),interpolation=3), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ] + transforms.Resize(size=(256, 128), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ] if self.erasing_p > 0: transform_train = transform_train + [RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0])] if self.color_jitter: - transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train + transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] \ + + transform_train self.data_transforms = { 'train': transforms.Compose(transform_train), 'val': transforms.Compose(transform_val), - } + } def preprocess_kd_data(self, dataset): loader, image_dataset = self.preprocess_one_train_dataset(dataset) self.kd_loader = loader - def preprocess_one_train_dataset(self, dataset): """preprocess a training dataset, construct a data loader. """ @@ -68,10 +79,10 @@ def preprocess_one_train_dataset(self, dataset): image_dataset = datasets.ImageFolder(data_path) loader = torch.utils.data.DataLoader( - ImageDataset(image_dataset.imgs, self.data_transforms['train']), + ImageDataset(image_dataset.imgs, self.data_transforms['train']), batch_size=self.batch_size, - shuffle=True, - num_workers=2, + shuffle=True, + num_workers=2, pin_memory=False) return loader, image_dataset @@ -79,64 +90,52 @@ def preprocess_one_train_dataset(self, dataset): def preprocess_train(self): """preprocess training data, constructing train loaders """ - self.train_loaders = {} - self.train_dataset_sizes = {} - self.train_class_sizes = {} - self.client_list = [] - for dataset in self.datasets: self.client_list.append(dataset) - + loader, image_dataset = self.preprocess_one_train_dataset(dataset) self.train_dataset_sizes[dataset] = len(image_dataset) self.train_class_sizes[dataset] = len(image_dataset.classes) self.train_loaders[dataset] = loader - + print('Train dataset sizes:', self.train_dataset_sizes) print('Train class sizes:', self.train_class_sizes) - + def preprocess_test(self): """preprocess testing data, constructing test loaders """ - self.test_loaders = {} - self.gallery_meta = {} - self.query_meta = {} - - for test_dir in self.datasets: - test_dir = 'data/'+test_dir+'/pytorch' - - dataset = test_dir.split('/')[1] + for dataset in self.datasets: + test_dir = os.path.join(self.data_dir, dataset, 'pytorch') gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery')) query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query')) - + gallery_dataset = ImageDataset(gallery_dataset.imgs, self.data_transforms['val']) query_dataset = ImageDataset(query_dataset.imgs, self.data_transforms['val']) self.test_loaders[dataset] = {key: torch.utils.data.DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=8, - pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()} - + dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=8, + pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()} gallery_cameras, gallery_labels = get_camera_ids(gallery_dataset.imgs) self.gallery_meta[dataset] = { - 'sizes': len(gallery_dataset), + 'sizes': len(gallery_dataset), 'cameras': gallery_cameras, 'labels': gallery_labels } query_cameras, query_labels = get_camera_ids(query_dataset.imgs) self.query_meta[dataset] = { - 'sizes': len(query_dataset), + 'sizes': len(query_dataset), 'cameras': query_cameras, 'labels': query_labels } - print('Query Sizes:', self.query_meta[dataset]['sizes']) - print('Gallery Sizes:', self.gallery_meta[dataset]['sizes']) + print('Query Sizes:', self.query_meta[dataset]['sizes']) + print('Gallery Sizes:', self.gallery_meta[dataset]['sizes']) def preprocess(self): self.transform() @@ -144,6 +143,7 @@ def preprocess(self): self.preprocess_test() self.preprocess_kd_data('cuhk02') + def get_camera_ids(img_paths): """get camera id and labels by image path """ @@ -151,14 +151,14 @@ def get_camera_ids(img_paths): labels = [] for path, v in img_paths: filename = os.path.basename(path) - if filename[:3]!='cam': + if filename[:3] != 'cam': label = filename[0:4] camera = filename.split('c')[1] camera = camera.split('s')[0] else: label = filename.split('_')[2] camera = filename.split('_')[1] - if label[0:2]=='-1': + if label[0:2] == '-1': labels.append(-1) else: labels.append(int(label)) diff --git a/main.py b/main.py index f68f142..eeeca61 100644 --- a/main.py +++ b/main.py @@ -1,39 +1,34 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division + import argparse -import torch -import time +import multiprocessing as mp import os -import yaml -import random -import numpy as np -import scipy.io -import pathlib import sys -import json -import copy -import multiprocessing as mp -import torch.nn.functional as F +import time + import matplotlib -matplotlib.use('agg') -import matplotlib.pyplot as plt -from PIL import Image +import torch + from client import Client +from data_utils import Data from server import Server from utils import set_random_seed -from data_utils import Data +matplotlib.use('agg') mp.set_start_method('spawn', force=True) sys.setrecursionlimit(10000) -version = torch.__version__ +version = torch.__version__ parser = argparse.ArgumentParser(description='Training') -parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') -parser.add_argument('--model_name',default='ft_ResNet50', type=str, help='output model name') -parser.add_argument('--project_dir',default='.', type=str, help='project path') -parser.add_argument('--data_dir',default='data',type=str, help='training dir path') -parser.add_argument('--datasets',default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids',type=str, help='datasets used') -parser.add_argument('--train_all', action='store_true', help='use all training data' ) +parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2') +parser.add_argument('--model_name', default='ft_ResNet50', type=str, help='output model name') +parser.add_argument('--project_dir', default='.', type=str, help='project path') +parser.add_argument('--data_dir', default='data', type=str, help='training dir path') +parser.add_argument('--datasets', + default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids', type=str, + help='datasets used') +parser.add_argument('--train_all', action='store_true', default=True, help='use all training data') parser.add_argument('--stride', default=2, type=int, help='stride') parser.add_argument('--lr', default=0.05, type=float, help='learning rate') parser.add_argument('--drop_rate', default=0.5, type=float, help='drop rate') @@ -42,27 +37,29 @@ parser.add_argument('--local_epoch', default=1, type=int, help='number of local epochs') parser.add_argument('--batch_size', default=32, type=int, help='batch size') parser.add_argument('--num_of_clients', default=9, type=int, help='number of clients') +parser.add_argument('--rounds', default=300, type=int, help='training rounds') # arguments for data transformation parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]') -parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' ) +parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training') # arguments for testing federated model -parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') -parser.add_argument('--multi', action='store_true', help='use multiple query' ) -parser.add_argument('--multiple_scale',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') -parser.add_argument('--test_dir',default='all',type=str, help='./test_data') +parser.add_argument('--which_epoch', default='last', type=str, help='0,1,2,3...or last') +parser.add_argument('--multi', action='store_true', help='use multiple query') +parser.add_argument('--multiple_scale', default='1', type=str, help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') # arguments for optimization -parser.add_argument('--cdw', action='store_true', help='use cosine distance weight for model aggregation, default false' ) -parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false' ) -parser.add_argument('--regularization', action='store_true', help='use regularization during distillation, default false' ) +parser.add_argument('--cdw', action='store_true', + help='use cosine distance weight for model aggregation, default false') +parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false') +parser.add_argument('--regularization', action='store_true', + help='use regularization during distillation, default false') def train(): args = parser.parse_args() print(args) - + use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -70,56 +67,54 @@ def train(): data = Data(args.datasets, args.data_dir, args.batch_size, args.erasing_p, args.color_jitter, args.train_all) data.preprocess() - + clients = {} for cid in data.client_list: clients[cid] = Client( - cid, - data, - device, - args.project_dir, - args.model_name, - args.local_epoch, - args.lr, - args.batch_size, - args.drop_rate, - args.stride) + cid, + data, + device, + args.project_dir, + args.model_name, + args.local_epoch, + args.lr, + args.batch_size, + args.drop_rate, + args.stride) server = Server( - clients, - data, - device, - args.project_dir, - args.model_name, - args.num_of_clients, - args.lr, - args.drop_rate, - args.stride, + clients, + data, + device, + args.project_dir, + args.model_name, + args.num_of_clients, + args.lr, + args.drop_rate, + args.stride, args.multiple_scale) - dir_name = os.path.join(args.project_dir, 'model', args.model_name) - if not os.path.isdir(dir_name): - os.mkdir(dir_name) + save_dir = os.path.join(args.project_dir, 'model') + if not os.path.isdir(save_dir): + os.mkdir(save_dir) + save_dir = os.path.join(save_dir, "{}_{}".format(args.model_name, args.rounds)) + if not os.path.isdir(save_dir): + os.mkdir(save_dir) print("=====training start!========") - rounds = 800 - for i in range(rounds): - print('='*10) + for i in range(args.rounds): + print('=' * 10) print("Round Number {}".format(i)) - print('='*10) + print('=' * 10) server.train(i, args.cdw, use_cuda) - save_path = os.path.join(dir_name, 'federated_model.pth') + save_path = os.path.join(save_dir, 'federated_model.pth') torch.save(server.federated_model.cpu().state_dict(), save_path) - if (i+1)%10 == 0: - server.test(use_cuda) + if (i + 1) % 10 == 0: if args.kd: server.knowledge_distillation(args.regularization) - server.test(use_cuda) + server.test(use_cuda, save_dir) server.draw_curve() + if __name__ == '__main__': train() - - - - diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..ba4336e --- /dev/null +++ b/run.sh @@ -0,0 +1,13 @@ +mkdir -p log +now=$(date +"%Y%m%d_%H%M") + +root_dir=/mnt/lustre/$(whoami) +project_dir=$root_dir/projects/FedReID +data_dir=$root_dir/fedreid_data + +export PYTHONPATH=$PYTHONPATH:${project_dir} + +srun -u --partition=innova --job-name=FedReID300 \ + -n8 --gres=gpu:8 --ntasks-per-node=8 \ + python ${project_dir}/main.py --data_dir $data_dir --rounds 300 \ + --datasets "Market-1501,DukeMTMC-reID,cuhk03,cuhk01,MSMT17,viper,prid,3dpes,ilids" 2>&1 | tee log/fedreid_${now}.log & \ No newline at end of file diff --git a/server.py b/server.py index 2cd46e9..86d7a39 100644 --- a/server.py +++ b/server.py @@ -1,16 +1,16 @@ -import os +import copy import math -import json +import os +import random + import matplotlib.pyplot as plt -from utils import get_model, extract_feature -import torch.nn as nn -import torch import scipy.io -import copy -from data_utils import ImageDataset -import random +import torch +import torch.nn as nn import torch.optim as optim -from torchvision import datasets + +from utils import get_model, extract_feature + def add_model(dst_model, src_model, dst_no_data, src_no_data): if dst_model is None: @@ -22,9 +22,10 @@ def add_model(dst_model, src_model, dst_no_data, src_no_data): with torch.no_grad(): for name1, param1 in params1: if name1 in dict_params2: - dict_params2[name1].set_(param1.data*src_no_data + dict_params2[name1].data*dst_no_data) + dict_params2[name1].set_(param1.data * src_no_data + dict_params2[name1].data * dst_no_data) return dst_model + def scale_model(model, scale): params = model.named_parameters() dict_params = dict(params) @@ -33,6 +34,7 @@ def scale_model(model, scale): dict_params[name].set_(dict_params[name].data * scale) return model + def aggregate_models(models, weights): """aggregate models based on weights params: @@ -45,13 +47,23 @@ def aggregate_models(models, weights): total_no_data = weights[0] for i in range(1, len(models)): model = add_model(model, models[i], total_no_data, weights[i]) - model = scale_model(model, 1.0 / (total_no_data+weights[i])) + model = scale_model(model, 1.0 / (total_no_data + weights[i])) total_no_data = total_no_data + weights[i] return model -class Server(): - def __init__(self, clients, data, device, project_dir, model_name, num_of_clients, lr, drop_rate, stride, multiple_scale): +class Server: + def __init__(self, + clients, + data, + device, + project_dir, + model_name, + num_of_clients, + lr, + drop_rate, + stride, + multiple_scale): self.project_dir = project_dir self.data = data self.device = device @@ -68,13 +80,14 @@ def __init__(self, clients, data, device, project_dir, model_name, num_of_client for s in multiple_scale.split(','): self.multiple_scale.append(math.sqrt(float(s))) - self.full_model = get_model(750, drop_rate, stride).to(device) + self.full_model = get_model(750, + drop_rate, + stride).to(device) self.full_model.classifier.classifier = nn.Sequential() - self.federated_model=self.full_model + self.federated_model = self.full_model self.federated_model.eval() self.train_loss = [] - def train(self, epoch, cdw, use_cuda): models = [] loss = [] @@ -88,8 +101,8 @@ def train(self, epoch, cdw, use_cuda): models.append(self.clients[i].get_model()) data_sizes.append(self.clients[i].get_data_sizes()) - if epoch==0: - self.L0 = torch.Tensor(loss) + if epoch == 0: + self.L0 = torch.Tensor(loss) avg_loss = sum(loss) / self.num_of_clients @@ -97,11 +110,11 @@ def train(self, epoch, cdw, use_cuda): print("number of clients used:", len(models)) print('Train Epoch: {}, AVG Train Loss among clients of lost epoch: {:.6f}'.format(epoch, avg_loss)) print() - + self.train_loss.append(avg_loss) - + weights = data_sizes - + if cdw: print("cos distance weights:", cos_distance_weights) weights = cos_distance_weights @@ -118,21 +131,23 @@ def draw_curve(self): os.mkdir(dir_name) plt.savefig(os.path.join(dir_name, 'train.png')) plt.close('all') - - def test(self, use_cuda): - print("="*10) + + def test(self, use_cuda, save_path): + print("=" * 10) print("Start Tesing!") - print("="*10) - print('We use the scale: %s'%self.multiple_scale) - + print("=" * 10) + print('We use the scale: %s' % self.multiple_scale) + for dataset in self.data.datasets: self.federated_model = self.federated_model.eval() if use_cuda: self.federated_model = self.federated_model.cuda() - + with torch.no_grad(): - gallery_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['gallery'], self.multiple_scale) - query_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['query'], self.multiple_scale) + gallery_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['gallery'], + self.multiple_scale) + query_feature = extract_feature(self.federated_model, self.data.test_loaders[dataset]['query'], + self.multiple_scale) result = { 'gallery_f': gallery_feature.numpy(), @@ -142,36 +157,30 @@ def test(self, use_cuda): 'query_label': self.data.query_meta[dataset]['labels'], 'query_cam': self.data.query_meta[dataset]['cameras']} - scipy.io.savemat(os.path.join(self.project_dir, - 'model', - self.model_name, - 'pytorch_result.mat'), - result) - - print(self.model_name) - print(dataset) + scipy.io.savemat(os.path.join(save_path, 'pytorch_result.mat'), result) - os.system('python evaluate.py --result_dir {} --dataset {}'.format(os.path.join(self.project_dir, 'model', self.model_name), dataset)) + os.system('python evaluate.py --result_dir {} --dataset {}'.format(save_path, dataset)) def knowledge_distillation(self, regularization): MSEloss = nn.MSELoss().to(self.device) - optimizer = optim.SGD(self.federated_model.parameters(), lr=self.lr*0.01, weight_decay=5e-4, momentum=0.9, nesterov=True) + optimizer = optim.SGD(self.federated_model.parameters(), lr=self.lr * 0.01, weight_decay=5e-4, momentum=0.9, + nesterov=True) self.federated_model.train() - for _, (x, target) in enumerate(self.data.kd_loader): + for _, (x, target) in enumerate(self.data.kd_loader): x, target = x.to(self.device), target.to(self.device) # target=target.long() optimizer.zero_grad() - soft_target = torch.Tensor([[0]*512]*len(x)).to(self.device) - + soft_target = torch.Tensor([[0] * 512] * len(x)).to(self.device) + for i in self.client_list: i_label = (self.clients[i].generate_soft_label(x, regularization)) soft_target += i_label soft_target /= len(self.client_list) - + output = self.federated_model(x) - + loss = MSEloss(output, soft_target) loss.backward() optimizer.step() - print("train_loss_fine_tuning", loss.data) \ No newline at end of file + print("train_loss_fine_tuning", loss.data)