Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor #8

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import copy
import time

import torch
from utils import get_optimizer, get_model
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.autograd import Variable
import copy
from torch.optim import lr_scheduler

from optimization import Optimization
class Client():
from utils import get_optimizer, get_model


class Client:
def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr, batch_size, drop_rate, stride):
self.cid = cid
self.project_dir = project_dir
Expand All @@ -17,15 +20,15 @@ def __init__(self, cid, data, device, project_dir, model_name, local_epoch, lr,
self.local_epoch = local_epoch
self.lr = lr
self.batch_size = batch_size

self.dataset_sizes = self.data.train_dataset_sizes[cid]
self.train_loader = self.data.train_loaders[cid]

self.full_model = get_model(self.data.train_class_sizes[cid], drop_rate, stride)
self.classifier = self.full_model.classifier.classifier
self.full_model.classifier.classifier = nn.Sequential()
self.model = self.full_model
self.distance=0
self.distance = 0
self.optimization = Optimization(self.train_loader, self.device)
# print("class name size",class_names_size[cid])

Expand All @@ -40,21 +43,20 @@ def train(self, federated_model, use_cuda):

optimizer = get_optimizer(self.model, self.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

criterion = nn.CrossEntropyLoss()

since = time.time()

print('Client', self.cid, 'start training')
for epoch in range(self.local_epoch):
print('Epoch {}/{}'.format(epoch, self.local_epoch - 1))
print('Local Epoch of {}: {}/{}'.format(self.cid, epoch, self.local_epoch - 1))
print('-' * 10)

scheduler.step()
self.model.train(True)
running_loss = 0.0
running_corrects = 0.0

for data in self.train_loader:
inputs, labels = data
b, c, h, w = inputs.shape
Expand All @@ -65,7 +67,7 @@ def train(self, federated_model, use_cuda):
labels = Variable(labels.cuda().detach())
else:
inputs, labels = Variable(inputs), Variable(labels)

optimizer.zero_grad()

outputs = self.model(inputs)
Expand All @@ -86,7 +88,7 @@ def train(self, federated_model, use_cuda):
'train', epoch_loss, epoch_acc))

self.y_loss.append(epoch_loss)
self.y_err.append(1.0-epoch_acc)
self.y_err.append(1.0 - epoch_acc)

