Skip to content

Commit

Permalink
Merge branch 'dev' into roles-preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
dawerner committed Oct 27, 2023
2 parents 3e04e51 + bc5c603 commit a1284cf
Show file tree
Hide file tree
Showing 21 changed files with 224 additions and 126 deletions.
12 changes: 11 additions & 1 deletion chebai/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from lightning.pytorch.cli import LightningCLI


class ChebaiCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
for kind in ("train", "val", "test"):
for average in ("micro", "macro"):
parser.link_arguments(
"model.init_args.out_dim",
f"model.init_args.{kind}_metrics.init_args.metrics.{average}-f1.init_args.num_labels",
)


def cli():
r = LightningCLI(save_config_callback=None)
r = ChebaiCLI(save_config_callback=None, parser_kwargs={"parser_mode": "omegaconf"})
15 changes: 15 additions & 0 deletions chebai/loss/mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torch import nn


class MixedDataLoss(nn.Module):
def __init__(self, base_loss: nn.Module):
super().__init__()
self.base_loss = base_loss

def forward(self, input, target, **kwargs):
nnl = kwargs.pop("non_null_labels", None)
if nnl:
inp = input[nnl]
else:
inp = input
return self.base_loss(inp, target, **kwargs)
116 changes: 116 additions & 0 deletions chebai/loss/semantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
from chebai.models.electra import extract_class_hierarchy
import os
import csv
import pickle

IMPLICATION_CACHE_FILE = "chebi.cache"


class ImplicationLoss(torch.nn.Module):
def __init__(
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
):
super().__init__()
self.base_loss = base_loss
label_names = _load_label_names(path_to_label_names)
hierarchy = _load_implications(path_to_chebi)
implication_filter = _build_implication_filter(label_names, hierarchy)
self.implication_filter_l = implication_filter[:, 0]
self.implication_filter_r = implication_filter[:, 1]

def forward(self, input, target, **kwargs):
if target is not None:
base_loss = self.base_loss(input, target.float())
else:
base_loss = 0
pred = torch.sigmoid(input)
l = pred[:, self.implication_filter_l]
r = pred[:, self.implication_filter_r]
# implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0))
implication_loss = self._calculate_implication_loss(l, r)
return base_loss + implication_loss

def _calculate_implication_loss(self, l, r):
capped_difference = torch.relu(l - r)
return torch.mean(
torch.sum(
(torch.softmax(capped_difference, dim=-1) * capped_difference), dim=-1
),
dim=0,
)


class DisjointLoss(ImplicationLoss):
def __init__(
self,
path_to_chebi,
path_to_label_names,
path_to_disjointedness,
base_loss: torch.nn.Module = None,
):
super().__init__(path_to_chebi, path_to_label_names, base_loss)
label_names = _load_label_names(path_to_label_names)
hierarchy = _load_implications(path_to_chebi)
self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter(
path_to_disjointedness, label_names, hierarchy
)

def forward(self, input, target, **kwargs):
loss = super().forward(input, target, **kwargs)
pred = torch.sigmoid(input)
l = pred[:, self.disjoint_filter_l]
r = pred[:, self.disjoint_filter_r]
disjointness_loss = self._calculate_implication_loss(l, 1 - r)
return loss + disjointness_loss


def _load_label_names(path_to_label_names):
with open(path_to_label_names) as fin:
label_names = [int(line.strip()) for line in fin]
return label_names


def _load_implications(path_to_chebi, implication_cache=IMPLICATION_CACHE_FILE):
if os.path.isfile(implication_cache):
with open(implication_cache, "rb") as fin:
hierarchy = pickle.load(fin)
else:
hierarchy = extract_class_hierarchy(path_to_chebi)
with open(implication_cache, "wb") as fout:
pickle.dump(hierarchy, fout)
return hierarchy


def _build_implication_filter(label_names, hierarchy):
return torch.tensor(
[
(i1, i2)
for i1, l1 in enumerate(label_names)
for i2, l2 in enumerate(label_names)
if l2 in hierarchy.pred[l1]
]
)


def _build_disjointness_filter(path_to_disjointedness, label_names, hierarchy):
disjoints = set()
label_dict = dict(map(reversed, enumerate(label_names)))

