Skip to content

Commit

Permalink
add utilities for box evaluations
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Nov 27, 2023
1 parent 324b9a6 commit 6868249
Show file tree
Hide file tree
Showing 5 changed files with 569 additions and 78 deletions.
32 changes: 0 additions & 32 deletions chebai/models/box_eval.py

This file was deleted.

79 changes: 33 additions & 46 deletions chebai/models/electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from chebai.preprocessing.datasets.chebi import extract_class_hierarchy
import torch
import csv

import pytorch_lightning as pl
from chebai.models.base import ChebaiBaseNet

logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
Expand Down Expand Up @@ -58,15 +58,15 @@ def forward(self, data):
mask = data["mask"]
with torch.no_grad():
dis_tar = (
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1)
).int()
disc_tar_one_hot = torch.eq(
torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None]
)
gen_tar = features[disc_tar_one_hot]
gen_tar_one_hot = torch.eq(
torch.arange(self.generator_config.vocab_size, device=self.device)[
None, :
None, :
],
gen_tar[:, None],
)
Expand Down Expand Up @@ -118,7 +118,7 @@ def forward(self, input, target):

def filter_dict(d, filter_key):
return {
str(k)[len(filter_key) :]: v
str(k)[len(filter_key):]: v
for k, v in d.items()
if str(k).startswith(filter_key)
}
Expand All @@ -139,10 +139,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand All @@ -157,7 +157,7 @@ def as_pretrained(self):
return self.electra.electra

def __init__(
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
self, config=None, pretrained_checkpoint=None, load_prefix=None, **kwargs
):
# Remove this property in order to prevent it from being stored as a
# hyper parameter
Expand Down Expand Up @@ -282,7 +282,7 @@ def _build_disjointness_filter(path_to_disjointedness, label_names, hierarchy):

class ElectraChEBILoss(nn.Module):
def __init__(
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
self, path_to_chebi, path_to_label_names, base_loss: torch.nn.Module = None
):
super().__init__()
self.base_loss = base_loss
Expand Down Expand Up @@ -312,11 +312,11 @@ def forward(self, input, target, **kwargs):

