Skip to content

Commit

Permalink
back to original code
Browse files Browse the repository at this point in the history
  • Loading branch information
ouioui199 committed Jan 7, 2025
1 parent f57307e commit 5a6c2c8
Showing 1 changed file with 2 additions and 196 deletions.
198 changes: 2 additions & 196 deletions examples/vit_mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License

# Copyright (c) 2023 Jérémy Fix, Xuan-Huy Nguyen
# Copyright (c) 2023 Jérémy Fix

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -40,15 +40,6 @@
import torchvision
import torchvision.transforms.v2 as v2_transforms

import lightning as L
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.progress import TQDMProgressBar
from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm
from lightning.pytorch.utilities import rank_zero_only

from torchmetrics.classification import Accuracy

import torchcvnn.nn as c_nn
import torchcvnn.models as c_models

Expand Down Expand Up @@ -311,190 +302,5 @@ def train():
)


# Pytorch Lightning code
class TBLogger(TensorBoardLogger):
@rank_zero_only
def log_metrics(self, metrics, step):
metrics.pop('epoch', None)
metrics = {k: v for k, v in metrics.items() if ('step' not in k) and ('val' not in k)}
return super().log_metrics(metrics, step)


class CustomProgressBar(TQDMProgressBar):

def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items

def init_train_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for training."""
bar = super().init_train_tqdm()
bar.ascii = ' >'
return bar

def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
bar.ascii = ' >'
return bar


class cMNISTModel(L.LightningModule):

def __init__(self):
super().__init__()

self.ce_loss = nn.CrossEntropyLoss()
self.model = Model()
self.accuracy = Accuracy(task='multiclass', num_classes=10)

self.train_step_outputs = {}
self.valid_step_outputs = {}

def forward(self, x):
return self.model(x)

def configure_optimizers(self):
return torch.optim.Adam(params=self.parameters(), lr=3e-4)

def training_step(self, batch, batch_idx):
data, label = batch
logits = self(data)

loss = self.ce_loss(logits, label)
acc = self.accuracy(logits, label)

self.log('step_loss', loss, prog_bar=True, sync_dist=True)
self.log('step_metrics', acc, prog_bar=True, sync_dist=True)

if not self.train_step_outputs:
self.train_step_outputs = {
'step_loss': [loss],
'step_metrics': [acc]
}
else:
self.train_step_outputs['step_loss'].append(loss)
self.train_step_outputs['step_metrics'].append(acc)

return loss

def validation_step(self, batch: torch.Tensor, batch_idx: int):
images, labels = batch
logits = self(images)

loss = self.ce_loss(logits, labels)
acc = self.accuracy(logits, labels)
self.log('step_loss', loss, prog_bar=True, sync_dist=True)
self.log('step_metrics', acc, prog_bar=True, sync_dist=True)

if not self.valid_step_outputs:
self.valid_step_outputs = {
'step_loss': [loss],
'step_metrics': [acc]
}
else:
self.valid_step_outputs['step_loss'].append(loss)
self.valid_step_outputs['step_metrics'].append(acc)

def on_train_epoch_end(self) -> None:
_log_dict = {
'Loss/loss': torch.tensor(self.train_step_outputs['step_loss']).mean(),
'Metrics/accuracy': torch.tensor(self.train_step_outputs['step_metrics']).mean()
}

self.loggers[0].log_metrics(_log_dict, self.current_epoch)
self.train_step_outputs.clear()

def on_validation_epoch_end(self) -> None:
mean_loss_value = torch.tensor(self.valid_step_outputs['step_loss']).mean()
mean_metrics_value = torch.tensor(self.valid_step_outputs['step_metrics']).mean()

_log_dict = {
'Loss/loss': mean_loss_value,
'Metrics/accuracy': mean_metrics_value
}

self.loggers[1].log_metrics(_log_dict, self.current_epoch)

self.log('val_loss', mean_loss_value, sync_dist=True)
self.log('val_Accuracy', mean_metrics_value, sync_dist=True)
self.valid_step_outputs.clear()


def lightning_train(version: int):
batch_size = 64
epochs = 15
cdtype = torch.complex64
torch.set_float32_matmul_precision('high')

# Dataloading
train_dataset = torchvision.datasets.MNIST(
root="./data",
train=True,
download=True,
transform=v2_transforms.Compose(
[v2_transforms.PILToTensor(), v2_transforms.ToDtype(cdtype)]
),
)
valid_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
transform=v2_transforms.Compose(
[v2_transforms.PILToTensor(), v2_transforms.ToDtype(cdtype)]
),
)

# Train dataloader
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
persistent_workers=True,
pin_memory=True
)

# Valid dataloader
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
persistent_workers=True,
pin_memory=True
)

model = cMNISTModel()
trainer = L.Trainer(
max_epochs=epochs,
num_sanity_val_steps=0,
benchmark=True,
enable_checkpointing=True,
callbacks=[
CustomProgressBar(),
EarlyStopping(
monitor='val_loss',
verbose=True,
patience=5,
min_delta=0.005
),
LearningRateMonitor(logging_interval='epoch'),
ModelCheckpoint(
dirpath='weights_storage/vit_mnist',
monitor='val_Accuracy',
verbose=True,
mode='max'
)
],
logger=[
TBLogger('training_logs', name=None, sub_dir='train', version=version),
TBLogger('training_logs', name=None, sub_dir='valid', version=version)
]
)

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)


if __name__ == "__main__":
lightning_train(0)
train()

0 comments on commit 5a6c2c8

Please sign in to comment.