forked from danieltan07/dagmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
77 lines (59 loc) · 2.1 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import os
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import h5py
import numpy as np
import collections
import numbers
import math
import pandas as pd
class KDD99Loader(object):
def __init__(self, data_path, mode="train"):
self.mode=mode
data = np.load(data_path)
labels = data["kdd"][:,-1]
features = data["kdd"][:,:-1]
N, D = features.shape
normal_data = features[labels==1]
normal_labels = labels[labels==1]
N_normal = normal_data.shape[0]
attack_data = features[labels==0]
attack_labels = labels[labels==0]
N_attack = attack_data.shape[0]
randIdx = np.arange(N_attack)
np.random.shuffle(randIdx)
N_train = N_attack // 2
self.train = attack_data[randIdx[:N_train]]
self.train_labels = attack_labels[randIdx[:N_train]]
self.test = attack_data[randIdx[N_train:]]
self.test_labels = attack_labels[randIdx[N_train:]]
self.test = np.concatenate((self.test, normal_data),axis=0)
self.test_labels = np.concatenate((self.test_labels, normal_labels),axis=0)
def __len__(self):
"""
Number of images in the object dataset.
"""
if self.mode == "train":
return self.train.shape[0]
else:
return self.test.shape[0]
def __getitem__(self, index):
if self.mode == "train":
return np.float32(self.train[index]), np.float32(self.train_labels[index])
else:
return np.float32(self.test[index]), np.float32(self.test_labels[index])
def get_loader(data_path, batch_size, mode='train'):
"""Build and return data loader."""
dataset = KDD99Loader(data_path, mode)
shuffle = False
if mode == 'train':
shuffle = True
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
return data_loader