Skip to content

Commit

Permalink
Annotate not accepting spans <= 1, added progress bars
Browse files Browse the repository at this point in the history
  • Loading branch information
galileosteinberg committed Jul 25, 2024
1 parent 5709a75 commit 2bc9cf7
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 48 deletions.
98 changes: 52 additions & 46 deletions benchmarks/bioid_ner_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
RESULTS_DIR = os.path.join(HERE, 'results', "bioid_ner_performance",
gilda.__version__)
MODULE = pystow.module('gilda', 'biocreative')
URL = 'https://biocreative.bioinformatics.udel.edu/media/store/files/2017/BioIDtraining_2.tar.gz'
URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017'
'/BioIDtraining_2.tar.gz')

STOPLIST_PATH = os.path.join(HERE, 'data', 'ner_stoplist.txt')

Expand All @@ -49,8 +50,8 @@ def __init__(self):
print("Instantiating benchmarker...")
self.equivalences = self._load_equivalences()
self.paper_level_grounding = defaultdict(set)
self.processed_data = self.process_xml_files() # xml files processesed
self.annotations_df = self._process_annotations_table() # csv annotations
self.processed_data = self.process_xml_files() # xml files processes
self.annotations_df = self._process_annotations_table() # csvannotations
self.stoplist = self._load_stoplist() # Load stoplist
self.gilda_annotations_map = defaultdict(list)
self.annotations_count = 0
Expand Down Expand Up @@ -289,10 +290,12 @@ def _get_plaintext(self, don_article: str) -> str:

def annotate_entities_with_gilda(self):
"""Performs NER on the XML files using gilda.annotate()"""
tqdm.write("Annotating corpus with Gilda...")
print("Annotating corpus with Gilda...")

total_gilda_annotations = 0
for _, item in self.processed_data.iterrows():
for _, item in tqdm(self.processed_data.iterrows(),
total=self.processed_data.shape[0],
desc="Annotating with Gilda"):
doc_id = item['doc_id']
figure = item['figure']
text = item['text']
Expand All @@ -311,23 +314,23 @@ def annotate_entities_with_gilda(self):

self.gilda_annotations_map[(doc_id, figure)].append(annotation)

if doc_id == '3868508' and figure == 'Figure_1-A':
print(f"Scored NER Match: {annotation}")
print(f"Annotated Text Segment: "
f"{text[annotation.start:annotation.end]} at "
f"indices {annotation.start} to {annotation.end}")
for i, scored_match in enumerate(annotation.matches):
print(f"Scored Match {i + 1}: {scored_match}")
print(
f"DB: {scored_match.term.db}, "
f"ID: {scored_match.term.id}")
print(
f"Score: {scored_match.score}, "
f"Match: {scored_match.match}")
print("\n")
# if doc_id == '3868508' and figure == 'Figure_1-A':
# tqdm.write(f"Scored NER Match: {annotation}")
# tqdm.write(f"Annotated Text Segment: "
# f"{text[annotation.start:annotation.end]} at "
# f"indices {annotation.start} to {annotation.end}")
# for i, scored_match in enumerate(annotation.matches):
# tqdm.write(f"Scored Match {i + 1}: {scored_match}")
# tqdm.write(
# f"DB: {scored_match.term.db}, "
# f"ID: {scored_match.term.id}")
# tqdm.write(
# f"Score: {scored_match.score}, "
# f"Match: {scored_match.match}")
# tqdm.write("\n")

tqdm.write("Finished annotating corpus with Gilda...")
print(f"Total Gilda annotations: {total_gilda_annotations}")
# tqdm.write(f"Total Gilda annotations: {total_gilda_annotations}")

def evaluate_gilda_performance(self):
"""Calculates precision, recall, and F1"""
Expand All @@ -346,9 +349,11 @@ def evaluate_gilda_performance(self):
row['first left'], row['last right'])
ref_dict[key].append((set(row['obj']), row['obj_synonyms']))

