Skip to content

Commit

Permalink
add data_limit to evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Jan 5, 2024
1 parent 7bbe5c1 commit 27ec370
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def evaluate_model(
collate = data_module.reader.COLLATER()

data_list = data_module.load_processed_data("test", filename)
data_list = data_list[: data_module.data_limit]
preds_list = []
labels_list = []
if buffer_dir is not None:
Expand Down Expand Up @@ -125,18 +126,11 @@ def load_results_from_buffer(buffer_dir, device):

def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output=False):
"""Prints relevant metrics, including micro and macro F1, recall and precision, best k classes and worst classes."""
f1_macro = MultilabelF1Score(preds.shape[1], average="macro").to(device=device)
f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device)

my_f1_macro = MacroF1(preds.shape[1]).to(device=device)

print(
f"Macro-F1 (torchmetrics, unadjusted) on test set with {preds.shape[1]} classes: {f1_macro(preds, labels):3f}"
)
print(f"Macro-F1 (my implementation): {my_f1_macro(preds, labels)}")
print(
f"Micro-F1 on test set with {preds.shape[1]} classes: {f1_micro(preds, labels):3f}"
)
print(f"Macro-F1: {my_f1_macro(preds, labels):3f}")
print(f"Micro-F1: {f1_micro(preds, labels):3f}")
precision_macro = MultilabelPrecision(preds.shape[1], average="macro").to(
device=device
)
Expand All @@ -156,7 +150,7 @@ def print_metrics(preds, labels, device, classes=None, top_k=10, markdown_output
)
print(f"| --- | --- | --- | --- | --- | --- | --- |")
print(
f"| | {f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | "
f"| | {my_f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | "
f"{precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | "
f"{recall_micro(preds, labels):3f} |"
)
Expand Down

0 comments on commit 27ec370

Please sign in to comment.