-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathpytorch_redis.py
66 lines (50 loc) · 1.64 KB
/
pytorch_redis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Shows how to store and load data from redis using a PyTorch
Dataset and DataLoader (with multiple workers).
@author: ptrblck
"""
import redis
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
# Create random data and push to redis
r = redis.Redis(host='localhost', port=6379, db=0)
nb_images = 100
for idx in range(nb_images):
# Use long for the fake images, as it's easier to store the target with it
data = np.random.randint(0, 256, (3, 24, 24), dtype=np.long).tobytes()
target = bytes(np.random.randint(0, 10, (1,)).astype(np.long))
r.set(idx, data + target)
# Create RedisDataset
class RedisDataset(Dataset):
def __init__(self,
redis_host='localhost',
redis_port=6379,
redis_db=0,
length=0,
transform=None):
self.db = redis.Redis(host=redis_host, port=redis_port, db=redis_db)
self.length = length
self.transform = transform
def __getitem__(self, index):
data = self.db.get(index)
data = np.frombuffer(data, dtype=np.long)
x = data[:-1].reshape(3, 24, 24).astype(np.uint8)
y = torch.tensor(data[-1]).long()
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return self.length
# Load samples from redis using multiprocessing
dataset = RedisDataset(length=100, transform=transforms.ToTensor())
loader = DataLoader(
dataset,
batch_size=10,
num_workers=2,
shuffle=True
)
for data, target in loader:
print(data.shape)
print(target.shape)