Skip to content

Commit

Permalink
adjust openvino2024 new api
Browse files Browse the repository at this point in the history
  • Loading branch information
rofgmd committed Dec 18, 2024
1 parent 3f8604b commit bb7fed3
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions mmdeploy/backend/openvino/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ def __init__(self,
output_names: Optional[Sequence[str]] = None,
**kwargs):

from openvino.inference_engine import IECore
from openvino.runtime import Core as IECore
self.ie = IECore()
bin_path = osp.splitext(ir_model_file)[0] + '.bin'
self.net = self.ie.read_network(ir_model_file, bin_path)
for input in self.net.input_info.values():
batch_size = input.input_data.shape[0]
dims = len(input.input_data.shape)
self.net = self.ie.read_model(model=ir_model_file, weights=bin_path)
for input in self.net.inputs:
batch_size = input.get_partial_shape()[0]
dims = len(input.get_partial_shape())
# if input is a image, it has (B,C,H,W) channels,
# need batch_size==1
assert not dims == 4 or batch_size == 1, \
'Only batch 1 is supported.'
self.device = 'cpu'
self.sess = self.ie.load_network(
network=self.net, device_name=self.device.upper(), num_requests=1)
self.sess = self.ie.compile_model(
model=self.net, device_name=self.device.upper())

# TODO: Check if output_names can be read
if output_names is None:
Expand Down Expand Up @@ -84,7 +84,8 @@ def __reshape(self, inputs: Dict[str, torch.Tensor]):
input_shapes = {name: data.shape for name, data in inputs.items()}
reshape_needed = False
for input_name, input_shape in input_shapes.items():
blob_shape = self.net.input_info[input_name].input_data.shape
input_node = next(input for input in self.net.inputs if input.get_any_name() == input_name)
blob_shape = input_node.get_partial_shape()
if not np.array_equal(input_shape, blob_shape):
reshape_needed = True
break
Expand Down Expand Up @@ -147,5 +148,8 @@ def __openvino_execute(
Returns:
Dict[str, numpy.ndarray]: The output name and tensor pairs.
"""
outputs = self.sess.infer(inputs)
infer_request = self.sess.create_infer_request()
infer_request.infer(inputs)
outputs = {output.get_any_name(): infer_request.get_tensor(output).data
for output in self.net.outputs}
return outputs

0 comments on commit bb7fed3

Please sign in to comment.