Skip to content

Commit

Permalink
reformat with isort and black
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Jan 3, 2024
1 parent c108686 commit a832015
Show file tree
Hide file tree
Showing 19 changed files with 78 additions and 101 deletions.
1 change: 1 addition & 0 deletions chebai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import torch

MODULE_PATH = os.path.abspath(os.path.dirname(__file__))
Expand Down
5 changes: 3 additions & 2 deletions chebai/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import os

from lightning.pytorch.callbacks import BasePredictionWriter
import torch
import os
import json


class ChebaiPredictionWriter(BasePredictionWriter):
Expand Down
9 changes: 5 additions & 4 deletions chebai/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer, LightningModule
import os

from lightning.fabric.utilities.cloud_io import _is_dir
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_info
from lightning_utilities.core.rank_zero import rank_zero_warn

Expand Down
5 changes: 2 additions & 3 deletions chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Dict, Set

from lightning.pytorch.cli import LightningCLI, LightningArgumentParser
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI

from chebai.trainer.CustomTrainer import CustomTrainer
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.models.base import ChebaiBaseNet


class ChebaiCLI(LightningCLI):
Expand Down
6 changes: 3 additions & 3 deletions chebai/loggers/custom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datetime import datetime
from typing import Optional, Union, Literal
from typing import Literal, Optional, Union
import os

import wandb
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
import os
import wandb


class CustomLogger(WandbLogger):
Expand Down
8 changes: 5 additions & 3 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch
from chebai.models.electra import extract_class_hierarchy
import os
import csv
import os
import pickle

import torch

from chebai.models.electra import extract_class_hierarchy

IMPLICATION_CACHE_FILE = "chebi.cache"


Expand Down
22 changes: 4 additions & 18 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
import os.path
import pickle
import random
from math import pi
from tempfile import TemporaryDirectory
import logging
from typing import Dict
from math import pi

import torchmetrics
from torch import nn
from torch.nn.utils.rnn import (
pack_padded_sequence,
pad_packed_sequence,
pad_sequence,
)
from torch.nn.utils.rnn import pad_sequence
from transformers import (
ElectraConfig,
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForSequenceClassification,
ElectraModel,
PretrainedConfig,
)
from chebai.preprocessing.reader import MASK_TOKEN_INDEX, CLS_TOKEN
from chebai.preprocessing.datasets.chebi import extract_class_hierarchy
from chebai.loss.pretraining import ElectraPreLoss # noqa
import torch
import csv

from chebai.loss.pretraining import ElectraPreLoss # noqa
from chebai.models.base import ChebaiBaseNet
from chebai.preprocessing.reader import CLS_TOKEN, MASK_TOKEN_INDEX

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)

Expand Down
4 changes: 2 additions & 2 deletions chebai/models/lnn_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from lnn import Implies, Model, Not, Predicate, Variable, World
from owlready2 import get_ontology
from lnn import Model, Predicate, Variable, World, Implies, Not
import tqdm
import fastobo
import pyhornedowl
import tqdm


def get_name(iri: str):
Expand Down
6 changes: 3 additions & 3 deletions chebai/models/strontex.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import abc
import torch
import typing
import numpy as np
import networkx as nx

import networkx as nx
import numpy as np
import torch

FeatureType = typing.TypeVar("FeatureType")
LabelType = typing.TypeVar("LabelType")
Expand Down
10 changes: 5 additions & 5 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import random
import typing
from typing import List, Union
import os
import random
import typing

from torch.utils.data import DataLoader
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning_utilities.core.rank_zero import rank_zero_info
from torch.utils.data import DataLoader
import lightning as pl
import torch
import tqdm
import lightning as pl

from chebai.preprocessing import reader as dr
from lightning_utilities.core.rank_zero import rank_zero_info


class XYBaseDataModule(LightningDataModule):
Expand Down
4 changes: 1 addition & 3 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
from collections import OrderedDict
import os
import pickle
import random

from iterstrat.ml_stratifiers import (
MultilabelStratifiedShuffleSplit,
MultilabelStratifiedKFold,
MultilabelStratifiedShuffleSplit,
)

import fastobo
import networkx as nx
import pandas as pd
Expand Down
8 changes: 6 additions & 2 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import tqdm

from chebai.preprocessing import reader as dr
from chebai.preprocessing.datasets.base import XYBaseDataModule, DataLoader
from chebai.preprocessing.datasets.chebi import ChEBIOver100, ChEBIOver50, ChEBIOverX
from chebai.preprocessing.datasets.base import DataLoader, XYBaseDataModule
from chebai.preprocessing.datasets.chebi import (
ChEBIOver50,
ChEBIOver100,
ChEBIOverX,
)


class PubChem(XYBaseDataModule):
Expand Down
25 changes: 13 additions & 12 deletions chebai/preprocessing/datasets/tox21.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import random

