Skip to content

Commit

Permalink
OmDet Turbo processor standardization (huggingface#34937)
Browse files Browse the repository at this point in the history
* Fix docstring

* Fix docstring

* Add `classes_structure` to model output

* Update omdet postprocessing

* Adjust tests

* Update code example in docs

* Add deprecation to "classes" key in output

* Types, docs

* Fixing test

* Fix missed clip_boxes

* [run-slow] omdet_turbo

* Apply suggestions from code review

Co-authored-by: Yoni Gozlan <[email protected]>

* Make CamelCase class

---------

Co-authored-by: Yoni Gozlan <[email protected]>
  • Loading branch information
qubvel and yonigozlan authored Jan 17, 2025
1 parent 94ae9a8 commit 42b2857
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 188 deletions.
87 changes: 45 additions & 42 deletions docs/source/en/model_doc/omdet-turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,40 @@ One unique property of OmDet-Turbo compared to other zero-shot object detection
Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image:

```python
import requests
from PIL import Image

from transformers import AutoProcessor, OmDetTurboForObjectDetection

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
classes = ["cat", "remote"]
inputs = processor(image, text=classes, return_tensors="pt")

outputs = model(**inputs)

# convert outputs (bounding boxes and class logits)
results = processor.post_process_grounded_object_detection(
outputs,
classes=classes,
target_sizes=[image.size[::-1]],
score_threshold=0.3,
nms_threshold=0.3,
)[0]
for score, class_name, box in zip(
results["scores"], results["classes"], results["boxes"]
):
box = [round(i, 1) for i in box.tolist()]
print(
f"Detected {class_name} with confidence "
f"{round(score.item(), 2)} at location {box}"
)
>>> import torch
>>> import requests
>>> from PIL import Image

>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection

>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text_labels = ["cat", "remote"]
>>> inputs = processor(image, text=text_labels, return_tensors="pt")

>>> with torch.no_grad():
... outputs = model(**inputs)

>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... target_sizes=[(image.height, image.width)],
... text_labels=text_labels,
... threshold=0.3,
... nms_threshold=0.3,
... )
>>> result = results[0]
>>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
>>> for box, score, text_label in zip(boxes, scores, text_labels):
... box = [round(i, 2) for i in box.tolist()]
... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
Detected remote with confidence 0.768 at location [39.89, 70.35, 176.74, 118.04]
Detected cat with confidence 0.72 at location [11.6, 54.19, 314.8, 473.95]
Detected remote with confidence 0.563 at location [333.38, 75.77, 370.7, 187.03]
Detected cat with confidence 0.552 at location [345.15, 23.95, 639.75, 371.67]
```

### Multi image inference
Expand All @@ -93,22 +96,22 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen

>>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB")
>>> classes1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(classes1))
>>> text_labels1 = ["cat", "remote"]
>>> task1 = "Detect {}.".format(", ".join(text_labels1))

>>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg"
>>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB")
>>> classes2 = ["boat"]
>>> text_labels2 = ["boat"]
>>> task2 = "Detect everything that looks like a boat."

>>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
>>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB")
>>> classes3 = ["statue", "trees"]
>>> text_labels3 = ["statue", "trees"]
>>> task3 = "Focus on the foreground, detect statue and trees."

>>> inputs = processor(
... images=[image1, image2, image3],
... text=[classes1, classes2, classes3],
... text=[text_labels1, text_labels2, text_labels3],
... task=[task1, task2, task3],
... return_tensors="pt",
... )
Expand All @@ -119,19 +122,19 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen
>>> # convert outputs (bounding boxes and class logits)
>>> results = processor.post_process_grounded_object_detection(
... outputs,
... classes=[classes1, classes2, classes3],
... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]],
... score_threshold=0.2,
... text_labels=[text_labels1, text_labels2, text_labels3],
... target_sizes=[(image.height, image.width) for image in [image1, image2, image3]],
... threshold=0.2,
... nms_threshold=0.3,
... )

>>> for i, result in enumerate(results):
... for score, class_name, box in zip(
... result["scores"], result["classes"], result["boxes"]
... for score, text_label, box in zip(
... result["scores"], result["text_labels"], result["boxes"]
... ):
... box = [round(i, 1) for i in box.tolist()]
... print(
... f"Detected {class_name} with confidence "
... f"Detected {text_label} with confidence "
... f"{round(score.item(), 2)} at location {box} in image {i}"
... )
Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0
Expand Down
33 changes: 19 additions & 14 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,24 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
The predicted class of the objects from the encoder.
encoder_extracted_states (`torch.FloatTensor`):
The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder.
decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`):
encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`):
encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*):
Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the
weighted average in the self-attention, cross-attention and multi-scale deformable attention heads.
classes_structure (`torch.LongTensor`, *optional*):
The number of queried classes for each image.
"""

loss: torch.FloatTensor = None
Expand All @@ -173,6 +175,7 @@ class OmDetTurboObjectDetectionOutput(ModelOutput):
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
classes_structure: Optional[torch.LongTensor] = None


# Copied from models.deformable_detr.load_cuda_kernels
Expand Down Expand Up @@ -1667,16 +1670,16 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
@replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
classes_input_ids: Tensor,
classes_attention_mask: Tensor,
tasks_input_ids: Tensor,
tasks_attention_mask: Tensor,
classes_structure: Tensor,
labels: Optional[Tensor] = None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pixel_values: torch.FloatTensor,
classes_input_ids: torch.LongTensor,
classes_attention_mask: torch.LongTensor,
tasks_input_ids: torch.LongTensor,
tasks_attention_mask: torch.LongTensor,
classes_structure: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1770,6 +1773,7 @@ def forward(
decoder_outputs[2],
encoder_outputs[1],
encoder_outputs[2],
classes_structure,
]
if output is not None
)
Expand All @@ -1787,6 +1791,7 @@ def forward(
decoder_attentions=decoder_outputs.attentions,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
classes_structure=classes_structure,
)


Expand Down
Loading

0 comments on commit 42b2857

Please sign in to comment.