class ElectraChEBIDisjointLoss(ElectraChEBILoss):
def __init__(
self,
path_to_chebi,
path_to_label_names,
path_to_disjointedness,
base_loss: torch.nn.Module = None,
self,
path_to_chebi,
path_to_label_names,
path_to_disjointedness,
base_loss: torch.nn.Module = None,
):
super().__init__(path_to_chebi, path_to_label_names, base_loss)
label_names = _load_label_names(path_to_label_names)
Expand Down Expand Up @@ -379,10 +379,10 @@ def _process_batch(self, batch, batch_idx):
batch_first=True,
)
cls_tokens = (
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze(
-1
)
* CLS_TOKEN
)
return dict(
features=torch.cat((cls_tokens, batch.x), dim=1),
Expand Down Expand Up @@ -417,7 +417,7 @@ def __init__(self, cone_dimensions=20, **kwargs):
model_dict = torch.load(fin, map_location=self.device)
if model_prefix:
state_dict = {
str(k)[len(model_prefix) :]: v
str(k)[len(model_prefix):]: v
for k, v in model_dict["state_dict"].items()
if str(k).startswith(model_prefix)
}
Expand Down Expand Up @@ -480,22 +480,13 @@ def forward(self, data, **kwargs):
class ChebiBox(Electra):
NAME = "ChebiBox"

def __init__(self, dimensions=12, hidden_size=2000, **kwargs):
def __init__(self, dimensions=2, hidden_size=4000, **kwargs):
super().__init__(**kwargs)
self.dimensions = dimensions
self.max_value = 100

self.boxes = nn.Parameter(
torch.rand((self.config.num_labels, self.dimensions, 2))
3 - torch.rand((self.config.num_labels, self.dimensions, 2)) * 6
)

"""
# for boxes coordinates with values between 0.0 and 5.0 (instead of between 0.0 and 1.0 generated by torch.rand
boxes = torch.rand(self.config.num_labels, self.dimensions, 2) * 5
self.boxes = nn.Parameter(
boxes
)
"""
self.embeddings_to_points = nn.Sequential(
nn.Linear(self.electra.config.hidden_size, hidden_size),
nn.ReLU(),
Expand All @@ -515,39 +506,35 @@ def forward(self, data, **kwargs):
l = torch.min(b, dim=-1)[0]
r = torch.max(b, dim=-1)[0]
p = points.expand(self.config.num_labels, -1, -1).transpose(1, 0)
membership_per_dim = torch.max(torch.stack((nn.functional.relu(l - p), nn.functional.relu(p - r))), dim=0)[0]
# min might be replaced
#m = torch.min(membership_per_dim, dim=-1)[0]
#m = torch.mean(membership_per_dim, dim=-1)
#s = 2 - ( 2 * (torch.sigmoid(m)) )
#logits = torch.logit( (s * 0.99) + 0.001)
max_distance_per_dim = torch.max(torch.stack((nn.functional.relu(l - p), nn.functional.relu(p - r))), dim=0)[0]

m = torch.sum(membership_per_dim, dim=-1)
logits = m
m = torch.sum(max_distance_per_dim, dim=-1)
s = 2 - (2 * (torch.sigmoid(m)))
logits = torch.logit((s * 0.99) + 0.001)

return dict(
boxes=b,
points=p,
embedded_points=points,
logits=logits,
attentions=electra.attentions,
target_mask=data.get("target_mask"),
)


class BoxLoss(pl.LightningModule):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, outputs, targets):
d = outputs
t = targets
theta = 0.00004
loss = ( ( torch.sqrt(d) * (d > theta) * t ) + ( (d <= theta) * (1 - d) * (1 - t) ) )

theta = 0.4
loss = ((torch.sqrt(d) * torch.log(1 + torch.exp(d - theta)) * t) + (torch.log(1 + torch.exp(-d)) * (1 - d) * (1 - t)))
scalar_loss = torch.mean(loss)
return scalar_loss

def softabs(x, eps=0.01):
return (x**2 + eps) ** 0.5 - eps**0.5
return (x ** 2 + eps) ** 0.5 - eps ** 0.5


def anglify(x):
Expand All @@ -574,8 +561,8 @@ def in_cone_parts(vectors, cone_axes, cone_arcs):
dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang)
return dis
"""
a = cone_axes - cone_arcs**2
b = cone_axes + cone_arcs**2
a = cone_axes - cone_arcs ** 2
b = cone_axes + cone_arcs ** 2
bigger_than_a = torch.sigmoid(vectors - a)
smaller_than_b = torch.sigmoid(b - vectors)
return bigger_than_a * smaller_than_b
Expand Down
135 changes: 135 additions & 0 deletions chebai/result/box_evals/evals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
Some utilities for evaluations. The code is not integrated to Chebai library yet.
"""
import torch
import pickle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import pandas as pd
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

#-----------------------------------------------------------------

# model outputs are stored in ./outputs
val_f1 = pd.read_csv("path to metrics.csv")
pd.DataFrame([list(val_f1[["val_micro-f1"]].dropna()), list(val_f1[["train_micro-f1"]].dropna())])
df1 = val_f1["val_micro-f1"].dropna().tolist()
df2 = val_f1["train_micro-f1"].dropna().tolist()
pd.DataFrame(zip(df1, df2), columns=["val_micro-f1", "train_micro-f1"]).plot()


#-----------------------------------------------------------------
# checkpoints are stored in ./saved_models
checkpoint = torch.load("/content/drive/MyDrive/Box/saved_models/mean_dim_3/best_epoch=199_val_loss=0.0399.ckpt", map_location=torch.device('cpu'))
mboxes = checkpoint["state_dict"]["boxes"]
corner_1 = mboxes[:,:,0]
corner_2 = mboxes[:,:,1]
boxes = [[corner_1[i].cpu().detach().numpy(), corner_2[i].cpu().detach().numpy()] for i in range(854)]


#-----------------------------------------------------------------
# Extract top most and least presented classes in the training dataset
with open('path to train dataset ... /train.pkl', 'rb') as f:
train_data = pickle.load(f)

