-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_graph.py
143 lines (109 loc) · 6.6 KB
/
train_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.optim.lr_scheduler import OneCycleLR
from torch_geometric.loader import DataLoader
from data import CustomGraphDataset, load_dataset
from models import GNNModel
from utils import ExperimentTracker, set_random_seeds
def parse_arguments():
parser = argparse.ArgumentParser(description="Train a GNN model on a graph dataset")
# Dataset and Experiment Configuration
parser.add_argument("--dataset_path", type=str, default='datasets/ADFA-LD/processed_graphs.pkl', help="Path to the dataset")
parser.add_argument("--experiment_name", type=str, default=None, help="Name of the experiment")
# Model Parameters
parser.add_argument("--model", type=str, choices=['MLP', 'GCN', 'Sage', 'GIN', 'GAT', 'GATv2'], default='GIN', help="Type of GNN encoder model to use")
parser.add_argument("--hidden_channels", type=int, default=512, help="Number of hidden channels in the model")
parser.add_argument("--num_layers", type=int, default=4, help="Number of layers in the model")
parser.add_argument("--embedding_dim", type=int, default=128, help="Dimensionality of the embedding layer")
parser.add_argument("--dropout", type=float, default=0.6, help="Dropout rate for the model")
parser.add_argument("--heads", type=int, default=8, help="Number of heads in GAT model, applicable only if GAT model is selected")
parser.add_argument("--norm", type=str, default='batch', choices=[None, 'batch', 'layer'], help="Type of normalization to use, applicable to most models")
# Optimization Parameters
parser.add_argument("--epochs", type=int, default=250, help="Number of epochs to train")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size for training and evaluation")
parser.add_argument("--learning_rate", type=float, default=0.001, help="Initial learning rate")
parser.add_argument("--max_lr", type=float, default=0.01, help="Maximum learning rate for OneCycleLR scheduler")
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum for SGD optimizer")
parser.add_argument("--weight_decay", type=float, default=1e-6, help="Weight decay for optimizer")
args = parser.parse_args()
args.experiment_name = args.model
return args
def load_data(dataset_path):
graph_data, vocab_size = load_dataset(dataset_path)
graphs = [data['graph'] for data in graph_data]
labels = [data['label'] for data in graph_data]
train_graphs, test_graphs, train_labels, test_labels = train_test_split(
graphs, labels, test_size=0.3, random_state=42, stratify=labels
)
# Binarize labels
train_labels = ['normal' if label == 'normal' else 'malware' for label in train_labels]
test_labels = ['normal' if label == 'normal' else 'malware' for label in test_labels]
# Encode labels
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_labels)
test_labels = label_encoder.transform(test_labels)
for graph, label in zip(train_graphs, train_labels):
graph.y = torch.tensor([label], dtype=torch.long)
for graph, label in zip(test_graphs, test_labels):
graph.y = torch.tensor([label], dtype=torch.long)
train_dataset = CustomGraphDataset(train_graphs, len(label_encoder.classes_), training=True)
test_dataset = CustomGraphDataset(test_graphs, len(label_encoder.classes_), training=False)
return train_dataset, test_dataset, vocab_size, label_encoder
def initialize_model(vocab_size, num_node_features, num_classes, num_steps_per_epoch, args, device):
model = GNNModel(vocab_size, args.embedding_dim, num_node_features, args.hidden_channels, args.num_layers, num_classes, args.dropout, 'relu',
model_type=args.model, heads=args.heads, norm=args.norm).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = OneCycleLR(optimizer, max_lr=args.max_lr, pct_start=0.3, steps_per_epoch=num_steps_per_epoch, epochs=args.epochs)
criterion = torch.nn.CrossEntropyLoss()
return model, optimizer, scheduler, criterion
def train_epoch(model, train_loader, optimizer, scheduler, criterion, device):
model.train()
total_loss = []
for data in train_loader:
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = criterion(out, data.y)
loss.backward()
optimizer.step()
scheduler.step()
total_loss.append(loss.item())
return np.mean(total_loss)
def evaluate_model(model, loader, device):
model.eval()
preds, labels = [], []
with torch.no_grad():
for data in loader:
data = data.to(device)
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1)
preds.extend(pred.cpu().numpy())
labels.extend(data.y.cpu().numpy())
return np.array(preds), np.array(labels)
def main():
args = parse_arguments()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Call the function to set random seeds right at the beginning of main
set_random_seeds()
train_dataset, test_dataset, vocab_size, label_encoder = load_data(args.dataset_path)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
# Extract `num_node_features` and `num_classes` to pass to `initialize_model`
num_node_features = train_dataset[0].num_node_features
num_classes = len(label_encoder.classes_)
num_steps_per_epoch = len(train_loader)
model, optimizer, scheduler, criterion = initialize_model(vocab_size, num_node_features, num_classes, num_steps_per_epoch, args, device)
experiment_tracker = ExperimentTracker(model, optimizer, scheduler, label_encoder, args)
for epoch in range(1, args.epochs + 1):
loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, device)
train_preds, train_labels = evaluate_model(model, train_loader, device)
test_preds, test_labels = evaluate_model(model, test_loader, device)
current_lr = optimizer.param_groups[0]['lr']
# Update the metrics tracker with the results from this epoch
experiment_tracker.update_and_save(epoch, loss, train_preds, train_labels, test_preds, test_labels, current_lr)
print(experiment_tracker)
if __name__ == "__main__":
main()