print(f"Total reference annotations: {len(ref_dict)}")
# print(f"Total reference annotations: {len(ref_dict)}")

for (doc_id, figure), annotations in self.gilda_annotations_map.items():
for (doc_id, figure), annotations in (
tqdm(self.gilda_annotations_map.items(),
desc="Evaluating Annotations")):
for annotation in annotations:
key = (doc_id, figure, annotation.text, annotation.start,
annotation.end)
Expand All @@ -366,14 +371,13 @@ def evaluate_gilda_performance(self):
match_found = True
break

if match_found:
if doc_id == '3868508' and figure == "Figure_1-A":
print(f"Gilda Annotation: {annotation}")
# print(f"Reference Annotations: {ref_annotations}")
print(f"Match Found: {match_found}")
print(f"Matching Reference: {matching_refs}")
# if match_found:
# if doc_id == '3868508' and figure == "Figure_1-A":
# print(f"Gilda Annotation: {annotation}")
# print(f"Match Found: {match_found}")
# print(f"Matching Reference: {matching_refs}")

break
# break

if match_found:
break
Expand All @@ -384,11 +388,12 @@ def evaluate_gilda_performance(self):
if annotation.matches: # Check if there are any matches
metrics['top_match']['fp'] += 1

print(f"20 Most Common False Positives: "
f"{false_positives_counter.most_common(20)}")
# print(f"20 Most Common False Positives: "
# f"{false_positives_counter.most_common(20)}")

# False negative calculation using ref dict
for key, refs in ref_dict.items():
for key, refs in tqdm(ref_dict.items(),
desc="Calculating False Negatives"):
doc_id, figure = key[0], key[1]
gilda_annotations = self.gilda_annotations_map.get((doc_id, figure),
[])
Expand Down Expand Up @@ -424,7 +429,6 @@ def evaluate_gilda_performance(self):
'f1': f1
}


counts_table = pd.DataFrame([
{
'Match Type': 'All Matches',
Expand Down Expand Up @@ -462,7 +466,8 @@ def evaluate_gilda_performance(self):
false_positives_df = pd.DataFrame(false_positives_counter.items(),
columns=['False Positive Text',
'Count'])
false_positives_df = false_positives_df.sort_values(by='Count', ascending=False)
false_positives_df = false_positives_df.sort_values(by='Count',
ascending=False)
false_positives_df.to_csv(
os.path.join(RESULTS_DIR, 'false_positives.csv'), index=False)

Expand All @@ -489,25 +494,26 @@ def get_famplex_members():
fplx_members = get_famplex_members()


# def main(results: str):
def main():
# results_path = os.path.expandvars(os.path.expanduser(results))
# os.makedirs(results_path, exist_ok=True)
def main(results: str = RESULTS_DIR):
results_path = os.path.expandvars(os.path.expanduser(results))
os.makedirs(results_path, exist_ok=True)

benchmarker = BioIDNERBenchmarker()
benchmarker.annotate_entities_with_gilda()
benchmarker.evaluate_gilda_performance()
counts, precision_recall = benchmarker.get_results_tables()
print("Counts table:")

print(f"Counts Table:")
print(counts.to_markdown(index=False))
print("Precision and Recall table:")
print(f"Precision and Recall table: ")
print(precision_recall.to_markdown(index=False))
# time = datetime.now().strftime('%y%m%d-%H%M%S')
# result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}')
# counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False)
# precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"),
# index=False)
# print(f'Results saved to {results_path}')

time = datetime.now().strftime('%y%m%d-%H%M%S')
result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}')
counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False)
precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"),
index=False)
print(f'Results saved to {results_path}')


if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions gilda/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def annotate(
spaces = ' ' * (c[0] - len(raw_span) -
raw_word_coords[idx][0])
raw_span += spaces + rw
# if len(txt_span) <= 1:
# continue

if len(raw_span) <= 1:
continue
context = text if context_text is None else context_text
matches = grounder.ground(raw_span,
context=context,
Expand Down

0 comments on commit 2bc9cf7

Please sign in to comment.