-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
125 lines (111 loc) · 5.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from networks.Vit_new import VisionTransformer as conv_ViT_seg
from networks.mlvit import MLViTSeg
from networks.Seg import MLPmedical
from trainer import trainer_synapse
from send_email import sendEmailText
#指定固定的卡
#torch.cuda.set_device(0)
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='../data/Synapse/train_npz', help='root dir for data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=500, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=5, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.0003,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=384, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
parser.add_argument('--n_skip', type=int,
default=3, help='using number of skip-connect, default is num')
parser.add_argument('--vit_name', type=str,
default='R50-ViT-B_16', help='select one vit model')
parser.add_argument('--vit_patches_size', type=int,
default=16, help='vit_patches_size, default is 16')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--world_size', default=1, type=int,help='number of distributed processes')
parser.add_argument('--local_rank', default=1, type=int,help='local_rank')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': '/home/kkk/medical/pop2/data/train_npz',
'list_dir': '/home/kkk/medical/snp/project_TransUNet/TransUNet/lists/lists_Synapse',
'num_classes': 5,
},
}
# if args.batch_size != 24 and args.batch_size % args.batch_size == 0:
# args.base_lr *= args.batch_size / 24
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.list_dir = dataset_config[dataset_name]['list_dir']
args.is_pretrain = True
args.exp = 'TU_' + dataset_name + str(args.img_size)
snapshot_path = "/home/kkk/medical/model/{}/{}".format(args.exp, 'TU')
snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path
snapshot_path += '_' + args.vit_name +'POP_HS-MLPv2_without pretrain'
snapshot_path = snapshot_path + '_skip' + str(args.n_skip)
snapshot_path = snapshot_path + '_vitpatch' + str(args.vit_patches_size) if args.vit_patches_size!=16 else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
config_vit = CONFIGS_ViT_seg[args.vit_name]
config_vit.n_classes = args.num_classes
config_vit.n_skip = args.n_skip
if args.vit_name.find('R50') != -1:
config_vit.patches.grid = (int(args.img_size / args.vit_patches_size), int(args.img_size / args.vit_patches_size))
#net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
# net = MLViTSeg(config_vit).cuda()
net = MLPmedical(config_vit).cuda()
pretrain = True
if pretrain:
pretrian_path = '/home/kkk/medical/model/TU_Synapse384/TU_pretrain_R50-ViT-B_16_HS-MLPv2_without pretrain_skip3_epo500_bs5_lr0.0003_384/epoch_499.pth'
pretrained_dict = torch.load(pretrian_path)
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)
print('load pretrain model')
#net = conv_ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
#net.load_from(weights=np.load(config_vit.pretrained_path))
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, net, snapshot_path)
#sendEmailText('code done','[email protected]','sfjlmvqiktdfbggd','[email protected]','Hi!\nYour training code is finished.\n','smtp.qq.com')