-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagehistorybuffer.py
57 lines (50 loc) · 2.46 KB
/
imagehistorybuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Adapted from https://github.com/mjdietzx/SimGAN/, under MIT license
import numpy as np
import torch
class ImageHistoryBuffer(object):
def __init__(self, shape, max_size, batch_size, device):
"""
Initialize the class's state.
:param shape: Shape of the data to be stored in the image history buffer
(i.e. (0, img_height, img_width, img_channels)).
:param max_size: Maximum number of images that can be stored in the image history buffer.
:param batch_size: Batch size used to train GAN.
:param device: torch device.
"""
self.image_history_buffer = torch.zeros(shape, device=device)
self.max_size = max_size
self.batch_size = batch_size
def add_to_image_history_buffer(self, images, nb_to_add=None):
"""
To be called during training of GAN. By default add batch_size // 2 images to the image history buffer each
time the generator generates a new batch of images.
:param images: Array of images (usually a batch) to be added to the image history buffer.
:param nb_to_add: The number of images from `images` to add to the image history buffer
(batch_size / 2 by default).
"""
if not nb_to_add:
nb_to_add = self.batch_size // 2
if len(self.image_history_buffer) < self.max_size:
self.image_history_buffer = torch.cat(
(self.image_history_buffer, images[:nb_to_add]), dim=0)
elif len(self.image_history_buffer) == self.max_size:
self.image_history_buffer[:nb_to_add] = images[:nb_to_add]
else:
assert False
# np.random.shuffle(self.image_history_buffer)
# random shuffle
self.image_history_buffer = self.image_history_buffer[torch.randperm(
self.image_history_buffer.size()[0])]
def get_from_image_history_buffer(self, nb_to_get=None):
"""
Get a random sample of images from the history buffer.
:param nb_to_get: Number of images to get from the image history buffer (batch_size / 2 by default).
:return: A random sample of `nb_to_get` images from the image history buffer, or an empty np array if the image
history buffer is empty.
"""
if not nb_to_get:
nb_to_get = self.batch_size // 2
try:
return self.image_history_buffer[:nb_to_get]
except IndexError:
return torch.zeros(0)