from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
from chebai.preprocessing.datasets.pubchem import Hazardous
from chebai.preprocessing.datasets.base import XYBaseDataModule, MergedDataset
from tempfile import NamedTemporaryFile, TemporaryDirectory
from urllib import request
from sklearn.model_selection import train_test_split, GroupShuffleSplit
import csv
import gzip
import os
import csv
import random
import shutil
import zipfile

from rdkit import Chem
from sklearn.model_selection import GroupShuffleSplit, train_test_split
import numpy as np
import pysmiles
import torch

from chebai.preprocessing import reader as dr
import pysmiles
import numpy as np
from rdkit import Chem
import zipfile
import shutil
from chebai.preprocessing.datasets.base import MergedDataset, XYBaseDataModule
from chebai.preprocessing.datasets.chebi import JCIExtendedTokenData
from chebai.preprocessing.datasets.pubchem import Hazardous


class Tox21MolNet(XYBaseDataModule):
Expand Down
7 changes: 2 additions & 5 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

from pysmiles.read_smiles import _tokenize
from transformers import RobertaTokenizerFast
import selfies as sf
import deepsmiles
import selfies as sf

from chebai.preprocessing.collate import (
DefaultCollater,
RaggedCollater,
)
from chebai.preprocessing.collate import DefaultCollater, RaggedCollater

EMBEDDING_OFFSET = 10
PADDING_TOKEN_INDEX = 0
Expand Down
1 change: 0 additions & 1 deletion chebai/result/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import tqdm

from chebai.models.base import ChebaiBaseNet
from chebai.preprocessing.reader import DataReader

PROCESSORS = dict()

Expand Down
17 changes: 10 additions & 7 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os

from torchmetrics.classification import (
MultilabelF1Score,
MultilabelPrecision,
MultilabelRecall,
)
from chebai.callbacks.epoch_metrics import MacroF1

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import torch
import tqdm

from chebai.callbacks.epoch_metrics import MacroF1
from chebai.models import ChebaiBaseNet
from chebai.preprocessing.datasets import XYBaseDataModule

Expand Down Expand Up @@ -40,7 +40,8 @@ def evaluate_model(
batch_size: int = 32,
):
"""Runs model on test set of data_module (or, if filename is not None, on data set found in that file).
If buffer_dir is set, results will be saved in buffer_dir. Returns tensors with predictions and labels."""
If buffer_dir is set, results will be saved in buffer_dir. Returns tensors with predictions and labels.
"""
model.eval()
collate = data_module.reader.COLLATER()

Expand Down Expand Up @@ -142,6 +143,7 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
precision_micro = MultilabelPrecision(preds.shape[1], average="micro").to(
device=device
)
macro_adjust = 1
recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device)
recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device)
print(f"Macro-Precision: {precision_macro(preds, labels) * macro_adjust:3f}")
Expand All @@ -154,13 +156,14 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
)
print(f"| --- | --- | --- | --- | --- | --- | --- |")
print(
f"| | {f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | {precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | {recall_micro(preds, labels):3f} |"
f"| | {f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | "
f"{precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | "
f"{recall_micro(preds, labels):3f} |"
)

classwise_f1_fn = MultilabelF1Score(preds.shape[1], average=None).to(device=device)
classwise_f1 = classwise_f1_fn(preds, labels)
best_classwise_f1 = torch.topk(classwise_f1, top_k).indices
worst_classwise_f1 = torch.topk(classwise_f1, top_k, largest=False).indices
print(f"Top {top_k} classes (F1-score):")
for i, best in enumerate(best_classwise_f1):
print(
Expand Down
12 changes: 7 additions & 5 deletions chebai/result/pretraining.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from chebai.result.base import ResultProcessor
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import chebai.models.electra as electra
from chebai.loss.pretraining import ElectraPreLoss
import torch
import tqdm

from chebai.loss.pretraining import ElectraPreLoss
from chebai.result.base import ResultProcessor
import chebai.models.electra as electra


def visualise_loss(logs_path):
df = pd.read_csv(os.path.join(logs_path, "metrics.csv"))
Expand Down
26 changes: 5 additions & 21 deletions chebai/trainer/CustomTrainer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,14 @@
from typing import List, Optional
import logging
import os
from typing import Optional, Union, List

import pandas as pd
from lightning import Trainer, LightningModule
from lightning import LightningModule, Trainer
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning_utilities.core.rank_zero import (
WarningCache,
rank_zero_warn,
rank_zero_info,
)
from lightning.pytorch.loggers import CSVLogger
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from lightning.pytorch.callbacks.model_checkpoint import _is_dir

from chebai.loggers.custom import CustomLogger
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.collate import RaggedCollater
from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader
from torch.nn.utils.rnn import pad_sequence
import torch
import pandas as pd
import torch

from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader

log = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup
from setuptools import find_packages
from setuptools import find_packages, setup

packages = find_packages()
print(packages)
Expand Down

0 comments on commit a832015

Please sign in to comment.