From 5a6c2c8db5d0b654014eb7eec72fd5c317648ed6 Mon Sep 17 00:00:00 2001 From: "huy.nguyen" Date: Tue, 7 Jan 2025 15:40:13 +0100 Subject: [PATCH] back to original code --- examples/vit_mnist.py | 198 +----------------------------------------- 1 file changed, 2 insertions(+), 196 deletions(-) diff --git a/examples/vit_mnist.py b/examples/vit_mnist.py index ade69b7..64819a5 100644 --- a/examples/vit_mnist.py +++ b/examples/vit_mnist.py @@ -1,6 +1,6 @@ # MIT License -# Copyright (c) 2023 Jérémy Fix, Xuan-Huy Nguyen +# Copyright (c) 2023 Jérémy Fix # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -40,15 +40,6 @@ import torchvision import torchvision.transforms.v2 as v2_transforms -import lightning as L -from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers import TensorBoardLogger -from lightning.pytorch.callbacks.progress import TQDMProgressBar -from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm -from lightning.pytorch.utilities import rank_zero_only - -from torchmetrics.classification import Accuracy - import torchcvnn.nn as c_nn import torchcvnn.models as c_models @@ -311,190 +302,5 @@ def train(): ) -# Pytorch Lightning code -class TBLogger(TensorBoardLogger): - @rank_zero_only - def log_metrics(self, metrics, step): - metrics.pop('epoch', None) - metrics = {k: v for k, v in metrics.items() if ('step' not in k) and ('val' not in k)} - return super().log_metrics(metrics, step) - - -class CustomProgressBar(TQDMProgressBar): - - def get_metrics(self, trainer, model): - items = super().get_metrics(trainer, model) - items.pop("v_num", None) - return items - - def init_train_tqdm(self) -> Tqdm: - """Override this to customize the tqdm bar for training.""" - bar = super().init_train_tqdm() - bar.ascii = ' >' - return bar - - def init_validation_tqdm(self): - bar = super().init_validation_tqdm() - bar.ascii = ' >' - return bar - - -class cMNISTModel(L.LightningModule): - - def __init__(self): - super().__init__() - - self.ce_loss = nn.CrossEntropyLoss() - self.model = Model() - self.accuracy = Accuracy(task='multiclass', num_classes=10) - - self.train_step_outputs = {} - self.valid_step_outputs = {} - - def forward(self, x): - return self.model(x) - - def configure_optimizers(self): - return torch.optim.Adam(params=self.parameters(), lr=3e-4) - - def training_step(self, batch, batch_idx): - data, label = batch - logits = self(data) - - loss = self.ce_loss(logits, label) - acc = self.accuracy(logits, label) - - self.log('step_loss', loss, prog_bar=True, sync_dist=True) - self.log('step_metrics', acc, prog_bar=True, sync_dist=True) - - if not self.train_step_outputs: - self.train_step_outputs = { - 'step_loss': [loss], - 'step_metrics': [acc] - } - else: - self.train_step_outputs['step_loss'].append(loss) - self.train_step_outputs['step_metrics'].append(acc) - - return loss - - def validation_step(self, batch: torch.Tensor, batch_idx: int): - images, labels = batch - logits = self(images) - - loss = self.ce_loss(logits, labels) - acc = self.accuracy(logits, labels) - self.log('step_loss', loss, prog_bar=True, sync_dist=True) - self.log('step_metrics', acc, prog_bar=True, sync_dist=True) - - if not self.valid_step_outputs: - self.valid_step_outputs = { - 'step_loss': [loss], - 'step_metrics': [acc] - } - else: - self.valid_step_outputs['step_loss'].append(loss) - self.valid_step_outputs['step_metrics'].append(acc) - - def on_train_epoch_end(self) -> None: - _log_dict = { - 'Loss/loss': torch.tensor(self.train_step_outputs['step_loss']).mean(), - 'Metrics/accuracy': torch.tensor(self.train_step_outputs['step_metrics']).mean() - } - - self.loggers[0].log_metrics(_log_dict, self.current_epoch) - self.train_step_outputs.clear() - - def on_validation_epoch_end(self) -> None: - mean_loss_value = torch.tensor(self.valid_step_outputs['step_loss']).mean() - mean_metrics_value = torch.tensor(self.valid_step_outputs['step_metrics']).mean() - - _log_dict = { - 'Loss/loss': mean_loss_value, - 'Metrics/accuracy': mean_metrics_value - } - - self.loggers[1].log_metrics(_log_dict, self.current_epoch) - - self.log('val_loss', mean_loss_value, sync_dist=True) - self.log('val_Accuracy', mean_metrics_value, sync_dist=True) - self.valid_step_outputs.clear() - - -def lightning_train(version: int): - batch_size = 64 - epochs = 15 - cdtype = torch.complex64 - torch.set_float32_matmul_precision('high') - - # Dataloading - train_dataset = torchvision.datasets.MNIST( - root="./data", - train=True, - download=True, - transform=v2_transforms.Compose( - [v2_transforms.PILToTensor(), v2_transforms.ToDtype(cdtype)] - ), - ) - valid_dataset = torchvision.datasets.MNIST( - root="./data", - train=False, - download=True, - transform=v2_transforms.Compose( - [v2_transforms.PILToTensor(), v2_transforms.ToDtype(cdtype)] - ), - ) - - # Train dataloader - train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=4, - persistent_workers=True, - pin_memory=True - ) - - # Valid dataloader - valid_loader = torch.utils.data.DataLoader( - valid_dataset, - batch_size=batch_size, - shuffle=False, - num_workers=4, - persistent_workers=True, - pin_memory=True - ) - - model = cMNISTModel() - trainer = L.Trainer( - max_epochs=epochs, - num_sanity_val_steps=0, - benchmark=True, - enable_checkpointing=True, - callbacks=[ - CustomProgressBar(), - EarlyStopping( - monitor='val_loss', - verbose=True, - patience=5, - min_delta=0.005 - ), - LearningRateMonitor(logging_interval='epoch'), - ModelCheckpoint( - dirpath='weights_storage/vit_mnist', - monitor='val_Accuracy', - verbose=True, - mode='max' - ) - ], - logger=[ - TBLogger('training_logs', name=None, sub_dir='train', version=version), - TBLogger('training_logs', name=None, sub_dir='valid', version=version) - ] - ) - - trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader) - - if __name__ == "__main__": - lightning_train(0) + train() \ No newline at end of file