This repository has been archived by the owner on Dec 4, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_retriever.py
106 lines (92 loc) · 2.86 KB
/
train_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pytorch_lightning import seed_everything
from config import EXPERIMENT_ROOT, PROJECT_NAME, args, set_template
from dataloader import dataloader_factory
from model import BERT, NARM, LRURec, SASRec
from trainer import BERTTrainer, LRUTrainer, RNNTrainer, SASTrainer
try:
os.environ["WANDB_PROJECT"] = PROJECT_NAME
except:
print("WANDB_PROJECT not available, please set it in config.py")
def main(args, export_root=None):
seed_everything(args.seed)
train_loader, val_loader, test_loader = dataloader_factory(args)
if args.model_code == "lru":
model = LRURec(args)
elif args.model_code == "bert":
model = BERT(args)
elif args.model_code == "sas":
model = SASRec(args)
elif args.model_code == "narm":
model = NARM(args)
if export_root == None:
export_root = EXPERIMENT_ROOT + "/" + args.model_code + "/" + args.dataset_code
if args.model_code == "lru":
trainer = LRUTrainer(
args,
model,
train_loader,
val_loader,
test_loader,
export_root,
args.use_wandb,
)
elif args.model_code == "bert":
trainer = BERTTrainer(
args,
model,
train_loader,
val_loader,
test_loader,
export_root,
args.use_wandb,
)
elif args.model_code == "sas":
trainer = SASTrainer(
args,
model,
train_loader,
val_loader,
test_loader,
export_root,
args.use_wandb,
)
elif args.model_code == "narm":
args.num_epochs = 100
trainer = RNNTrainer(
args,
model,
train_loader,
val_loader,
test_loader,
export_root,
args.use_wandb,
)
trainer.train()
trainer.test()
# the next line generates val / test candidates for reranking
trainer.generate_candidates(os.path.join(export_root, "retrieved.pkl"))
if __name__ == "__main__":
set_template(args)
if args.hyperparam_search:
# searching best hyperparameters
for decay in [0, 0.01]:
for dropout in [0, 0.1, 0.2, 0.3, 0.4, 0.5]:
args.weight_decay = decay
args.bert_dropout = dropout
args.bert_attn_dropout = dropout
export_root = (
EXPERIMENT_ROOT
+ "/"
+ args.model_code
+ "/"
+ args.dataset_code
+ "/"
+ str(decay)
+ "_"
+ str(dropout)
)
main(args, export_root=export_root)
else:
main(args, export_root=None)