with open(path_to_disjointedness, "rt") as fin:
reader = csv.reader(fin)
for l1_raw, r1_raw in reader:
l1 = int(l1_raw)
r1 = int(r1_raw)
disjoints.update(
{
(label_dict[l2], label_dict[r2])
for r2 in hierarchy.succ[r1]
if r2 in label_names
for l2 in hierarchy.succ[l1]
if l2 in label_names and l2 < r2
}
)

dis_filter = torch.tensor(list(disjoints))
return dis_filter[:, 0], dis_filter[:, 1]
14 changes: 8 additions & 6 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class ChebaiBaseNet(LightningModule):
def __init__(
self,
criterion: torch.nn.Module = None,
out_dim=None,
metrics: Optional[Dict[str, torch.nn.Module]] = None,
out_dim: Optional[int] = None,
train_metrics: Optional[torch.nn.Module] = None,
val_metrics: Optional[torch.nn.Module] = None,
test_metrics: Optional[torch.nn.Module] = None,
pass_loss_kwargs=True,
**kwargs,
):
Expand All @@ -25,9 +27,9 @@ def __init__(
self.save_hyperparameters(ignore=["criterion"])
self.out_dim = out_dim
self.optimizer_kwargs = kwargs.get("optimizer_kwargs", dict())
self.train_metrics = metrics["train"]
self.validation_metrics = metrics["validation"]
self.test_metrics = metrics["test"]
self.train_metrics = train_metrics
self.validation_metrics = val_metrics
self.test_metrics = test_metrics
self.pass_loss_kwargs = pass_loss_kwargs

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -88,7 +90,7 @@ def _execute(self, batch, batch_idx, metrics, prefix="", log=True, sync_dist=Fal
f"{prefix}loss",
loss.item(),
batch_size=batch.x.shape[0],
on_step=False,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
Expand Down
110 changes: 0 additions & 110 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,116 +226,6 @@ def forward(self, data, **kwargs):
)


IMPLICATION_CACHE_FILE = "chebi.cache"


def _load_label_names(path_to_label_names):
with open(path_to_label_names) as fin:
label_names = [int(line.strip()) for line in fin]
return label_names


def _load_implications(path_to_chebi, implication_cache=IMPLICATION_CACHE_FILE):
if os.path.isfile(implication_cache):
with open(implication_cache, "rb") as fin:
hierarchy = pickle.load(fin)
else:
hierarchy = extract_class_hierarchy(path_to_chebi)
with open(implication_cache, "wb") as fout:
pickle.dump(hierarchy, fout)
return hierarchy


def _build_implication_filter(label_names, hierarchy):
return torch.tensor(
[
(i1, i2)
for i1, l1 in enumerate(label_names)
for i2, l2 in enumerate(label_names)
if l2 in hierarchy.pred[l1]
]
)


def _build_disjointness_filter(path_to_disjointedness, label_names, hierarchy):
disjoints = set()
label_dict = dict(map(reversed, enumerate(label_names)))

with open(path_to_disjointedness, "rt") as fin:
reader = csv.reader(fin)
for l1_raw, r1_raw in reader:
l1 = int(l1_raw)
r1 = int(r1_raw)
disjoints.update(
{
(label_dict[l2], label_dict[r2])
for r2 in hierarchy.succ[r1]
if r2 in label_names
for l2 in hierarchy.succ[l1]
if l2 in label_names and l2 < r2
}
)

dis_filter = torch.tensor(list(disjoints))
return dis_filter[:, 0], dis_filter[:, 1]


class ElectraChEBILoss(nn.Module):
def __init__(
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
):
super().__init__()
self.base_loss = base_loss
label_names = _load_label_names(path_to_label_names)
hierarchy = _load_implications(path_to_chebi)
implication_filter = _build_implication_filter(label_names, hierarchy)
self.implication_filter_l = implication_filter[:, 0]
self.implication_filter_r = implication_filter[:, 1]

def forward(self, input, target, **kwargs):
if "non_null_labels" in kwargs:
n = kwargs["non_null_labels"]
inp = input[n]
else:
inp = input
if target is not None:
base_loss = self.base_loss(inp, target.float())
else:
base_loss = 0
pred = torch.sigmoid(input)
l = pred[:, self.implication_filter_l]
r = pred[:, self.implication_filter_r]
# implication_loss = torch.sqrt(torch.mean(torch.sum(l*(1-r), dim=-1), dim=0))
implication_loss = torch.mean(torch.mean(torch.relu(l - r), dim=-1), dim=0)
return base_loss + implication_loss


class ElectraChEBIDisjointLoss(ElectraChEBILoss):
def __init__(
self,
path_to_chebi,
path_to_label_names,
path_to_disjointedness,
base_loss: torch.nn.Module = None,
):
super().__init__(path_to_chebi, path_to_label_names, base_loss)
label_names = _load_label_names(path_to_label_names)
hierarchy = _load_implications(path_to_chebi)
self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter(
path_to_disjointedness, label_names, hierarchy
)

def forward(self, input, target, **kwargs):
loss = super().forward(input, target, **kwargs)
pred = torch.sigmoid(input)
l = pred[:, self.disjoint_filter_l]
r = pred[:, self.disjoint_filter_r]
disjointness_loss = torch.mean(
torch.mean(torch.relu(l - (1 - r)), dim=-1), dim=0
)
return loss + disjointness_loss


class ElectraLegacy(ChebaiBaseNet):
NAME = "ElectraLeg"

Expand Down
6 changes: 5 additions & 1 deletion chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
data_limit: typing.Optional[int] = None,
label_filter: typing.Optional[int] = None,
balance_after_filter: typing.Optional[float] = None,
num_workers: int = 1,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -38,6 +39,7 @@ def __init__(
self.label_filter is None
), "Filter balancing requires a filter"
self.balance_after_filter = balance_after_filter
self.num_workers = num_workers
os.makedirs(self.raw_dir, exist_ok=True)
os.makedirs(self.processed_dir, exist_ok=True)

Expand Down Expand Up @@ -112,7 +114,9 @@ def _load_data_from_file(self, path):
return data

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return self.dataloader("train", shuffle=True, **kwargs)
return self.dataloader(
"train", shuffle=True, num_workers=self.num_workers, **kwargs
)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return self.dataloader("validation", shuffle=False, **kwargs)
Expand Down
19 changes: 14 additions & 5 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from chebai.preprocessing import reader as dr
from chebai.preprocessing.datasets.base import XYBaseDataModule, DataLoader
from chebai.preprocessing.datasets.chebi import ChEBIOver100
from chebai.preprocessing.datasets.chebi import ChEBIOver100, ChEBIOver50, ChEBIOverX


class PubChem(XYBaseDataModule):
Expand Down Expand Up @@ -220,17 +220,18 @@ def raw_dir(self):
return os.path.join("data", self._name, "raw")


class PubToxAndChEBI100(XYBaseDataModule):
class PubToxAndChebiX(XYBaseDataModule):
READER = dr.ChemDataReader
CHEBI_X = ChEBIOverX

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.labeled = ChEBIOver100(*args, **kwargs)
self.labeled = self.CHEBI_X(*args, **kwargs)
self.unlabeled = PubchemChem(*args, **kwargs)
super().__init__(*args, **kwargs)

@property
def _name(self):
return "PubToxUChebi100"
return "PubToxU" + self.labeled._name

def dataloader(self, kind, **kwargs):
labeled_data = torch.load(
Expand Down Expand Up @@ -260,3 +261,11 @@ def processed_file_names(self):
def setup_processed(self):
self.labeled.setup()
self.unlabeled.setup()


class PubToxAndChebi100(PubToxAndChebiX):
CHEBI_X = ChEBIOver100


class PubToxAndChebi50(PubToxAndChebiX):
CHEBI_X = ChEBIOver50
1 change: 1 addition & 0 deletions configs/data/chebi100.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver100
1 change: 1 addition & 0 deletions configs/data/chebi100_mixed.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.pubchem.PubToxAndChebi100
1 change: 1 addition & 0 deletions configs/data/chebi50.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.chebi.ChEBIOver50
1 change: 1 addition & 0 deletions configs/data/chebi50_mixed.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.preprocessing.datasets.pubchem.PubToxAndChebi50
1 change: 1 addition & 0 deletions configs/loss/bce.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: torch.nn.BCEWithLogitsLoss
5 changes: 5 additions & 0 deletions configs/loss/semantic_loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class_path: chebai.loss.semantic.DisjointLoss
init_args:
path_to_chebi: data/ChEBI100/raw/chebi.obo
path_to_label_names: data/ChEBI100/raw/classes.txt
path_to_disjointedness: disjoint.csv
Loading

0 comments on commit a1284cf

Please sign in to comment.