Skip to content

Commit

Permalink
added a method to introduce batching with merged clusters #3
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidFromPandora committed Apr 3, 2022
1 parent 29faf81 commit b6322dd
Show file tree
Hide file tree
Showing 10 changed files with 1,609 additions and 253 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,5 @@ dmypy.json
*.tar.gz
/crosslingual_coreference/models
test.py
/batching.ipynb
/test.py
69 changes: 49 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,74 @@ pip install crosslingual-coreference
```python
from crosslingual_coreference import Predictor

text = "Do not forget about Momofuku Ando! He created instant noodles in Osaka. At that location, Nissin was founded. Many students survived by eating these noodles, but they don't even know him."
text = (
"Do not forget about Momofuku Ando! He created instant noodles in Osaka. At"
" that location, Nissin was founded. Many students survived by eating these"
" noodles, but they don't even know him."
)

predictor = Predictor(language="en_core_web_sm", device=-1, model_name="info_xlm")
predictor = Predictor(
language="en_core_web_sm", device=-1, model_name="info_xlm"
)

print(predictor.predict(text)["resolved_text"])
# Output
#
# Do not forget about Momofuku Ando!
# Momofuku Ando created instant noodles in Osaka.
# At Osaka, Nissin was founded.
# Many students survived by eating instant noodles,
#
# Do not forget about Momofuku Ando!
# Momofuku Ando created instant noodles in Osaka.
# At Osaka, Nissin was founded.
# Many students survived by eating instant noodles,
# but Many students don't even know Momofuku Ando.
```
![](https://raw.githubusercontent.com/Pandora-Intelligence/crosslingual-coreference/master/img/example_en.png)

## Chunking/batching to resolve memory OOM errors

```python
from crosslingual_coreference import Predictor

predictor = Predictor(
language="en_core_web_sm",
device=0,
model_name="info_xlm",
chunk_size=2500,
chunk_overlap=2,
)
```

## Use spaCy pipeline
```python
import crosslingual_coreference
import spacy

text = "Do not forget about Momofuku Ando! He created instant noodles in Osaka. At that location, Nissin was founded. Many students survived by eating these noodles, but they don't even know him."
import crosslingual_coreference

text = (
"Do not forget about Momofuku Ando! He created instant noodles in Osaka. At"
" that location, Nissin was founded. Many students survived by eating these"
" noodles, but they don't even know him."
)


nlp = spacy.load('en_core_web_sm')
nlp.add_pipe('xx_coref')
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe(
"xx_coref", config={"chunk_size": 2500, "chunk_overlap": 2, "device": 0}
)

doc = nlp(text)
print(doc._.coref_clusters)
# Output
#
# [[[4, 5], [7, 7], [27, 27], [36, 36]],
# [[12, 12], [15, 16]],
# [[9, 10], [27, 28]],
#
# [[[4, 5], [7, 7], [27, 27], [36, 36]],
# [[12, 12], [15, 16]],
# [[9, 10], [27, 28]],
# [[22, 23], [31, 31]]]
print(doc._.resolved_text)
# Output
#
# Do not forget about Momofuku Ando!
# Momofuku Ando created instant noodles in Osaka.
# At Osaka, Nissin was founded.
# Many students survived by eating instant noodles,
#
# Do not forget about Momofuku Ando!
# Momofuku Ando created instant noodles in Osaka.
# At Osaka, Nissin was founded.
# Many students survived by eating instant noodles,
# but Many students don't even know Momofuku Ando.
```
## Available models
Expand Down
57 changes: 40 additions & 17 deletions crosslingual_coreference/CorefResolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,27 @@


class CorefResolver(object):
""" a class that implements the logic from
https://towardsdatascience.com/how-to-make-an-effective-coreference-resolution-model-55875d2b5f19"""
"""a class that implements the logic from
https://towardsdatascience.com/how-to-make-an-effective-coreference-resolution-model-55875d2b5f19"""

def __init__(self) -> None:
pass

@staticmethod
def core_logic_part(document: Doc, coref: List[int], resolved: List[str], mention_span: Span):

@staticmethod
def core_logic_part(
document: Doc, coref: List[int], resolved: List[str], mention_span: Span
):
final_token = document[coref[1]]
final_token_tag = str(final_token.tag_).lower()
test_token_test = False
for option in ["PRP$", "POS", 'BEZ']:
for option in ["PRP$", "POS", "BEZ"]:
if option.lower() in final_token_tag:
test_token_test = True
break
if test_token_test:
resolved[coref[0]] = mention_span.text + "'s" + final_token.whitespace_
resolved[coref[0]] = (
mention_span.text + "'s" + final_token.whitespace_
)
else:
resolved[coref[0]] = mention_span.text + final_token.whitespace_
for i in range(coref[0] + 1, coref[1] + 1):
Expand All @@ -29,34 +34,52 @@ def core_logic_part(document: Doc, coref: List[int], resolved: List[str], mentio

@staticmethod
def get_span_noun_indices(doc: Doc, cluster: List[List[int]]) -> List[int]:
spans = [doc[span[0]:span[1]+1] for span in cluster]
spans = [doc[span[0] : span[1] + 1] for span in cluster]
spans_pos = [[token.pos_ for token in span] for span in spans]
span_noun_indices = [i for i, span_pos in enumerate(spans_pos)
if any(pos in span_pos for pos in ['NOUN', 'PROPN'])]
span_noun_indices = [
i
for i, span_pos in enumerate(spans_pos)
if any(pos in span_pos for pos in ["NOUN", "PROPN"])
]
return span_noun_indices

@staticmethod
def get_cluster_head(doc: Doc, cluster: List[List[int]], noun_indices: List[int]):
def get_cluster_head(
doc: Doc, cluster: List[List[int]], noun_indices: List[int]
):
head_idx = noun_indices[0]
head_start, head_end = cluster[head_idx]
head_span = doc[head_start:head_end+1]
head_span = doc[head_start : head_end + 1]
return head_span, [head_start, head_end]

@staticmethod
def is_containing_other_spans(span: List[int], all_spans: List[List[int]]):
return any([s[0] >= span[0] and s[1] <= span[1] and s != span for s in all_spans])
return any(
[
s[0] >= span[0] and s[1] <= span[1] and s != span
for s in all_spans
]
)

def replace_corefs(self, document, clusters):
resolved = list(tok.text_with_ws for tok in document)
all_spans = [span for cluster in clusters for span in cluster] # flattened list of all spans
all_spans = [
span for cluster in clusters for span in cluster
] # flattened list of all spans

for cluster in clusters:
noun_indices = self.get_span_noun_indices(document, cluster)

if noun_indices:
mention_span, mention = self.get_cluster_head(document, cluster, noun_indices)
mention_span, mention = self.get_cluster_head(
document, cluster, noun_indices
)

for coref in cluster:
if coref != mention and not self.is_containing_other_spans(coref, all_spans):
self.core_logic_part(document, coref, resolved, mention_span)
if coref != mention and not self.is_containing_other_spans(
coref, all_spans
):
self.core_logic_part(
document, coref, resolved, mention_span
)
return "".join(resolved)
Loading

0 comments on commit b6322dd

Please sign in to comment.