diff --git a/.github/workflows/test_accuracy.yml b/.github/workflows/test_accuracy.yml index f16ea295..862b6ac8 100644 --- a/.github/workflows/test_accuracy.yml +++ b/.github/workflows/test_accuracy.yml @@ -29,6 +29,7 @@ jobs: run: | source venv/bin/activate pytest --data=./data tests/python/accuracy/test_accuracy.py + DATA=data pytest --data=./data tests/python/accuracy/test_YOLOv8.py - name: Install CPP ependencies run: | sudo bash model_api/cpp/install_dependencies.sh @@ -40,3 +41,4 @@ jobs: - name: Run CPP Test run: | build/test_accuracy -d data -p tests/python/accuracy/public_scope.json + DATA=data build/test_YOLOv8 diff --git a/docs/model-configuration.md b/docs/model-configuration.md index 6eee3260..89d16d2c 100644 --- a/docs/model-configuration.md +++ b/docs/model-configuration.md @@ -49,6 +49,9 @@ The list features only model wrappers which intoduce new configuration values in ###### `YoloV4` 1. `anchors`: List - list of custom anchor values 1. `masks`: List - list of mask, applied to anchors for each output layer +###### `YOLOv5`, `YOLOv8` +1. `agnostic_nms`: bool - if True, the model is agnostic to the number of classes, and all classes are considered as one +1. `iou_threshold`: float - threshold for non-maximum suppression (NMS) intersection over union (IOU) filtering ###### `YOLOX` 1. `iou_threshold`: float - threshold for non-maximum suppression (NMS) intersection over union (IOU) filtering #### `HpeAssociativeEmbedding` diff --git a/model_api/cpp/models/include/models/detection_model_yolo.h b/model_api/cpp/models/include/models/detection_model_yolo.h index 0993bb7a..ed821326 100644 --- a/model_api/cpp/models/include/models/detection_model_yolo.h +++ b/model_api/cpp/models/include/models/detection_model_yolo.h @@ -83,3 +83,24 @@ class ModelYolo : public DetectionModelExt { std::vector presetMasks; ov::Layout yoloRegionLayout = "NCHW"; }; + +class YOLOv5 : public DetectionModelExt { + // Reimplementation of ultralytics.YOLO + void prepareInputsOutputs(std::shared_ptr& model) override; + void updateModelInfo() override; + void init_from_config(const ov::AnyMap& top_priority, const ov::AnyMap& mid_priority); + bool agnostic_nms = false; +public: + YOLOv5(std::shared_ptr& model, const ov::AnyMap& configuration); + YOLOv5(std::shared_ptr& adapter); + std::unique_ptr postprocess(InferenceResult& infResult) override; + static std::string ModelType; +}; + +class YOLOv8 : public YOLOv5 { +public: + // YOLOv5 and YOLOv8 are identical in terms of inference + YOLOv8(std::shared_ptr& model, const ov::AnyMap& configuration) : YOLOv5{model, configuration} {} + YOLOv8(std::shared_ptr& adapter) : YOLOv5{adapter} {} + static std::string ModelType; +}; diff --git a/model_api/cpp/models/src/detection_model.cpp b/model_api/cpp/models/src/detection_model.cpp index 31ae0788..8ed9a39a 100644 --- a/model_api/cpp/models/src/detection_model.cpp +++ b/model_api/cpp/models/src/detection_model.cpp @@ -91,6 +91,10 @@ std::unique_ptr DetectionModel::create_model(const std::string& detectionModel = std::unique_ptr(new ModelYoloX(model, configuration)); } else if (model_type == ModelCenterNet::ModelType) { detectionModel = std::unique_ptr(new ModelCenterNet(model, configuration)); + } else if (model_type == YOLOv5::ModelType) { + detectionModel = std::unique_ptr(new YOLOv5(model, configuration)); + } else if (model_type == YOLOv8::ModelType) { + detectionModel = std::unique_ptr(new YOLOv8(model, configuration)); } else { throw std::runtime_error("Incorrect or unsupported model_type is provided in the model_info section: " + model_type); } diff --git a/model_api/cpp/models/src/detection_model_faceboxes.cpp b/model_api/cpp/models/src/detection_model_faceboxes.cpp index d6f9bdea..71daf165 100644 --- a/model_api/cpp/models/src/detection_model_faceboxes.cpp +++ b/model_api/cpp/models/src/detection_model_faceboxes.cpp @@ -243,7 +243,7 @@ std::unique_ptr ModelFaceBoxes::postprocess(InferenceResult& infResu std::vector boxes = filterBoxes(boxesTensor, anchors, scores.first, variance); // Apply Non-maximum Suppression - const std::vector keep = nms(boxes, scores.second, iou_threshold); + const std::vector& keep = nms(boxes, scores.second, iou_threshold); // Create detection result objects DetectionResult* result = new DetectionResult(infResult.frameId, infResult.metaData); diff --git a/model_api/cpp/models/src/detection_model_ssd.cpp b/model_api/cpp/models/src/detection_model_ssd.cpp index dd646302..c30fef56 100644 --- a/model_api/cpp/models/src/detection_model_ssd.cpp +++ b/model_api/cpp/models/src/detection_model_ssd.cpp @@ -162,12 +162,13 @@ std::unique_ptr ModelSSD::postprocessSingleOutput(InferenceResult& i 0.f, floatInputImgHeight); desc.width = clamp( - round((detections[i * numAndStep.objectSize + 5] * netInputWidth - padLeft) * invertedScaleX - desc.x), + round((detections[i * numAndStep.objectSize + 5] * netInputWidth - padLeft) * invertedScaleX), 0.f, - floatInputImgWidth); + floatInputImgWidth) - desc.x; desc.height = clamp( - round((detections[i * numAndStep.objectSize + 6] * netInputHeight - padTop) * invertedScaleY - desc.y), - 0.f, floatInputImgHeight); + round((detections[i * numAndStep.objectSize + 6] * netInputHeight - padTop) * invertedScaleY), + 0.f, + floatInputImgHeight) - desc.y; result->objects.push_back(desc); } } @@ -223,12 +224,13 @@ std::unique_ptr ModelSSD::postprocessMultipleOutputs(InferenceResult 0.f, floatInputImgHeight); desc.width = clamp( - round((boxes[i * numAndStep.objectSize + 2] * widthScale - padLeft) * invertedScaleX - desc.x), + round((boxes[i * numAndStep.objectSize + 2] * widthScale - padLeft) * invertedScaleX), 0.f, - floatInputImgWidth); + floatInputImgWidth) - desc.x; desc.height = clamp( - round((boxes[i * numAndStep.objectSize + 3] * heightScale - padTop) * invertedScaleY - desc.y), - 0.f, floatInputImgHeight); + round((boxes[i * numAndStep.objectSize + 3] * heightScale - padTop) * invertedScaleY), + 0.f, + floatInputImgHeight) - desc.y; result->objects.push_back(desc); } } diff --git a/model_api/cpp/models/src/detection_model_yolo.cpp b/model_api/cpp/models/src/detection_model_yolo.cpp index 8b1d928d..e2552dda 100644 --- a/model_api/cpp/models/src/detection_model_yolo.cpp +++ b/model_api/cpp/models/src/detection_model_yolo.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include "models/internal_model_data.h" @@ -504,3 +505,171 @@ ModelYolo::Region::Region(size_t classes, num = anchors.size() / 2; } } + +std::string YOLOv5::ModelType = "YOLOv5"; + +void YOLOv5::prepareInputsOutputs(std::shared_ptr& model) { + const ov::Output& input = model->input(); + const ov::Shape& in_shape = input.get_partial_shape().get_max_shape(); + if (in_shape.size() != 4) { + throw std::runtime_error("YOLO: the rank of the input must be 4"); + } + inputNames.push_back(input.get_any_name()); + const ov::Layout& inputLayout = getInputLayout(input); + if (!embedded_processing) { + model = ImageModel::embedProcessing(model, + inputNames[0], + inputLayout, + resizeMode, + interpolationMode, + ov::Shape{ + in_shape[ov::layout::width_idx(inputLayout)], + in_shape[ov::layout::height_idx(inputLayout)] + }, + pad_value, + reverse_input_channels, + {}, + scale_values); + + netInputWidth = in_shape[ov::layout::width_idx(inputLayout)]; + netInputHeight = in_shape[ov::layout::height_idx(inputLayout)]; + + embedded_processing = true; + } + + const ov::Output& output = model->output(); + if (ov::element::Type_t::f32 != output.get_element_type()) { + throw std::runtime_error("YOLO: the output must be of precision f32"); + } + const ov::Shape& out_shape = output.get_partial_shape().get_max_shape(); + if (3 != out_shape.size()) { + throw std::runtime_error("YOLO: the output must be of rank 3"); + } + if (!labels.empty() && labels.size() + 4 != out_shape[1]) { + throw std::runtime_error("YOLO: number of labels must be smaller than out_shape[1] by 4"); + } +} + +void YOLOv5::updateModelInfo() { + DetectionModelExt::updateModelInfo(); + model->set_rt_info(YOLOv5::ModelType, "model_info", "model_type"); + model->set_rt_info(agnostic_nms, "model_info", "agnostic_nms"); + model->set_rt_info(iou_threshold, "model_info", "iou_threshold"); +} + +void YOLOv5::init_from_config(const ov::AnyMap& top_priority, const ov::AnyMap& mid_priority) { + pad_value = get_from_any_maps("pad_value", top_priority, mid_priority, 114); + if (top_priority.find("resize_type") == top_priority.end() && mid_priority.find("resize_type") == mid_priority.end()) { + interpolationMode = cv::INTER_LINEAR; + resizeMode = RESIZE_KEEP_ASPECT_LETTERBOX; + } + reverse_input_channels = get_from_any_maps("reverse_input_channels", top_priority, mid_priority, true); + scale_values = get_from_any_maps("scale_values", top_priority, mid_priority, std::vector{255.0f}); + confidence_threshold = get_from_any_maps("confidence_threshold", top_priority, mid_priority, 0.25f); + agnostic_nms = get_from_any_maps("agnostic_nms", top_priority, mid_priority, agnostic_nms); + iou_threshold = get_from_any_maps("iou_threshold", top_priority, mid_priority, 0.7f); +} + +YOLOv5::YOLOv5(std::shared_ptr& model, const ov::AnyMap& configuration) + : DetectionModelExt(model, configuration) { + init_from_config(configuration, model->get_rt_info("model_info")); +} + +YOLOv5::YOLOv5(std::shared_ptr& adapter) + : DetectionModelExt(adapter) { + init_from_config(adapter->getModelConfig(), ov::AnyMap{}); +} + +std::unique_ptr YOLOv5::postprocess(InferenceResult& infResult) { + if (1 != infResult.outputsData.size()) { + throw std::runtime_error("YOLO: expect 1 output"); + } + const ov::Tensor& detectionsTensor = infResult.getFirstOutputTensor(); + const ov::Shape& out_shape = detectionsTensor.get_shape(); + if (3 != out_shape.size()) { + throw std::runtime_error("YOLO: the output must be of rank 3"); + } + if (1 != out_shape[0]) { + throw std::runtime_error("YOLO: the first dim of the output must be 1"); + } + size_t num_proposals = out_shape[2]; + std::vector boxes; + std::vector confidences; + std::vector labelIDs; + const float* const detections = detectionsTensor.data(); + for (size_t i = 0; i < num_proposals; ++i) { + float confidence = 0.0f; + size_t max_id = 0; + constexpr size_t LABELS_START = 4; + for (size_t j = LABELS_START; j < out_shape[1]; ++j) { + if (detections[j * num_proposals + i] > confidence) { + confidence = detections[j * num_proposals + i]; + max_id = j; + } + } + if (confidence > confidence_threshold) { + boxes.push_back(Anchor{ + detections[0 * num_proposals + i] - detections[2 * num_proposals + i] / 2.0f, + detections[1 * num_proposals + i] - detections[3 * num_proposals + i] / 2.0f, + detections[0 * num_proposals + i] + detections[2 * num_proposals + i] / 2.0f, + detections[1 * num_proposals + i] + detections[3 * num_proposals + i] / 2.0f, + }); + confidences.push_back(confidence); + labelIDs.push_back(max_id - LABELS_START); + } + } + constexpr bool includeBoundaries = false; + constexpr size_t keep_top_k = 30000; + std::vector keep; + if (agnostic_nms) { + keep = nms(boxes, confidences, iou_threshold, includeBoundaries, keep_top_k); + } else { + std::vector boxes_with_class; + boxes_with_class.reserve(boxes.size()); + for (size_t i = 0; i < boxes.size(); ++i) { + boxes_with_class.emplace_back(boxes[i], int(labelIDs[i])); + } + keep = multiclass_nms(boxes_with_class, confidences, iou_threshold, includeBoundaries, keep_top_k); + } + DetectionResult* result = new DetectionResult(infResult.frameId, infResult.metaData); + auto base = std::unique_ptr(result); + const auto& internalData = infResult.internalModelData->asRef(); + float floatInputImgWidth = float(internalData.inputImgWidth), + floatInputImgHeight = float(internalData.inputImgHeight); + float invertedScaleX = floatInputImgWidth / netInputWidth, + invertedScaleY = floatInputImgHeight / netInputHeight; + int padLeft = 0, padTop = 0; + if (RESIZE_KEEP_ASPECT == resizeMode || RESIZE_KEEP_ASPECT_LETTERBOX == resizeMode) { + invertedScaleX = invertedScaleY = std::max(invertedScaleX, invertedScaleY); + if (RESIZE_KEEP_ASPECT_LETTERBOX == resizeMode) { + padLeft = (netInputWidth - int(std::round(floatInputImgWidth / invertedScaleX))) / 2; + padTop = (netInputHeight - int(std::round(floatInputImgHeight / invertedScaleY))) / 2; + } + } + for (size_t idx : keep) { + DetectedObject desc; + desc.x = clamp( + round((boxes[idx].left - padLeft) * invertedScaleX), + 0.f, + floatInputImgWidth); + desc.y = clamp( + round((boxes[idx].top - padTop) * invertedScaleY), + 0.f, + floatInputImgHeight); + desc.width = clamp( + round((boxes[idx].right - padLeft) * invertedScaleX), + 0.f, + floatInputImgWidth) - desc.x; + desc.height = clamp( + round((boxes[idx].bottom - padTop) * invertedScaleY), + 0.f, + floatInputImgHeight) - desc.y; + desc.confidence = confidences[idx]; + desc.labelID = static_cast(labelIDs[idx]); + desc.label = getLabelName(desc.labelID); + result->objects.push_back(desc); + } + return base; +} + +std::string YOLOv8::ModelType = "YOLOv8"; diff --git a/model_api/cpp/models/src/detection_model_yolox.cpp b/model_api/cpp/models/src/detection_model_yolox.cpp index f46d4203..df47955e 100644 --- a/model_api/cpp/models/src/detection_model_yolox.cpp +++ b/model_api/cpp/models/src/detection_model_yolox.cpp @@ -190,8 +190,8 @@ std::unique_ptr ModelYoloX::postprocess(InferenceResult& infResult) } // NMS for valid boxes - std::vector keep = nms(validBoxes, scores, iou_threshold, true); - for (auto& index: keep) { + const std::vector& keep = nms(validBoxes, scores, iou_threshold, true); + for (size_t index: keep) { // Create new detected box DetectedObject obj; obj.x = clamp(validBoxes[index].left, 0.f, static_cast(scale.inputImgWidth)); diff --git a/model_api/cpp/utils/include/utils/nms.hpp b/model_api/cpp/utils/include/utils/nms.hpp index e6a23b51..ffa0cade 100644 --- a/model_api/cpp/utils/include/utils/nms.hpp +++ b/model_api/cpp/utils/include/utils/nms.hpp @@ -50,13 +50,13 @@ struct AnchorLabeled : public Anchor { AnchorLabeled() = default; AnchorLabeled(float _left, float _top, float _right, float _bottom, int _labelID) : Anchor(_left, _top, _right, _bottom), labelID(_labelID) {} + AnchorLabeled(const Anchor& coords, int labelID) : Anchor{coords}, labelID{labelID} {} }; template -std::vector nms(const std::vector& boxes, const std::vector& scores, - const float thresh, bool includeBoundaries=false, size_t maxNum=0) { - if (maxNum == 0) { - maxNum = boxes.size(); +std::vector nms(const std::vector& boxes, const std::vector& scores, const float thresh, bool includeBoundaries=false, size_t keep_top_k=0) { + if (keep_top_k == 0) { + keep_top_k = boxes.size(); } std::vector areas(boxes.size()); for (size_t i = 0; i < boxes.size(); ++i) { @@ -67,25 +67,24 @@ std::vector nms(const std::vector& boxes, const std::vector& std::sort(order.begin(), order.end(), [&scores](int o1, int o2) { return scores[o1] > scores[o2]; }); size_t ordersNum = 0; - for (; ordersNum < order.size() && scores[order[ordersNum]] >= 0 && ordersNum < maxNum; ordersNum++); + for (; ordersNum < order.size() && scores[order[ordersNum]] >= 0 && ordersNum < keep_top_k; ordersNum++); - std::vector keep; + std::vector keep; bool shouldContinue = true; for (size_t i = 0; shouldContinue && i < ordersNum; ++i) { - auto idx1 = order[i]; + int idx1 = order[i]; if (idx1 >= 0) { keep.push_back(idx1); shouldContinue = false; for (size_t j = i + 1; j < ordersNum; ++j) { - auto idx2 = order[j]; + int idx2 = order[j]; if (idx2 >= 0) { shouldContinue = true; - auto overlappingWidth = std::fminf(boxes[idx1].right, boxes[idx2].right) - std::fmaxf(boxes[idx1].left, boxes[idx2].left); - auto overlappingHeight = std::fminf(boxes[idx1].bottom, boxes[idx2].bottom) - std::fmaxf(boxes[idx1].top, boxes[idx2].top); - auto intersection = overlappingWidth > 0 && overlappingHeight > 0 ? overlappingWidth * overlappingHeight : 0; - auto overlap = intersection / (areas[idx1] + areas[idx2] - intersection); - - if (overlap >= thresh) { + float overlappingWidth = std::fminf(boxes[idx1].right, boxes[idx2].right) - std::fmaxf(boxes[idx1].left, boxes[idx2].left); + float overlappingHeight = std::fminf(boxes[idx1].bottom, boxes[idx2].bottom) - std::fmaxf(boxes[idx1].top, boxes[idx2].top); + float intersection = overlappingWidth > 0 && overlappingHeight > 0 ? overlappingWidth * overlappingHeight : 0; + float union_area = areas[idx1] + areas[idx2] - intersection; + if (0.0f == union_area || intersection / union_area > thresh) { order[j] = -1; } } @@ -95,5 +94,5 @@ std::vector nms(const std::vector& boxes, const std::vector& return keep; } -std::vector multiclass_nms(const std::vector& boxes, const std::vector& scores, +std::vector multiclass_nms(const std::vector& boxes, const std::vector& scores, const float iou_threshold=0.45f, bool includeBoundaries=false, size_t maxNum=200); diff --git a/model_api/cpp/utils/src/nms.cpp b/model_api/cpp/utils/src/nms.cpp index 16444906..e77f30f2 100644 --- a/model_api/cpp/utils/src/nms.cpp +++ b/model_api/cpp/utils/src/nms.cpp @@ -19,7 +19,7 @@ #include "utils/nms.hpp" -std::vector multiclass_nms(const std::vector& boxes, const std::vector& scores, +std::vector multiclass_nms(const std::vector& boxes, const std::vector& scores, const float iou_threshold, bool includeBoundaries, size_t maxNum) { std::vector boxes_copy; boxes_copy.reserve(boxes.size()); diff --git a/model_api/python/openvino/model_api/models/__init__.py b/model_api/python/openvino/model_api/models/__init__.py index 247acb52..6d6af150 100644 --- a/model_api/python/openvino/model_api/models/__init__.py +++ b/model_api/python/openvino/model_api/models/__init__.py @@ -56,7 +56,7 @@ add_rotated_rects, get_contours, ) -from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4 +from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8 classification_models = [ "resnet-18-pytorch", @@ -118,6 +118,8 @@ "YOLO", "YoloV3ONNX", "YoloV4", + "YOLOv5", + "YOLOv8", "YOLOF", "YOLOX", "ClassificationResult", diff --git a/model_api/python/openvino/model_api/models/image_model.py b/model_api/python/openvino/model_api/models/image_model.py index c8850278..5db8cd5d 100644 --- a/model_api/python/openvino/model_api/models/image_model.py +++ b/model_api/python/openvino/model_api/models/image_model.py @@ -91,6 +91,20 @@ def parameters(cls): parameters = super().parameters() parameters.update( { + "embedded_processing": BooleanValue( + description="Flag that pre/postprocessing embedded", + default_value=False, + ), + "mean_values": ListValue( + description="Normalization values, which will be subtracted from image channels for image-input layer during preprocessing", + default_value=[], + ), + "orig_height": NumericalValue( + int, description="Model input height before embedding processing" + ), + "orig_width": NumericalValue( + int, description="Model input width before embedding processing" + ), "pad_value": NumericalValue( int, min=0, @@ -98,31 +112,17 @@ def parameters(cls): description="Pad value for resize_image_letterbox embedded into a model", default_value=0, ), - "mean_values": ListValue( - default_value=[], - description="Normalization values, which will be subtracted from image channels for image-input layer during preprocessing", - ), - "scale_values": ListValue( - default_value=[], - description="Normalization values, which will divide the image channels for image-input layer", - ), - "reverse_input_channels": BooleanValue( - default_value=False, description="Reverse the input channel order" - ), "resize_type": StringValue( default_value="standard", choices=tuple(RESIZE_TYPES.keys()), description="Type of input image resizing", ), - "embedded_processing": BooleanValue( - default_value=False, - description="Flag that pre/postprocessing embedded", - ), - "orig_width": NumericalValue( - int, description="Model input width before embedding processing" + "reverse_input_channels": BooleanValue( + default_value=False, description="Reverse the input channel order" ), - "orig_height": NumericalValue( - int, description="Model input height before embedding processing" + "scale_values": ListValue( + default_value=[], + description="Normalization values, which will divide the image channels for image-input layer", ), } ) diff --git a/model_api/python/openvino/model_api/models/model.py b/model_api/python/openvino/model_api/models/model.py index 1f704b51..5ba50d09 100644 --- a/model_api/python/openvino/model_api/models/model.py +++ b/model_api/python/openvino/model_api/models/model.py @@ -26,8 +26,8 @@ from openvino.model_api.adapters.ovms_adapter import OVMSAdapter -class WrapperError(RuntimeError): - """Special class for errors occurred in Model API wrappers""" +class WrapperError(Exception): + """The class for errors occurred in Model API wrappers""" def __init__(self, wrapper_name, message): super().__init__(f"{wrapper_name}: {message}") diff --git a/model_api/python/openvino/model_api/models/utils.py b/model_api/python/openvino/model_api/models/utils.py index b79901fb..8e9d84ac 100644 --- a/model_api/python/openvino/model_api/models/utils.py +++ b/model_api/python/openvino/model_api/models/utils.py @@ -87,7 +87,9 @@ class DetectionResult( ): def __str__(self): obj_str = "; ".join(str(obj) for obj in self.objects) - return f"{obj_str}; [{','.join(str(i) for i in self.saliency_map.shape)}]; [{','.join(str(i) for i in self.feature_vector.shape)}]" + if obj_str: + obj_str += "; " + return f"{obj_str}[{','.join(str(i) for i in self.saliency_map.shape)}]; [{','.join(str(i) for i in self.feature_vector.shape)}]" class SegmentedObject(Detection): @@ -340,7 +342,7 @@ def crop_resize(image, size): } -def nms(x1, y1, x2, y2, scores, thresh, include_boundaries=False, keep_top_k=None): +def nms(x1, y1, x2, y2, scores, thresh, include_boundaries=False, keep_top_k=0): b = 1 if include_boundaries else 0 areas = (x2 - x1 + b) * (y2 - y1 + b) order = scores.argsort()[::-1] @@ -362,12 +364,12 @@ def nms(x1, y1, x2, y2, scores, thresh, include_boundaries=False, keep_top_k=Non h = np.maximum(0.0, yy2 - yy1 + b) intersection = w * h - union = areas[i] + areas[order[1:]] - intersection + union_areas = areas[i] + areas[order[1:]] - intersection overlap = np.divide( intersection, - union, + union_areas, out=np.zeros_like(intersection, dtype=float), - where=union != 0, + where=union_areas != 0, ) order = order[np.where(overlap <= thresh)[0] + 1] diff --git a/model_api/python/openvino/model_api/models/yolo.py b/model_api/python/openvino/model_api/models/yolo.py index 83307d1c..ce6c031a 100644 --- a/model_api/python/openvino/model_api/models/yolo.py +++ b/model_api/python/openvino/model_api/models/yolo.py @@ -16,8 +16,15 @@ import numpy as np from .detection_model import DetectionModel -from .types import ListValue, NumericalValue -from .utils import INTERPOLATION_TYPES, Detection, clip_detections, nms, resize_image +from .types import BooleanValue, ListValue, NumericalValue +from .utils import ( + INTERPOLATION_TYPES, + Detection, + DetectionResult, + clip_detections, + nms, + resize_image, +) DetectionBox = namedtuple("DetectionBox", ["x", "y", "w", "h"]) @@ -111,6 +118,19 @@ def sigmoid(x): return 1.0 / (1.0 + np.exp(-x)) +def xywh2xyxy(xywh): + return np.stack( + ( + xywh[:, 0] - xywh[:, 2] / 2.0, + xywh[:, 1] - xywh[:, 3] / 2.0, + xywh[:, 0] + xywh[:, 2] / 2.0, + xywh[:, 1] + xywh[:, 3] / 2.0, + ), + 1, + xywh, + ) + + class YOLO(DetectionModel): __model__ = "YOLO" @@ -513,7 +533,7 @@ def postprocess(self, outputs, meta): valid_predictions = output[output[..., 4] > self.confidence_threshold] valid_predictions[:, 5:] *= valid_predictions[:, 4:5] - boxes = self.xywh2xyxy(valid_predictions[:, :4]) / meta["scale"] + boxes = xywh2xyxy(valid_predictions[:, :4]) / meta["scale"] i, j = (valid_predictions[:, 5:] > self.confidence_threshold).nonzero() x_mins, y_mins, x_maxs, y_maxs = boxes[i].T scores = valid_predictions[i, j + 5] @@ -560,15 +580,6 @@ def set_strides_grids(self): self.grids = np.concatenate(grids, 1) self.expanded_strides = np.concatenate(expanded_strides, 1) - @staticmethod - def xywh2xyxy(x): - y = np.copy(x) - y[:, 0] = x[:, 0] - x[:, 2] / 2 - y[:, 1] = x[:, 1] - x[:, 3] / 2 - y[:, 2] = x[:, 0] + x[:, 2] / 2 - y[:, 3] = x[:, 1] + x[:, 3] / 2 - return y - class YoloV3ONNX(DetectionModel): __model__ = "YOLOv3-ONNX" @@ -707,3 +718,130 @@ def _parse_outputs(self, outputs): ] return detections + + +class YOLOv5(DetectionModel): + """ + Reimplementation of ultralytics.YOLO + """ + + __model__ = "YOLOv5" + + def __init__(self, inference_adapter, configuration, preload=False): + super().__init__(inference_adapter, configuration, preload) + self._check_io_number(1, 1) + output = next(iter(self.outputs.values())) + if "f32" != output.precision: + self.raise_error("the output must be of precision f32") + out_shape = output.shape + if 3 != len(out_shape): + self.raise_error("the output must be of rank 3") + if self.labels and len(self.labels) + 4 != out_shape[1]: + self.raise_error("number of labels must be smaller than out_shape[1] by 4") + + @classmethod + def parameters(cls): + parameters = super().parameters() + parameters["pad_value"].update_default_value(114) + parameters["resize_type"].update_default_value("fit_to_window_letterbox") + parameters["reverse_input_channels"].update_default_value(True) + parameters["scale_values"].update_default_value([255.0]) + parameters["confidence_threshold"].update_default_value(0.25) + parameters.update( + { + "agnostic_nms": BooleanValue( + description="If True, the model is agnostic to the number of classes, and all classes are considered as one", + default_value=False, + ), + "iou_threshold": NumericalValue( + float, + min=0.0, + max=1.0, + default_value=0.7, + description="Threshold for non-maximum suppression (NMS) intersection over union (IOU) filtering", + ), + } + ) + return parameters + + def postprocess(self, outputs, meta): + if 1 != len(outputs): + self.raise_error("expect 1 output") + prediction = next(iter(outputs.values())) + if np.float32 != prediction.dtype: + self.raise_error("the output must be of precision f32") + out_shape = prediction.shape + if 3 != len(out_shape): + raise RuntimeError("the output must be of rank 3") + if 1 != out_shape[0]: + raise RuntimeError("the first dim of the output must be 1") + LABELS_START = 4 + xc = ( + np.amax(prediction[:, LABELS_START:], 1) > self.confidence_threshold + ) # Candidates + x = prediction[0] + x = x.transpose(1, 0)[xc[0]] + box, cls = x[:, :LABELS_START], x[:, LABELS_START:] + box = xywh2xyxy(box) + j = cls.argmax(1, keepdims=True) + conf = np.take_along_axis(cls, j, 1) + x = np.concatenate((box, conf, j.astype(np.float32)), 1) + max_wh = 0 if self.agnostic_nms else 7680 + c = x[:, 5:6] * max_wh + boxes = x[:, :LABELS_START] + c + boxes = x[ + nms( + boxes[:, 0], + boxes[:, 1], + boxes[:, 2], + boxes[:, 3], + x[:, LABELS_START], + self.iou_threshold, + keep_top_k=30000, + ) + ] + inputImgWidth = meta["original_shape"][1] + inputImgHeight = meta["original_shape"][0] + invertedScaleX, invertedScaleY = ( + inputImgWidth / self.orig_width, + inputImgHeight / self.orig_height, + ) + padLeft, padTop = 0, 0 + if ( + "fit_to_window" == self.resize_type + or "fit_to_window_letterbox" == self.resize_type + ): + invertedScaleX = invertedScaleY = max(invertedScaleX, invertedScaleY) + if "fit_to_window_letterbox" == self.resize_type: + padLeft = (self.orig_width - round(inputImgWidth / invertedScaleX)) // 2 + padTop = ( + self.orig_height - round(inputImgHeight / invertedScaleY) + ) // 2 + coords = boxes[:, :LABELS_START] + coords -= (padLeft, padTop, padLeft, padTop) + coords *= (invertedScaleX, invertedScaleY, invertedScaleX, invertedScaleY) + + intboxes = np.round(coords, out=coords).astype(np.int32) + np.clip( + intboxes, + 0, + [inputImgWidth, inputImgHeight, inputImgWidth, inputImgHeight], + intboxes, + ) + intid = boxes[:, 5].astype(np.int32) + return DetectionResult( + [ + Detection( + *intboxes[i], boxes[i, 4], intid[i], self.get_label_name(intid[i]) + ) + for i in range(len(boxes)) + ], + np.ndarray(0), + np.ndarray(0), + ) + + +class YOLOv8(YOLOv5): + """YOLOv5 and YOLOv8 are identical in terms of inference""" + + __model__ = "YOLOv8" diff --git a/model_api/python/setup.py b/model_api/python/setup.py index 678438ca..9040da03 100755 --- a/model_api/python/setup.py +++ b/model_api/python/setup.py @@ -38,7 +38,12 @@ install_requires=(SETUP_DIR / "requirements.txt").read_text(), extras_require={ "ovms": (SETUP_DIR / "requirements_ovms.txt").read_text(), - "tests": ["httpx", "pytest", "openvino-dev[onnx,pytorch,tensorflow2]"], + "tests": [ + "httpx", + "pytest", + "openvino-dev[onnx,pytorch,tensorflow2]", + "ultralytics>=8.0.114", + ], }, long_description=(SETUP_DIR.parents[1] / "README.md").read_text(), long_description_content_type="text/markdown", diff --git a/tests/cpp/accuracy/CMakeLists.txt b/tests/cpp/accuracy/CMakeLists.txt index ac0ee387..d1bf2447 100644 --- a/tests/cpp/accuracy/CMakeLists.txt +++ b/tests/cpp/accuracy/CMakeLists.txt @@ -69,3 +69,4 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime) add_subdirectory(../../../model_api/cpp ${tests_BINARY_DIR}/model_api/cpp) add_test(NAME test_accuracy SOURCES test_accuracy.cpp DEPENDENCIES model_api) +add_test(NAME test_YOLOv8 SOURCES test_YOLOv8.cpp DEPENDENCIES model_api) diff --git a/tests/cpp/accuracy/test_YOLOv8.cpp b/tests/cpp/accuracy/test_YOLOv8.cpp new file mode 100644 index 00000000..4e681ee5 --- /dev/null +++ b/tests/cpp/accuracy/test_YOLOv8.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + +#include + +#include +#include + +using namespace std; + +namespace { +TEST(YOLOv8, Detector) { + // Get data from env var, not form cmd arg to stay aligned with Python version + const char* const data = getenv("DATA"); + ASSERT_NE(data, nullptr); + const string& exported_path = string{data} + "/ultralytics/detectors/"; + for (const string model_name : {"yolov5mu_openvino_model", "yolov8l_openvino_model"}) { + filesystem::path xml; + for (auto const& dir_entry : filesystem::directory_iterator{exported_path + model_name}) { + const filesystem::path& path = dir_entry.path(); + if (".xml" == path.extension()) { + ASSERT_TRUE(xml.empty()); + xml = path; + } + } + bool preload = true; + unique_ptr yoloV8 = DetectionModel::create_model(xml.string(), {}, "", preload, "CPU"); + vector refpaths; + for (auto const& dir_entry : filesystem::directory_iterator{exported_path + model_name + "/ref/"}) { + refpaths.push_back(dir_entry.path()); + } + ASSERT_GT(refpaths.size(), 0); + sort(refpaths.begin(), refpaths.end()); + for (filesystem::path refpath : refpaths) { + ifstream file{refpath}; + stringstream ss; + ss << file.rdbuf(); + EXPECT_EQ(ss.str(), std::string{*yoloV8->infer(cv::imread(string{data} + "/coco128/images/train2017/" + refpath.stem().string() + ".jpg"))}); + } + } +} +} diff --git a/tests/python/accuracy/test_YOLOv8.py b/tests/python/accuracy/test_YOLOv8.py new file mode 100644 index 00000000..9db44e9d --- /dev/null +++ b/tests/python/accuracy/test_YOLOv8.py @@ -0,0 +1,92 @@ +import functools +import os +from pathlib import Path + +import cv2 +import numpy as np +import openvino.runtime as ov +import pytest +from openvino.model_api.models import YOLOv5 +from ultralytics import YOLO + + +def _init_predictor(yolo): + yolo.predict(np.empty([1, 1, 3], np.uint8)) + + +@functools.lru_cache(maxsize=1) +def _cached_models(folder, pt): + export_dir = Path( + YOLO(folder / "ultralytics/detectors" / pt, "detect").export(format="openvino") + ) + impl_wrapper = YOLOv5.create_model(export_dir / (pt.stem + ".xml"), device="CPU") + ref_wrapper = YOLO(export_dir, "detect") + ref_wrapper.overrides["imgsz"] = (impl_wrapper.w, impl_wrapper.h) + _init_predictor(ref_wrapper) + ref_wrapper.predictor.model.ov_compiled_model = ov.Core().compile_model( + ref_wrapper.predictor.model.ov_model, "CPU" + ) + ref_dir = export_dir / "ref" + ref_dir.mkdir(exist_ok=True) + return impl_wrapper, ref_wrapper, ref_dir + + +def _impaths(): + """ + It's impossible to pass fixture as argument for + @pytest.mark.parametrize, so I can't take cmd arg. Use env var + instead. Another solution was to define + pytest_generate_tests(metafunc) in conftest.py + """ + impaths = sorted( + file + for file in (Path(os.environ["DATA"]) / "coco128/images/train2017/").iterdir() + if file.name + not in { # This images fail because image preprocessing is imbedded into the model + "000000000143.jpg", + "000000000491.jpg", + "000000000536.jpg", + "000000000581.jpg", + } + ) + if not impaths: + raise RuntimeError( + f"{Path(os.environ['DATA']) / 'coco128/images/train2017/'} is empty" + ) + return impaths + + +@pytest.mark.parametrize("impath", _impaths()) +@pytest.mark.parametrize("pt", [Path("yolov5mu.pt"), Path("yolov8l.pt")]) +def test_detector(impath, pt): + impl_wrapper, ref_wrapper, ref_dir = _cached_models(Path(os.environ["DATA"]), pt) + im = cv2.imread(str(impath)) + assert im is not None + impl_preds = impl_wrapper(im) + pred_boxes = np.array( + [ + [ + impl_pred.xmin, + impl_pred.ymin, + impl_pred.xmax, + impl_pred.ymax, + impl_pred.score, + impl_pred.id, + ] + for impl_pred in impl_preds.objects + ], + dtype=np.float32, + ) + ref_predictions = ref_wrapper.predict(im) + assert 1 == len(ref_predictions) + ref_boxes = ref_predictions[0].boxes.data.numpy() + with open(ref_dir / impath.with_suffix(".txt").name, "w") as file: + print(impl_preds, end="", file=file) + if 0 == pred_boxes.size == ref_boxes.size: + return # np.isclose() doesn't work for empty arrays + ref_boxes[:, :4] = np.round(ref_boxes[:, :4], out=ref_boxes[:, :4]) + assert np.isclose( + pred_boxes[:, :4], ref_boxes[:, :4], 0, 1 + ).all() # Allow one pixel deviation because image preprocessing is imbedded into the model + assert np.isclose(pred_boxes[:, 4], ref_boxes[:, 4], 0.0, 0.02).all() + assert (pred_boxes[:, 5] == ref_boxes[:, 5]).all()