chebi_labels = list(train_data.columns[3:])
ds_dict = train_data.iloc[:, 3:].sum().to_dict()

N = 200
top_represented_classes = [ i[0] for i in sorted( ds_dict.items(), key=lambda pair: pair[1], reverse=True )[:N] ]
lowest_represented_classes = [ i[0] for i in sorted( ds_dict.items(), key=lambda pair: -pair[1], reverse=True )[:N] ]

#visualize_boxes_3d(boxes, chebi_labels, limits=lowest_represented_classes )

def visualize_boxes_3d(boxes, labels, limits):

plt.figure(figsize=(40,40))
ax = plt.axes(projection='3d')
ax.set_xlim([-15, 15])
ax.set_ylim([-15, 15])
ax.set_zlim([-15, 15])

for idx_i, box_i in enumerate(boxes):
min_corner, max_corner = box_i
vertices = [
[min_corner[0], min_corner[1], min_corner[2]],
[max_corner[0], min_corner[1], min_corner[2]],
[max_corner[0], max_corner[1], min_corner[2]],
[min_corner[0], max_corner[1], min_corner[2]],
[min_corner[0], min_corner[1], max_corner[2]],
[max_corner[0], min_corner[1], max_corner[2]],
[max_corner[0], max_corner[1], max_corner[2]],
[min_corner[0], max_corner[1], max_corner[2]]
]
faces = [
[vertices[0], vertices[1], vertices[2], vertices[3]],
[vertices[4], vertices[5], vertices[6], vertices[7]],
[vertices[0], vertices[1], vertices[5], vertices[4]],
[vertices[2], vertices[3], vertices[7], vertices[6]],
[vertices[1], vertices[2], vertices[6], vertices[5]],
[vertices[4], vertices[7], vertices[3], vertices[0]]
]
if labels[idx_i] in limits:
poly3d = Poly3DCollection(faces, color='blue', linewidths=1, edgecolors='r', alpha=0.02)
ax.add_collection3d(poly3d)
ax.text(min_corner[0], min_corner[1], min_corner[2], labels[idx_i], color='green', fontsize=24)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Visualization of Chebi classes in 3D')
plt.show()

# Example usage:
# visualize_boxes_3d(boxes, chebi_labels, limits=top_represented_classes )
# visualize_boxes_3d(boxes, chebi_labels, limits=lowest_represented_classes )

#-----------------------------------------------------------------
# Calculate containments based on boxes:

n = len(boxes)
containment_matrix = np.zeros((n, n), dtype=float)

for i in range(n):
for j in range(n):
if i != j:
box1 = boxes[i]
box2 = boxes[j]

min_corners_box_1 = np.minimum(box1[0], box1[1])
max_corners_box_1 = np.maximum(box1[0], box1[1])

min_corners_box_2 = np.minimum(box2[0], box2[1])
max_corners_box_2 = np.maximum(box2[0], box2[1])

dim = len(min_corners_box_1)
product_of_b = 1
product_of_intersection = 1

for d in range(dim):
left_most_corner_of_intersection = max(min_corners_box_1[d], min_corners_box_2[d])
right_most_corner_of_intersection = min(max_corners_box_1[d], max_corners_box_2[d])

intersection = (left_most_corner_of_intersection <= right_most_corner_of_intersection) * (right_most_corner_of_intersection - left_most_corner_of_intersection)
product_of_intersection *= intersection

size_of_a = max_corners_box_1[d] - min_corners_box_1[d]
size_of_b = max_corners_box_2[d] - min_corners_box_2[d]
product_of_b *= size_of_b

if product_of_b:
containment_matrix[j][i] = ( product_of_intersection /product_of_b)

# A heatmap for containments:
import numpy as np
import matplotlib.pyplot as plt

binary_data = containment_matrix.astype(int)
fig, ax = plt.subplots(figsize=(20, 8))
im = ax.imshow(binary_data, cmap='coolwarm')
plt.colorbar(im, ticks=[0, 1], label='True/False')
plt.grid(False)
plt.show()
Loading

0 comments on commit 6868249

Please sign in to comment.