Skip to content

Commit

Permalink
fix (ctgan): add discriminator to model attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
MazenAli committed Nov 21, 2024
1 parent c0ea824 commit fc0b3e3
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
self._transformer = None
self._data_sampler = None
self._generator = None
self._discriminator = None
self.loss_values = None

@staticmethod
Expand Down Expand Up @@ -330,7 +331,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
).to(self._device)

discriminator = Discriminator(
self._discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
).to(self._device)

Expand All @@ -342,7 +343,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
)

optimizerD = optim.Adam(
discriminator.parameters(),
self._discriminator.parameters(),
lr=self._discriminator_lr,
betas=(0.5, 0.9),
weight_decay=self._discriminator_decay,
Expand Down Expand Up @@ -395,10 +396,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
real_cat = real
fake_cat = fakeact

y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)
y_fake = self._discriminator(fake_cat)
y_real = self._discriminator(real_cat)

pen = discriminator.calc_gradient_penalty(
pen = self._discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device, self.pac
)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
Expand All @@ -423,9 +424,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
fakeact = self._apply_activate(fake)

if c1 is not None:
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = discriminator(fakeact)
y_fake = self._discriminator(fakeact)

if condvec is None:
cross_entropy = 0
Expand Down Expand Up @@ -520,3 +521,5 @@ def set_device(self, device):
self._device = device
if self._generator is not None:
self._generator.to(self._device)
if self._discriminator is not None:
self._discriminator.to(self._device)

0 comments on commit fc0b3e3

Please sign in to comment.