diff --git a/docs/source/en/model_doc/omdet-turbo.md b/docs/source/en/model_doc/omdet-turbo.md index 1e9e05a898d230..91419919b6e02c 100644 --- a/docs/source/en/model_doc/omdet-turbo.md +++ b/docs/source/en/model_doc/omdet-turbo.md @@ -44,37 +44,40 @@ One unique property of OmDet-Turbo compared to other zero-shot object detection Here's how to load the model and prepare the inputs to perform zero-shot object detection on a single image: ```python -import requests -from PIL import Image - -from transformers import AutoProcessor, OmDetTurboForObjectDetection - -processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") -model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") - -url = "http://images.cocodataset.org/val2017/000000039769.jpg" -image = Image.open(requests.get(url, stream=True).raw) -classes = ["cat", "remote"] -inputs = processor(image, text=classes, return_tensors="pt") - -outputs = model(**inputs) - -# convert outputs (bounding boxes and class logits) -results = processor.post_process_grounded_object_detection( - outputs, - classes=classes, - target_sizes=[image.size[::-1]], - score_threshold=0.3, - nms_threshold=0.3, -)[0] -for score, class_name, box in zip( - results["scores"], results["classes"], results["boxes"] -): - box = [round(i, 1) for i in box.tolist()] - print( - f"Detected {class_name} with confidence " - f"{round(score.item(), 2)} at location {box}" - ) +>>> import torch +>>> import requests +>>> from PIL import Image + +>>> from transformers import AutoProcessor, OmDetTurboForObjectDetection + +>>> processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") +>>> model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") + +>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" +>>> image = Image.open(requests.get(url, stream=True).raw) +>>> text_labels = ["cat", "remote"] +>>> inputs = processor(image, text=text_labels, return_tensors="pt") + +>>> with torch.no_grad(): +... outputs = model(**inputs) + +>>> # convert outputs (bounding boxes and class logits) +>>> results = processor.post_process_grounded_object_detection( +... outputs, +... target_sizes=[(image.height, image.width)], +... text_labels=text_labels, +... threshold=0.3, +... nms_threshold=0.3, +... ) +>>> result = results[0] +>>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"] +>>> for box, score, text_label in zip(boxes, scores, text_labels): +... box = [round(i, 2) for i in box.tolist()] +... print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}") +Detected remote with confidence 0.768 at location [39.89, 70.35, 176.74, 118.04] +Detected cat with confidence 0.72 at location [11.6, 54.19, 314.8, 473.95] +Detected remote with confidence 0.563 at location [333.38, 75.77, 370.7, 187.03] +Detected cat with confidence 0.552 at location [345.15, 23.95, 639.75, 371.67] ``` ### Multi image inference @@ -93,22 +96,22 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen >>> url1 = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image1 = Image.open(BytesIO(requests.get(url1).content)).convert("RGB") ->>> classes1 = ["cat", "remote"] ->>> task1 = "Detect {}.".format(", ".join(classes1)) +>>> text_labels1 = ["cat", "remote"] +>>> task1 = "Detect {}.".format(", ".join(text_labels1)) >>> url2 = "http://images.cocodataset.org/train2017/000000257813.jpg" >>> image2 = Image.open(BytesIO(requests.get(url2).content)).convert("RGB") ->>> classes2 = ["boat"] +>>> text_labels2 = ["boat"] >>> task2 = "Detect everything that looks like a boat." >>> url3 = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" >>> image3 = Image.open(BytesIO(requests.get(url3).content)).convert("RGB") ->>> classes3 = ["statue", "trees"] +>>> text_labels3 = ["statue", "trees"] >>> task3 = "Focus on the foreground, detect statue and trees." >>> inputs = processor( ... images=[image1, image2, image3], -... text=[classes1, classes2, classes3], +... text=[text_labels1, text_labels2, text_labels3], ... task=[task1, task2, task3], ... return_tensors="pt", ... ) @@ -119,19 +122,19 @@ OmDet-Turbo can perform batched multi-image inference, with support for differen >>> # convert outputs (bounding boxes and class logits) >>> results = processor.post_process_grounded_object_detection( ... outputs, -... classes=[classes1, classes2, classes3], -... target_sizes=[image1.size[::-1], image2.size[::-1], image3.size[::-1]], -... score_threshold=0.2, +... text_labels=[text_labels1, text_labels2, text_labels3], +... target_sizes=[(image.height, image.width) for image in [image1, image2, image3]], +... threshold=0.2, ... nms_threshold=0.3, ... ) >>> for i, result in enumerate(results): -... for score, class_name, box in zip( -... result["scores"], result["classes"], result["boxes"] +... for score, text_label, box in zip( +... result["scores"], result["text_labels"], result["boxes"] ... ): ... box = [round(i, 1) for i in box.tolist()] ... print( -... f"Detected {class_name} with confidence " +... f"Detected {text_label} with confidence " ... f"{round(score.item(), 2)} at location {box} in image {i}" ... ) Detected remote with confidence 0.77 at location [39.9, 70.4, 176.7, 118.0] in image 0 diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 2680bc714d9980..15d7c1f05916ac 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -143,22 +143,24 @@ class OmDetTurboObjectDetectionOutput(ModelOutput): The predicted class of the objects from the encoder. encoder_extracted_states (`torch.FloatTensor`): The extracted states from the Feature Pyramid Network (FPN) and Path Aggregation Network (PAN) of the encoder. - decoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`): + decoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - decoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`): + decoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*): Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. - encoder_hidden_states (`Optional[Tuple[torch.FloatTensor]]`): + encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. - encoder_attentions (`Optional[Tuple[Tuple[torch.FloatTensor]]]`): + encoder_attentions (`Tuple[Tuple[torch.FloatTensor]]`, *optional*): Tuple of tuples of `torch.FloatTensor` (one for attention for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention, cross-attention and multi-scale deformable attention heads. + classes_structure (`torch.LongTensor`, *optional*): + The number of queried classes for each image. """ loss: torch.FloatTensor = None @@ -173,6 +175,7 @@ class OmDetTurboObjectDetectionOutput(ModelOutput): decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + classes_structure: Optional[torch.LongTensor] = None # Copied from models.deformable_detr.load_cuda_kernels @@ -1667,16 +1670,16 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m @replace_return_docstrings(output_type=OmDetTurboObjectDetectionOutput, config_class=_CONFIG_FOR_DOC) def forward( self, - pixel_values: Tensor, - classes_input_ids: Tensor, - classes_attention_mask: Tensor, - tasks_input_ids: Tensor, - tasks_attention_mask: Tensor, - classes_structure: Tensor, - labels: Optional[Tensor] = None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + pixel_values: torch.FloatTensor, + classes_input_ids: torch.LongTensor, + classes_attention_mask: torch.LongTensor, + tasks_input_ids: torch.LongTensor, + tasks_attention_mask: torch.LongTensor, + classes_structure: torch.LongTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.FloatTensor], OmDetTurboObjectDetectionOutput]: r""" Returns: @@ -1770,6 +1773,7 @@ def forward( decoder_outputs[2], encoder_outputs[1], encoder_outputs[2], + classes_structure, ] if output is not None ) @@ -1787,6 +1791,7 @@ def forward( decoder_attentions=decoder_outputs.attentions, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, + classes_structure=classes_structure, ) diff --git a/src/transformers/models/omdet_turbo/processing_omdet_turbo.py b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py index 7e8d0af8a10d16..f52840e1d0b662 100644 --- a/src/transformers/models/omdet_turbo/processing_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/processing_omdet_turbo.py @@ -16,7 +16,8 @@ Processor class for OmDet-Turbo. """ -from typing import List, Optional, Tuple, Union +import warnings +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from ...feature_extraction_utils import BatchFeature from ...image_transforms import center_to_corners_format @@ -28,12 +29,25 @@ is_torch_available, is_torchvision_available, ) +from ...utils.deprecation import deprecate_kwarg + + +if TYPE_CHECKING: + from .modeling_omdet_turbo import OmDetTurboObjectDetectionOutput class OmDetTurboTextKwargs(TextKwargs, total=False): task: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] +if is_torch_available(): + import torch + + +if is_torchvision_available(): + from torchvision.ops.boxes import batched_nms + + class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False): text_kwargs: OmDetTurboTextKwargs _defaults = { @@ -55,11 +69,23 @@ class OmDetTurboProcessorKwargs(ProcessingKwargs, total=False): } -if is_torch_available(): - import torch +class DictWithDeprecationWarning(dict): + message = ( + "The `classes` key is deprecated for `OmDetTurboProcessor.post_process_grounded_object_detection` " + "output dict and will be removed in a 4.51.0 version. Please use `text_labels` instead." + ) -if is_torchvision_available(): - from torchvision.ops.boxes import batched_nms + def __getitem__(self, key): + if key == "classes": + warnings.warn(self.message, FutureWarning) + return super().__getitem__("text_labels") + return super().__getitem__(key) + + def get(self, key, *args, **kwargs): + if key == "classes": + warnings.warn(self.message, FutureWarning) + return super().get("text_labels", *args, **kwargs) + return super().get(key, *args, **kwargs) def clip_boxes(box, box_size: Tuple[int, int]): @@ -97,76 +123,80 @@ def compute_score(boxes): def _post_process_boxes_for_image( - boxes: TensorType, - scores: TensorType, - predicted_classes: TensorType, - classes: List[str], + boxes: "torch.Tensor", + scores: "torch.Tensor", + labels: "torch.Tensor", + image_num_classes: int, image_size: Tuple[int, int], - num_classes: int, - score_threshold: float, + threshold: float, nms_threshold: float, - max_num_det: int = None, -) -> dict: + max_num_det: Optional[int] = None, +) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: """ Filter predicted results using given thresholds and NMS. + Args: - boxes (torch.Tensor): A Tensor of predicted class-specific or class-agnostic - boxes for the image. Shape : (num_queries, max_num_classes_in_batch * 4) if doing - class-specific regression, or (num_queries, 4) if doing class-agnostic - regression. - scores (torch.Tensor): A Tensor of predicted class scores for the image. - Shape : (num_queries, max_num_classes_in_batch + 1) - predicted_classes (torch.Tensor): A Tensor of predicted classes for the image. - Shape : (num_queries * (max_num_classes_in_batch + 1),) - classes (List[str]): The input classes names. - image_size (tuple): A tuple of (height, width) for the image. - num_classes (int): The number of classes given for this image. - score_threshold (float): Only return detections with a confidence score exceeding this - threshold. - nms_threshold (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. - max_num_det (int, optional): The maximum number of detections to return. Default is None. + boxes (`torch.Tensor`): + A Tensor of predicted class-specific or class-agnostic boxes for the image. + Shape (num_queries, max_num_classes_in_batch * 4) if doing class-specific regression, + or (num_queries, 4) if doing class-agnostic regression. + scores (`torch.Tensor` of shape (num_queries, max_num_classes_in_batch + 1)): + A Tensor of predicted class scores for the image. + labels (`torch.Tensor` of shape (num_queries * (max_num_classes_in_batch + 1),)): + A Tensor of predicted labels for the image. + image_num_classes (`int`): + The number of classes queried for detection on the image. + image_size (`Tuple[int, int]`): + A tuple of (height, width) for the image. + threshold (`float`): + Only return detections with a confidence score exceeding this threshold. + nms_threshold (`float`): + The threshold to use for box non-maximum suppression. Value in [0, 1]. + max_num_det (`int`, *optional*): + The maximum number of detections to return. Default is None. + Returns: - dict: A dictionary the following keys: + Tuple: A tuple with the following: "boxes" (Tensor): A tensor of shape (num_filtered_objects, 4), containing the predicted boxes in (x1, y1, x2, y2) format. "scores" (Tensor): A tensor of shape (num_filtered_objects,), containing the predicted confidence scores for each detection. - "classes" (List[str]): A list of strings, where each string is the predicted class for the - corresponding detection + "labels" (Tensor): A tensor of ids, where each id is the predicted class id for the corresponding detection """ + + # Filter by max number of detections proposal_num = len(boxes) if max_num_det is None else max_num_det scores_per_image, topk_indices = scores.flatten(0, 1).topk(proposal_num, sorted=False) - classes_per_image = predicted_classes[topk_indices] - box_pred_per_image = boxes.view(-1, 1, 4).repeat(1, num_classes, 1).view(-1, 4) - box_pred_per_image = box_pred_per_image[topk_indices] - - # Score filtering - box_pred_per_image = center_to_corners_format(box_pred_per_image) - box_pred_per_image = box_pred_per_image * torch.tensor(image_size[::-1]).repeat(2).to(box_pred_per_image.device) - filter_mask = scores_per_image > score_threshold # R x K + labels_per_image = labels[topk_indices] + boxes_per_image = boxes.view(-1, 1, 4).repeat(1, scores.shape[1], 1).view(-1, 4) + boxes_per_image = boxes_per_image[topk_indices] + + # Convert and scale boxes to original image size + boxes_per_image = center_to_corners_format(boxes_per_image) + boxes_per_image = boxes_per_image * torch.tensor(image_size[::-1]).repeat(2).to(boxes_per_image.device) + + # Filtering by confidence score + filter_mask = scores_per_image > threshold # R x K score_keep = filter_mask.nonzero(as_tuple=False).view(-1) - box_pred_per_image = box_pred_per_image[score_keep] + boxes_per_image = boxes_per_image[score_keep] scores_per_image = scores_per_image[score_keep] - classes_per_image = classes_per_image[score_keep] + labels_per_image = labels_per_image[score_keep] - filter_classes_mask = classes_per_image < len(classes) + # Ensure we did not overflow to non existing classes + filter_classes_mask = labels_per_image < image_num_classes classes_keep = filter_classes_mask.nonzero(as_tuple=False).view(-1) - box_pred_per_image = box_pred_per_image[classes_keep] + boxes_per_image = boxes_per_image[classes_keep] scores_per_image = scores_per_image[classes_keep] - classes_per_image = classes_per_image[classes_keep] + labels_per_image = labels_per_image[classes_keep] # NMS - keep = batched_nms(box_pred_per_image, scores_per_image, classes_per_image, nms_threshold) - box_pred_per_image = box_pred_per_image[keep] + keep = batched_nms(boxes_per_image, scores_per_image, labels_per_image, nms_threshold) + boxes_per_image = boxes_per_image[keep] scores_per_image = scores_per_image[keep] - classes_per_image = classes_per_image[keep] - classes_per_image = [classes[i] for i in classes_per_image] + labels_per_image = labels_per_image[keep] - # create an instance - result = {} - result["boxes"] = clip_boxes(box_pred_per_image, image_size) - result["scores"] = scores_per_image - result["classes"] = classes_per_image + # Clip to image size + boxes_per_image = clip_boxes(boxes_per_image, image_size) - return result + return boxes_per_image, scores_per_image, labels_per_image class OmDetTurboProcessor(ProcessorMixin): @@ -274,11 +304,26 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def _get_default_image_size(self) -> Tuple[int, int]: + height = ( + self.image_processor.size["height"] + if "height" in self.image_processor.size + else self.image_processor.size["shortest_edge"] + ) + width = ( + self.image_processor.size["width"] + if "width" in self.image_processor.size + else self.image_processor.size["longest_edge"] + ) + return height, width + + @deprecate_kwarg("score_threshold", new_name="threshold", version="4.51.0") + @deprecate_kwarg("classes", new_name="text_labels", version="4.51.0") def post_process_grounded_object_detection( self, - outputs, - classes: Union[List[str], List[List[str]]], - score_threshold: float = 0.3, + outputs: "OmDetTurboObjectDetectionOutput", + text_labels: Optional[Union[List[str], List[List[str]]]] = None, + threshold: float = 0.3, nms_threshold: float = 0.5, target_sizes: Optional[Union[TensorType, List[Tuple]]] = None, max_num_det: Optional[int] = None, @@ -290,67 +335,77 @@ def post_process_grounded_object_detection( Args: outputs ([`OmDetTurboObjectDetectionOutput`]): Raw outputs of the model. - classes (Union[List[str], List[List[str]]]): The input classes names. - score_threshold (float, defaults to 0.3): Only return detections with a confidence score exceeding this - threshold. - nms_threshold (float, defaults to 0.5): The threshold to use for box non-maximum suppression. Value in [0, 1]. - target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*, defaults to None): + text_labels (Union[List[str], List[List[str]]], *optional*): + The input classes names. If not provided, `text_labels` will be set to `None` in `outputs`. + threshold (float, defaults to 0.3): + Only return detections with a confidence score exceeding this threshold. + nms_threshold (float, defaults to 0.5): + The threshold to use for box non-maximum suppression. Value in [0, 1]. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size `(height, width)` of each image in the batch. If unset, predictions will not be resized. - max_num_det (int, *optional*, defaults to None): The maximum number of detections to return. + max_num_det (`int`, *optional*): + The maximum number of detections to return. Returns: `List[Dict]`: A list of dictionaries, each dictionary containing the scores, classes and boxes for an image in the batch as predicted by the model. """ - if isinstance(classes[0], str): - classes = [classes] - boxes_logits = outputs.decoder_coord_logits - scores_logits = outputs.decoder_class_logits + batch_size = len(outputs.decoder_coord_logits) - # Inputs consistency check + # Inputs consistency check for target sizes if target_sizes is None: - height = ( - self.image_processor.size["height"] - if "height" in self.image_processor.size - else self.image_processor.size["shortest_edge"] - ) - width = ( - self.image_processor.size["width"] - if "width" in self.image_processor.size - else self.image_processor.size["longest_edge"] - ) - target_sizes = ((height, width),) * len(boxes_logits) - elif len(target_sizes[0]) != 2: + height, width = self._get_default_image_size() + target_sizes = [(height, width)] * batch_size + + if any(len(image_size) != 2 for image_size in target_sizes): raise ValueError( "Each element of target_sizes must contain the size (height, width) of each image of the batch" ) - if len(target_sizes) != len(boxes_logits): + + if len(target_sizes) != batch_size: raise ValueError("Make sure that you pass in as many target sizes as output sequences") - if len(classes) != len(boxes_logits): + + # Inputs consistency check for text labels + if text_labels is not None and isinstance(text_labels[0], str): + text_labels = [text_labels] + + if text_labels is not None and len(text_labels) != batch_size: raise ValueError("Make sure that you pass in as many classes group as output sequences") # Convert target_sizes to list for easier handling if isinstance(target_sizes, torch.Tensor): target_sizes = target_sizes.tolist() - scores, predicted_classes = compute_score(scores_logits) - num_classes = scores_logits.shape[2] + batch_boxes = outputs.decoder_coord_logits + batch_logits = outputs.decoder_class_logits + batch_num_classes = outputs.classes_structure + + batch_scores, batch_labels = compute_score(batch_logits) + results = [] - for scores_img, box_per_img, image_size, class_names in zip(scores, boxes_logits, target_sizes, classes): - results.append( - _post_process_boxes_for_image( - box_per_img, - scores_img, - predicted_classes, - class_names, - image_size, - num_classes, - score_threshold=score_threshold, - nms_threshold=nms_threshold, - max_num_det=max_num_det, - ) + for boxes, scores, image_size, image_num_classes in zip( + batch_boxes, batch_scores, target_sizes, batch_num_classes + ): + boxes, scores, labels = _post_process_boxes_for_image( + boxes=boxes, + scores=scores, + labels=batch_labels, + image_num_classes=image_num_classes, + image_size=image_size, + threshold=threshold, + nms_threshold=nms_threshold, + max_num_det=max_num_det, ) + result = DictWithDeprecationWarning( + {"boxes": boxes, "scores": scores, "labels": labels, "text_labels": None} + ) + results.append(result) + + # Add text labels + if text_labels is not None: + for result, image_text_labels in zip(results, text_labels): + result["text_labels"] = [image_text_labels[idx] for idx in result["labels"]] return results diff --git a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py index 75c0e6f1c78d58..d057b35006d3ee 100644 --- a/tests/models/omdet_turbo/test_modeling_omdet_turbo.py +++ b/tests/models/omdet_turbo/test_modeling_omdet_turbo.py @@ -646,9 +646,9 @@ def prepare_img(): def prepare_text(): - classes = ["cat", "remote"] - task = "Detect {}.".format(", ".join(classes)) - return classes, task + text_labels = ["cat", "remote"] + task = "Detect {}.".format(", ".join(text_labels)) + return text_labels, task def prepare_img_batched(): @@ -660,14 +660,14 @@ def prepare_img_batched(): def prepare_text_batched(): - classes1 = ["cat", "remote"] - classes2 = ["boat"] - classes3 = ["statue", "trees", "torch"] + text_labels1 = ["cat", "remote"] + text_labels2 = ["boat"] + text_labels3 = ["statue", "trees", "torch"] - task1 = "Detect {}.".format(", ".join(classes1)) + task1 = "Detect {}.".format(", ".join(text_labels1)) task2 = "Detect all the boat in the image." task3 = "Focus on the foreground, detect statue, torch and trees." - return [classes1, classes2, classes3], [task1, task2, task3] + return [text_labels1, text_labels2, text_labels3], [task1, task2, task3] @require_timm @@ -683,8 +683,8 @@ def test_inference_object_detection_head(self): processor = self.default_processor image = prepare_img() - classes, task = prepare_text() - encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to(torch_device) + text_labels, task = prepare_text() + encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to(torch_device) with torch.no_grad(): outputs = model(**encoding) @@ -706,7 +706,7 @@ def test_inference_object_detection_head(self): # verify grounded postprocessing results = processor.post_process_grounded_object_detection( - outputs, classes=[classes], target_sizes=[image.size[::-1]] + outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]] )[0] expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device) expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device) @@ -715,8 +715,8 @@ def test_inference_object_detection_head(self): self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2)) - expected_classes = ["remote", "cat", "remote", "cat"] - self.assertListEqual(results["classes"], expected_classes) + expected_text_labels = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["text_labels"], expected_text_labels) def test_inference_object_detection_head_fp16(self): model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to( @@ -725,8 +725,8 @@ def test_inference_object_detection_head_fp16(self): processor = self.default_processor image = prepare_img() - classes, task = prepare_text() - encoding = processor(images=image, text=classes, task=task, return_tensors="pt").to( + text_labels, task = prepare_text() + encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt").to( torch_device, dtype=torch.float16 ) @@ -750,7 +750,7 @@ def test_inference_object_detection_head_fp16(self): # verify grounded postprocessing results = processor.post_process_grounded_object_detection( - outputs, classes=[classes], target_sizes=[image.size[::-1]] + outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]] )[0] expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device, dtype=torch.float16) expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to( @@ -761,16 +761,16 @@ def test_inference_object_detection_head_fp16(self): self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-1)) - expected_classes = ["remote", "cat", "remote", "cat"] - self.assertListEqual(results["classes"], expected_classes) + expected_text_labels = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["text_labels"], expected_text_labels) def test_inference_object_detection_head_no_task(self): model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf").to(torch_device) processor = self.default_processor image = prepare_img() - classes, _ = prepare_text() - encoding = processor(images=image, text=classes, return_tensors="pt").to(torch_device) + text_labels, _ = prepare_text() + encoding = processor(images=image, text=text_labels, return_tensors="pt").to(torch_device) with torch.no_grad(): outputs = model(**encoding) @@ -792,7 +792,7 @@ def test_inference_object_detection_head_no_task(self): # verify grounded postprocessing results = processor.post_process_grounded_object_detection( - outputs, classes=[classes], target_sizes=[image.size[::-1]] + outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]] )[0] expected_scores = torch.tensor([0.7675, 0.7196, 0.5634, 0.5524]).to(torch_device) expected_slice_boxes = torch.tensor([39.8870, 70.3522, 176.7424, 118.0354]).to(torch_device) @@ -801,8 +801,8 @@ def test_inference_object_detection_head_no_task(self): self.assertTrue(torch.allclose(results["scores"], expected_scores, atol=1e-2)) self.assertTrue(torch.allclose(results["boxes"][0, :], expected_slice_boxes, atol=1e-2)) - expected_classes = ["remote", "cat", "remote", "cat"] - self.assertListEqual(results["classes"], expected_classes) + expected_text_labels = ["remote", "cat", "remote", "cat"] + self.assertListEqual(results["text_labels"], expected_text_labels) def test_inference_object_detection_head_batched(self): torch_device = "cpu" @@ -810,10 +810,10 @@ def test_inference_object_detection_head_batched(self): processor = self.default_processor images_batched = prepare_img_batched() - classes_batched, tasks_batched = prepare_text_batched() - encoding = processor(images=images_batched, text=classes_batched, task=tasks_batched, return_tensors="pt").to( - torch_device - ) + text_labels_batched, tasks_batched = prepare_text_batched() + encoding = processor( + images=images_batched, text=text_labels_batched, task=tasks_batched, return_tensors="pt" + ).to(torch_device) with torch.no_grad(): outputs = model(**encoding) @@ -837,7 +837,7 @@ def test_inference_object_detection_head_batched(self): # verify grounded postprocessing results = processor.post_process_grounded_object_detection( outputs, - classes=classes_batched, + text_labels=text_labels_batched, target_sizes=[image.size[::-1] for image in images_batched], score_threshold=0.2, ) @@ -858,19 +858,19 @@ def test_inference_object_detection_head_batched(self): torch.allclose(torch.stack([result["boxes"][0, :] for result in results]), expected_slice_boxes, atol=1e-2) ) - expected_classes = [ + expected_text_labels = [ ["remote", "cat", "remote", "cat"], ["boat", "boat", "boat", "boat"], ["statue", "trees", "trees", "torch", "statue", "statue"], ] - self.assertListEqual([result["classes"] for result in results], expected_classes) + self.assertListEqual([result["text_labels"] for result in results], expected_text_labels) @require_torch_accelerator def test_inference_object_detection_head_equivalence_cpu_gpu(self): processor = self.default_processor image = prepare_img() - classes, task = prepare_text() - encoding = processor(images=image, text=classes, task=task, return_tensors="pt") + text_labels, task = prepare_text() + encoding = processor(images=image, text=text_labels, task=task, return_tensors="pt") # 1. run model on CPU model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf") @@ -894,10 +894,10 @@ def test_inference_object_detection_head_equivalence_cpu_gpu(self): # verify grounded postprocessing results_cpu = processor.post_process_grounded_object_detection( - cpu_outputs, classes=[classes], target_sizes=[image.size[::-1]] + cpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]] )[0] result_gpu = processor.post_process_grounded_object_detection( - gpu_outputs, classes=[classes], target_sizes=[image.size[::-1]] + gpu_outputs, text_labels=[text_labels], target_sizes=[image.size[::-1]] )[0] self.assertTrue(torch.allclose(results_cpu["scores"], result_gpu["scores"].cpu(), atol=1e-2)) diff --git a/tests/models/omdet_turbo/test_processor_omdet_turbo.py b/tests/models/omdet_turbo/test_processor_omdet_turbo.py index 52e1926e50b22f..341c7a1d9a1365 100644 --- a/tests/models/omdet_turbo/test_processor_omdet_turbo.py +++ b/tests/models/omdet_turbo/test_processor_omdet_turbo.py @@ -76,10 +76,13 @@ def tearDown(self): shutil.rmtree(self.tmpdirname) def get_fake_omdet_turbo_output(self): + classes = self.get_fake_omdet_turbo_classes() + classes_structure = torch.tensor([len(sublist) for sublist in classes]) torch.manual_seed(42) return OmDetTurboObjectDetectionOutput( decoder_coord_logits=torch.rand(self.batch_size, self.num_queries, 4), decoder_class_logits=torch.rand(self.batch_size, self.num_queries, self.embed_dim), + classes_structure=classes_structure, ) def get_fake_omdet_turbo_classes(self): @@ -99,7 +102,7 @@ def test_post_process_grounded_object_detection(self): ) self.assertEqual(len(post_processed), self.batch_size) - self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "classes"]) + self.assertEqual(list(post_processed[0].keys()), ["boxes", "scores", "labels", "text_labels"]) self.assertEqual(post_processed[0]["boxes"].shape, (self.num_queries, 4)) self.assertEqual(post_processed[0]["scores"].shape, (self.num_queries,)) expected_scores = torch.tensor([0.7310, 0.6579, 0.6513, 0.6444, 0.6252])