-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/phrase grounding recall filter (#139)
* + add phrase_grounding_recall_filter op * * update new OP + make llava_to_dj and dj_to_llava support only_keep_caption mode * + Add unit test for phrase_grounding * + remove hf model caches automatically for unittest * * download required nltk data when initializing the phrase_grounding_recall_filter * * output the cleaning log when the cleaning actually happens * * update Operator docs * * fix some typos * * removing hf models automatically after unit test is finished for clip and blip
- Loading branch information
Showing
15 changed files
with
957 additions
and
154 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
286 changes: 286 additions & 0 deletions
286
data_juicer/ops/filter/phrase_grounding_recall_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
from jsonargparse.typing import ClosedUnitInterval | ||
from loguru import logger | ||
from PIL import ImageOps | ||
|
||
from data_juicer.utils.availability_utils import AvailabilityChecking | ||
from data_juicer.utils.constant import Fields, StatsKeys | ||
from data_juicer.utils.mm_utils import (SpecialTokens, iou, load_image, | ||
remove_special_tokens) | ||
from data_juicer.utils.model_utils import get_model, prepare_model | ||
|
||
from ..base_op import OPERATORS, Filter | ||
from ..op_fusion import LOADED_IMAGES | ||
|
||
OP_NAME = 'phrase_grounding_recall_filter' | ||
|
||
with AvailabilityChecking(['torch', 'transformers', 'nltk'], OP_NAME): | ||
|
||
import torch | ||
import transformers # noqa: F401 | ||
|
||
# avoid hanging when calling clip in multiprocessing | ||
torch.set_num_threads(1) | ||
|
||
import nltk | ||
|
||
|
||
# NER algorithm adapted from GLIP starts | ||
# https://github.com/microsoft/GLIP/blob/main/maskrcnn_benchmark/engine/predictor_glip.py#L107-L127 | ||
def find_noun_phrases(caption: str) -> List[str]: | ||
caption = caption.lower() | ||
tokens = nltk.word_tokenize(caption) | ||
pos_tags = nltk.pos_tag(tokens) | ||
|
||
grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}' | ||
cp = nltk.RegexpParser(grammar) | ||
result = cp.parse(pos_tags) | ||
|
||
noun_phrases = list() | ||
for subtree in result.subtrees(): | ||
if subtree.label() == 'NP': | ||
noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) | ||
|
||
return noun_phrases | ||
|
||
|
||
def remove_punctuation(text: str) -> str: | ||
punct = [ | ||
'|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’', | ||
'`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' | ||
] | ||
for p in punct: | ||
text = text.replace(p, '') | ||
return text.strip() | ||
|
||
|
||
def run_ner(caption): | ||
noun_phrases = find_noun_phrases(caption) | ||
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] | ||
noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] | ||
noun_phrases = list(set(noun_phrases)) # remove duplicate ners | ||
return noun_phrases | ||
|
||
|
||
# NER algorithm adapted from GLIP ends | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
@LOADED_IMAGES.register_module(OP_NAME) | ||
class PhraseGroundingRecallFilter(Filter): | ||
"""Filter to keep samples whose locating recalls of phrases extracted | ||
from text in the images are within a specified range.""" | ||
|
||
def __init__(self, | ||
hf_owlvit='google/owlvit-base-patch32', | ||
min_recall: ClosedUnitInterval = 0.1, | ||
max_recall: ClosedUnitInterval = 1.0, | ||
horizontal_flip: bool = False, | ||
vertical_flip: bool = False, | ||
any_or_all: str = 'any', | ||
reduce_mode: str = 'avg', | ||
iou_thr: ClosedUnitInterval = 0.5, | ||
large_area_ratio_thr: ClosedUnitInterval = 0.95, | ||
conf_thr: ClosedUnitInterval = 0.0, | ||
*args, | ||
**kwargs): | ||
""" | ||
Initialization method. | ||
:param hf_owlvit: Owl-ViT model name on huggingface to locate the | ||
phrases extracted from the text. | ||
:param min_recall: The min phrase grounding recall to keep samples. | ||
:param max_recall: The max phrase grounding recall to keep samples. | ||
:param horizontal_flip: Flip image horizontally (left to right). | ||
:param vertical_flip: Flip image vertically (top to bottom). | ||
:param any_or_all: keep this sample with 'any' or 'all' strategy of | ||
all images. 'any': keep this sample if any images meet the | ||
condition. 'all': keep this sample only if all images meet the | ||
condition. | ||
:param reduce_mode: reduce mode when one text corresponds to | ||
multiple images in a chunk. | ||
'avg': Take the average of multiple values | ||
'max': Take the max of multiple values | ||
'min': Take the min of multiple values | ||
:param iou_thr: the IoU threshold for NMS-like post-process. If two | ||
predicted bboxes are overlap with an IoU larger than this | ||
threshold, the bbox with less confidence will be removed. Default: | ||
0.5. | ||
:param large_area_ratio_thr: the area ratio threshold for filtering out | ||
those large predicted bboxes. If the area of a predicted bbox | ||
accounts for more than this ratio threshold of the whole image | ||
area, this bbox will be removed. Default: 0.95. | ||
:param conf_thr: the confidence score threshold for removing | ||
low-confidence bboxes. If the confidence score of a predicted bbox | ||
is lower than the threshold, this bbox will be removed. Default: 0. | ||
:param args: extra args | ||
:param kwargs: extra args | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self.min_recall = min_recall | ||
self.max_recall = max_recall | ||
if reduce_mode not in ['avg', 'max', 'min']: | ||
raise ValueError(f'Reduce mode [{reduce_mode}] is not supported. ' | ||
f'Can only be one of ["avg", "max", "min"].') | ||
if any_or_all not in ['any', 'all']: | ||
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' | ||
f'Can only be one of ["any", "all"].') | ||
self.any = (any_or_all == 'any') | ||
self.model_type = 'hf_owlvit' | ||
self.model_key = prepare_model(model_type=self.model_type, | ||
model_key=hf_owlvit) | ||
self.reduce_mode = reduce_mode | ||
self.horizontal_flip = horizontal_flip | ||
self.vertical_flip = vertical_flip | ||
|
||
self.iou_thr = iou_thr | ||
self.large_area_ratio_thr = large_area_ratio_thr | ||
self.conf_thr = conf_thr | ||
|
||
requires_nltk_data = ['punkt', 'averaged_perceptron_tagger'] | ||
logger.info(f'Downloading nltk data of {requires_nltk_data}...') | ||
for nltk_data_pkg in requires_nltk_data: | ||
nltk.download(nltk_data_pkg) | ||
|
||
def compute_stats(self, sample, context=False): | ||
# check if it's computed already | ||
if StatsKeys.phrase_grounding_recall in sample[Fields.stats]: | ||
return sample | ||
|
||
# there is no image in this sample | ||
if self.image_key not in sample or not sample[self.image_key]: | ||
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = np.array( | ||
[], dtype=np.float64) | ||
return sample | ||
|
||
# load images | ||
loaded_image_keys = sample[self.image_key] | ||
images = {} | ||
for loaded_image_key in loaded_image_keys: | ||
if context and loaded_image_key in sample[Fields.context]: | ||
# load from context | ||
images[loaded_image_key] = sample[ | ||
Fields.context][loaded_image_key] | ||
else: | ||
if loaded_image_key not in images: | ||
# avoid load the same images | ||
image = load_image(loaded_image_key) | ||
images[loaded_image_key] = image | ||
if context: | ||
# store the image data into context | ||
sample[Fields.context][loaded_image_key] = image | ||
|
||
text = sample[self.text_key] | ||
offset = 0 | ||
recalls = [] | ||
model, processor = get_model(self.model_key, | ||
model_type=self.model_type) | ||
|
||
for chunk in text.split(SpecialTokens.eoc): | ||
count = chunk.count(SpecialTokens.image) | ||
|
||
# no image or no text | ||
if count == 0 or len(chunk) == 0: | ||
continue | ||
else: | ||
text_this_chunk = remove_special_tokens(chunk) | ||
ners_this_chunk = run_ner(text_this_chunk) | ||
num_ners = len(ners_this_chunk) | ||
if num_ners <= 0: | ||
# no ners found, just skip this chunk | ||
recalls.append(1.0) | ||
continue | ||
images_this_chunk = [] | ||
for image_key in loaded_image_keys[offset:offset + count]: | ||
image = images[image_key] | ||
if self.horizontal_flip: | ||
image = ImageOps.mirror(image) | ||
if self.vertical_flip: | ||
image = ImageOps.flip(image) | ||
images_this_chunk.append(image) | ||
|
||
ners_batch = [ners_this_chunk] * len(images_this_chunk) | ||
inputs = processor(text=ners_batch, | ||
images=images_this_chunk, | ||
return_tensors='pt', | ||
padding=True, | ||
truncation=True) | ||
|
||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
target_sizes = torch.tensor( | ||
[img.size[::-1] for img in images_this_chunk]) | ||
results = processor.post_process_object_detection( | ||
outputs, | ||
threshold=self.conf_thr, | ||
target_sizes=target_sizes) | ||
|
||
image_recalls = [] | ||
for idx, result in enumerate(results): | ||
scores = result['scores'] | ||
labels = result['labels'] | ||
boxes = result['boxes'] | ||
|
||
# sort by the confidence scores | ||
# and only keep the first num_ners predictions | ||
order_idx = scores.argsort(descending=True) | ||
scores = scores[order_idx].tolist()[:num_ners] | ||
labels = labels[order_idx].tolist()[:num_ners] | ||
boxes = boxes[order_idx].tolist()[:num_ners] | ||
|
||
image_area = target_sizes[idx].prod() | ||
hit = {} | ||
for box, label, score in zip(boxes, labels, scores): | ||
# this ner is already hit | ||
if ners_this_chunk[label] in hit: | ||
continue | ||
# skip boxes nearly cover the whole image | ||
xmin, ymin, xmax, ymax = box | ||
box_area = (xmax - xmin) * (ymax - ymin) | ||
if 1.0 * box_area / image_area > \ | ||
self.large_area_ratio_thr: | ||
continue | ||
# skip overlapped boxes with nms-like method | ||
suppressed = False | ||
for ner in hit: | ||
if iou(box, hit[ner][0]) > self.iou_thr: | ||
suppressed = True | ||
break | ||
if suppressed: | ||
continue | ||
|
||
# record the new hit box | ||
hit[ners_this_chunk[label]] = (box, score) | ||
|
||
recall = 1.0 * len(hit) / num_ners | ||
image_recalls.append(recall) | ||
|
||
if self.reduce_mode == 'avg': | ||
image_recall = sum(image_recalls) / len(image_recalls) | ||
elif self.reduce_mode == 'max': | ||
image_recall = max(image_recalls) | ||
else: | ||
image_recall = min(image_recalls) | ||
|
||
recalls.append(image_recall) | ||
offset += count | ||
sample[Fields.stats][StatsKeys.phrase_grounding_recall] = recalls | ||
|
||
return sample | ||
|
||
def process(self, sample): | ||
recalls = sample[Fields.stats][StatsKeys.phrase_grounding_recall] | ||
if len(recalls) <= 0: | ||
return True | ||
|
||
keep_bools = np.array([ | ||
self.min_recall <= recall <= self.max_recall for recall in recalls | ||
]) | ||
|
||
# different strategies | ||
if self.any: | ||
return keep_bools.any() | ||
else: | ||
return keep_bools.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.