time_elapsed = time.time() - since
print('Client', self.cid, ' Training complete in {:.0f}m {:.0f}s'.format(
Expand All @@ -97,12 +99,11 @@ def train(self, federated_model, use_cuda):
time_elapsed // 60, time_elapsed % 60))

# save_network(self.model, self.cid, 'last', self.project_dir, self.model_name, gpu_ids)

self.classifier = self.model.classifier.classifier
self.distance = self.optimization.cdw_feature_distance(federated_model, self.old_classifier, self.model)
self.model.classifier.classifier = nn.Sequential()


def generate_soft_label(self, x, regularization):
return self.optimization.kd_generate_soft_label(self.model, x, regularization)

Expand All @@ -116,4 +117,4 @@ def get_train_loss(self):
return self.y_loss[-1]

def get_cos_distance_weight(self):
return self.distance
return self.distance
108 changes: 54 additions & 54 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,76 @@
from torch.utils.data import Dataset
from PIL import Image
from torchvision import datasets, transforms
import os
import json

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms

from random_erasing import RandomErasing


class ImageDataset(Dataset):
def __init__(self, imgs, transform = None):
def __init__(self, imgs, transform=None):
self.imgs = imgs
self.transform = transform

def __len__(self):
return len(self.imgs)

def __getitem__(self, index):
data,label = self.imgs[index]
data, label = self.imgs[index]
return self.transform(Image.open(data)), label


class Data():
class Data:
def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, train_all):
self.datasets = datasets.split(',')
self.batch_size = batch_size
self.erasing_p = erasing_p
self.color_jitter = color_jitter
self.data_dir = data_dir
self.train_all = '_all' if train_all else ''

self.data_transforms = {}
self.train_loaders = {}
self.train_dataset_sizes = {}
self.train_class_sizes = {}
self.client_list = []
self.test_loaders = {}
self.gallery_meta = {}
self.query_meta = {}
self.kd_loader = None

def transform(self):
transform_train = [
transforms.Resize((256,128), interpolation=3),
transforms.Pad(10),
transforms.RandomCrop((256,128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transforms.Resize((256, 128), interpolation=3),
transforms.Pad(10),
transforms.RandomCrop((256, 128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

transform_val = [
transforms.Resize(size=(256,128),interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transforms.Resize(size=(256, 128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

if self.erasing_p > 0:
transform_train = transform_train + [RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0])]

if self.color_jitter:
transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train
transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] \
+ transform_train

self.data_transforms = {
'train': transforms.Compose(transform_train),
'val': transforms.Compose(transform_val),
}
}

def preprocess_kd_data(self, dataset):
loader, image_dataset = self.preprocess_one_train_dataset(dataset)
self.kd_loader = loader


def preprocess_one_train_dataset(self, dataset):
"""preprocess a training dataset, construct a data loader.
"""
Expand All @@ -68,97 +79,86 @@ def preprocess_one_train_dataset(self, dataset):
image_dataset = datasets.ImageFolder(data_path)

loader = torch.utils.data.DataLoader(
ImageDataset(image_dataset.imgs, self.data_transforms['train']),
ImageDataset(image_dataset.imgs, self.data_transforms['train']),
batch_size=self.batch_size,
shuffle=True,
num_workers=2,
shuffle=True,
num_workers=2,
pin_memory=False)

return loader, image_dataset

def preprocess_train(self):
"""preprocess training data, constructing train loaders
"""
self.train_loaders = {}
self.train_dataset_sizes = {}
self.train_class_sizes = {}
self.client_list = []

for dataset in self.datasets:
self.client_list.append(dataset)

loader, image_dataset = self.preprocess_one_train_dataset(dataset)

self.train_dataset_sizes[dataset] = len(image_dataset)
self.train_class_sizes[dataset] = len(image_dataset.classes)
self.train_loaders[dataset] = loader

print('Train dataset sizes:', self.train_dataset_sizes)
print('Train class sizes:', self.train_class_sizes)

def preprocess_test(self):
"""preprocess testing data, constructing test loaders
"""
self.test_loaders = {}
self.gallery_meta = {}
self.query_meta = {}

for test_dir in self.datasets:
test_dir = 'data/'+test_dir+'/pytorch'

dataset = test_dir.split('/')[1]
for dataset in self.datasets:
test_dir = os.path.join(self.data_dir, dataset, 'pytorch')
gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery'))
query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query'))

gallery_dataset = ImageDataset(gallery_dataset.imgs, self.data_transforms['val'])
query_dataset = ImageDataset(query_dataset.imgs, self.data_transforms['val'])

self.test_loaders[dataset] = {key: torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()}

dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()}

gallery_cameras, gallery_labels = get_camera_ids(gallery_dataset.imgs)
self.gallery_meta[dataset] = {
'sizes': len(gallery_dataset),
'sizes': len(gallery_dataset),
'cameras': gallery_cameras,
'labels': gallery_labels
}

query_cameras, query_labels = get_camera_ids(query_dataset.imgs)
self.query_meta[dataset] = {
'sizes': len(query_dataset),
'sizes': len(query_dataset),
'cameras': query_cameras,
'labels': query_labels
}

print('Query Sizes:', self.query_meta[dataset]['sizes'])
print('Gallery Sizes:', self.gallery_meta[dataset]['sizes'])
print('Query Sizes:', self.query_meta[dataset]['sizes'])
print('Gallery Sizes:', self.gallery_meta[dataset]['sizes'])

def preprocess(self):
self.transform()
self.preprocess_train()
self.preprocess_test()
self.preprocess_kd_data('cuhk02')


def get_camera_ids(img_paths):
"""get camera id and labels by image path
"""
camera_ids = []
labels = []
for path, v in img_paths:
filename = os.path.basename(path)
if filename[:3]!='cam':
if filename[:3] != 'cam':
label = filename[0:4]
camera = filename.split('c')[1]
camera = camera.split('s')[0]
else:
label = filename.split('_')[2]
camera = filename.split('_')[1]
if label[0:2]=='-1':
if label[0:2] == '-1':
labels.append(-1)
else:
labels.append(int(label))
Expand Down
Loading