Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): ML on Rockchip NPUs #15241

Open
wants to merge 78 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
8ef3e49
untested
yoni13 Nov 29, 2024
6ffc227
test
yoni13 Nov 29, 2024
7fddf28
lowercase
yoni13 Nov 30, 2024
bc849e2
ViT-B-32__openai/textual/ Runs with emulator now.
yoni13 Dec 1, 2024
b6c4b37
Merge branch 'immich-app:main' into rknn-toolkit2
yoni13 Dec 3, 2024
4140e93
Merge branch 'immich-app:main' into rknn-toolkit-lite2
yoni13 Dec 4, 2024
257cc6c
Init commit for using rknn, RecognitionFormDataLoadTest doesnt work
yoni13 Dec 4, 2024
da152bd
Merge branch 'immich-app:main' into rknn-toolkit-lite2
yoni13 Dec 13, 2024
082c426
Merge branch 'immich-app:main' into rknn-toolkit-lite2
yoni13 Dec 25, 2024
a94fad5
all infrencing works with 1 max job concurrency
yoni13 Jan 9, 2025
8608b9c
Merge branch 'immich-app:main' into rknn-toolkit-lite2
yoni13 Jan 9, 2025
9bc3e5b
Update rknn.py
yoni13 Jan 10, 2025
4d704e9
fix inf,-inf with 2 concurrency
yoni13 Jan 10, 2025
a2722e1
Revert my changes to dockerfiles
yoni13 Jan 11, 2025
c20d110
support for rknn.rknnpool.is_available
yoni13 Jan 11, 2025
66004e3
Merge branch 'immich-app:main' into rknn-toolkit-lite2
yoni13 Jan 11, 2025
d10147f
Handling Import and file not found Error for non-arm devices.
yoni13 Jan 11, 2025
d5ef821
Set group RKNN to optional
yoni13 Jan 11, 2025
506ca0d
Dockerfile for rknn
yoni13 Jan 11, 2025
7aaf3aa
Remove unused imports.
yoni13 Jan 11, 2025
f4671f4
Indentation issue
yoni13 Jan 11, 2025
7f2af6f
Fix typo: rknnlite.api
yoni13 Jan 11, 2025
d5e453a
ruff format
yoni13 Jan 11, 2025
23d0ea0
ruff
yoni13 Jan 11, 2025
4162119
Check if NPU drivers is loaded or not.
yoni13 Jan 11, 2025
815ed1a
Install onnxruntime
yoni13 Jan 11, 2025
807111e
Should Fix No module named 'rknn'
yoni13 Jan 11, 2025
665718b
add rknn to src
yoni13 Jan 11, 2025
efaf70e
Set running threads from env
yoni13 Jan 11, 2025
19ee48f
fix path
yoni13 Jan 11, 2025
c72cf61
support core_mask for specfic socs
yoni13 Jan 11, 2025
c665fd2
Fix Please do not set this parameter on other platforms.
yoni13 Jan 11, 2025
e6ff21b
set default thread num to 2, not everyone has 8 gigs of ram
yoni13 Jan 12, 2025
c109e28
DOCS
yoni13 Jan 12, 2025
bb67a9d
fix formatting
yoni13 Jan 12, 2025
68fccad
Fix docs.
yoni13 Jan 12, 2025
1775397
Sort them by alphablet
yoni13 Jan 12, 2025
7ae4b71
format be happy
yoni13 Jan 12, 2025
8965a9f
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 12, 2025
4c7ac14
only load knnx model when required
yoni13 Jan 12, 2025
daf8860
Add export script
yoni13 Jan 13, 2025
ebdfe1b
Load model by SOC name
yoni13 Jan 13, 2025
2f7e44a
typing be happy.
yoni13 Jan 13, 2025
f328104
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 13, 2025
b6cc205
ignore rknn model if not using it
yoni13 Jan 13, 2025
6c4e6cb
reformat
yoni13 Jan 13, 2025
4b0f93c
add test,founds bugs, fix it tomorrow
yoni13 Jan 13, 2025
8b80d03
fixed some bugs
yoni13 Jan 14, 2025
5244ed6
black app export
yoni13 Jan 14, 2025
cb01a11
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 14, 2025
c21ce40
switch to Runtime error instead of exit()
yoni13 Jan 14, 2025
0f03f77
remove non implemented tests
yoni13 Jan 14, 2025
b5a4ed5
this duplicated?
yoni13 Jan 14, 2025
01eb095
trying to fix pytest
yoni13 Jan 14, 2025
9882b83
Should FIx the quote that made mypy unhappy
yoni13 Jan 14, 2025
f32d991
changes some cases
yoni13 Jan 17, 2025
26d5fb0
add checksum for libnnrt.so
yoni13 Jan 17, 2025
bc48b67
switch to sha256
yoni13 Jan 17, 2025
f067212
tpe
yoni13 Jan 17, 2025
0567592
remove unrequired devices
yoni13 Jan 17, 2025
3634ae1
fix granularity
yoni13 Jan 18, 2025
f5de3de
fix typo and add a propper var name
yoni13 Jan 18, 2025
be76857
make these functions snake case.
yoni13 Jan 18, 2025
87a46dc
remove unnecessary print
yoni13 Jan 18, 2025
d7381ab
refactor ignore_patterns
yoni13 Jan 18, 2025
9926045
add a simple script to notify user if some op is not supported
yoni13 Jan 18, 2025
4e42fbc
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 18, 2025
b3ae5d3
fix typo in tests
yoni13 Jan 18, 2025
d2b7e10
shellcheck happy
yoni13 Jan 18, 2025
58f1cc9
prettier happy
yoni13 Jan 18, 2025
32f3707
fix types and ignored pattern
yoni13 Jan 18, 2025
1653cd9
update supported SOCs
yoni13 Jan 18, 2025
2b967ca
raise NotImplementedError for now
yoni13 Jan 19, 2025
ac4ce3e
add input outputs
yoni13 Jan 19, 2025
20ba9f9
update mapping
yoni13 Jan 19, 2025
59e4b65
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 19, 2025
dd52c2d
Update permission
yoni13 Jan 19, 2025
794da29
Merge branch 'main' into rknn-toolkit-lite2
yoni13 Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
suffix: ["", "-cuda", "-openvino", "-armnn"]
suffix: ["", "-cuda", "-openvino", "-armnn","-rknn"]
steps:
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
Expand Down Expand Up @@ -116,6 +116,9 @@ jobs:
- platforms: linux/arm64
device: armnn
suffix: -armnn
- platforms: linux/arm64
device: rknn
suffix: -rknn

steps:
- name: Checkout
Expand Down Expand Up @@ -307,4 +310,4 @@ jobs:
run: exit 1
- name: All jobs passed or skipped
if: ${{ !(contains(needs.*.result, 'failure')) }}
run: echo "All jobs passed or skipped" && echo "${{ toJSON(needs.*.result) }}"
run: echo "All jobs passed or skipped" && echo "${{ toJSON(needs.*.result) }}"
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ jobs:
poetry run black --check app export
- name: Run mypy type checking
run: |
poetry run mypy --install-types --non-interactive --strict app/
mkdir .mypy_cache && poetry run mypy --install-types --non-interactive --strict app/
- name: Run tests and coverage
run: |
poetry run pytest app --cov=app --cov-report term-missing
Expand Down
4 changes: 2 additions & 2 deletions docker/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ services:
image: immich-machine-learning-dev:latest
# extends:
# file: hwaccel.ml.yml
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl, rknn] for accelerated inference
build:
context: ../machine-learning
dockerfile: Dockerfile
args:
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl, rknn] for accelerated inference
ports:
- 3003:3003
volumes:
Expand Down
20 changes: 5 additions & 15 deletions docker/docker-compose.prod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ services:
image: immich-machine-learning:latest
# extends:
# file: hwaccel.ml.yml
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl, rknn] for accelerated inference
build:
context: ../machine-learning
dockerfile: Dockerfile
args:
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference
- DEVICE=cpu # set to one of [armnn, cuda, openvino, openvino-wsl, rknn] for accelerated inference
ports:
- 3003:3003
volumes:
Expand Down Expand Up @@ -68,22 +68,12 @@ services:
- 5432:5432
healthcheck:
test: >-
pg_isready --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" || exit 1;
Chksum="$$(psql --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" --tuples-only --no-align
--command='SELECT COALESCE(SUM(checksum_failures), 0) FROM pg_stat_database')";
echo "checksum failure count is $$Chksum";
[ "$$Chksum" = '0' ] || exit 1
pg_isready --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" || exit 1; Chksum="$$(psql --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" --tuples-only --no-align --command='SELECT COALESCE(SUM(checksum_failures), 0) FROM pg_stat_database')"; echo "checksum failure count is $$Chksum"; [ "$$Chksum" = '0' ] || exit 1
interval: 5m
start_interval: 30s
start_period: 5m
command: >-
postgres
-c shared_preload_libraries=vectors.so
-c 'search_path="$$user", public, vectors'
-c logging_collector=on
-c max_wal_size=2GB
-c shared_buffers=512MB
-c wal_compression=on
postgres -c shared_preload_libraries=vectors.so -c 'search_path="$$user", public, vectors' -c logging_collector=on -c max_wal_size=2GB -c shared_buffers=512MB -c wal_compression=on
restart: always

# set IMMICH_TELEMETRY_INCLUDE=all in .env to enable metrics
Expand All @@ -100,7 +90,7 @@ services:
# add data source for http://immich-prometheus:9090 to get started
immich-grafana:
container_name: immich_grafana
command: ['./run.sh', '-disable-reporting']
command: [ './run.sh', '-disable-reporting' ]
ports:
- 3000:3000
image: grafana/grafana:11.4.0-ubuntu@sha256:afccec22ba0e4815cca1d2bf3836e414322390dc78d77f1851976ffa8d61051c
Expand Down
18 changes: 4 additions & 14 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ services:

immich-machine-learning:
container_name: immich_machine_learning
# For hardware acceleration, add one of -[armnn, cuda, openvino] to the image tag.
# For hardware acceleration, add one of -[armnn, cuda, openvino, rknn] to the image tag.
# Example tag: ${IMMICH_VERSION:-release}-cuda
image: ghcr.io/immich-app/immich-machine-learning:${IMMICH_VERSION:-release}
# extends: # uncomment this section for hardware acceleration - see https://immich.app/docs/features/ml-hardware-acceleration
# file: hwaccel.ml.yml
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl] for accelerated inference - use the `-wsl` version for WSL2 where applicable
# service: cpu # set to one of [armnn, cuda, openvino, openvino-wsl, rknn] for accelerated inference - use the `-wsl` version for WSL2 where applicable
volumes:
- model-cache:/cache
env_file:
Expand Down Expand Up @@ -66,22 +66,12 @@ services:
- ${DB_DATA_LOCATION}:/var/lib/postgresql/data
healthcheck:
test: >-
pg_isready --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" || exit 1;
Chksum="$$(psql --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" --tuples-only --no-align
--command='SELECT COALESCE(SUM(checksum_failures), 0) FROM pg_stat_database')";
echo "checksum failure count is $$Chksum";
[ "$$Chksum" = '0' ] || exit 1
pg_isready --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" || exit 1; Chksum="$$(psql --dbname="$${POSTGRES_DB}" --username="$${POSTGRES_USER}" --tuples-only --no-align --command='SELECT COALESCE(SUM(checksum_failures), 0) FROM pg_stat_database')"; echo "checksum failure count is $$Chksum"; [ "$$Chksum" = '0' ] || exit 1
interval: 5m
start_interval: 30s
start_period: 5m
command: >-
postgres
-c shared_preload_libraries=vectors.so
-c 'search_path="$$user", public, vectors'
-c logging_collector=on
-c max_wal_size=2GB
-c shared_buffers=512MB
-c wal_compression=on
postgres -c shared_preload_libraries=vectors.so -c 'search_path="$$user", public, vectors' -c logging_collector=on -c max_wal_size=2GB -c shared_buffers=512MB -c wal_compression=on
restart: always

volumes:
Expand Down
12 changes: 12 additions & 0 deletions docker/hwaccel.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ services:
volumes:
- /lib/firmware/mali_csffw.bin:/lib/firmware/mali_csffw.bin:ro # Mali firmware for your chipset (not always required depending on the driver)
- /usr/lib/libmali.so:/usr/lib/libmali.so:ro # Mali driver for your chipset (always required)

rknn:
security_opt:
- systempaths=unconfined
- apparmor=unconfined
devices:
- /dev/rga:/dev/rga
- /dev/dri:/dev/dri
- /dev/dma_heap:/dev/dma_heap
- /dev/mpp_service:/dev/mpp_service
volumes:
- /sys/kernel/debug/:/sys/kernel/debug/:ro
yoni13 marked this conversation as resolved.
Show resolved Hide resolved

cpu: {}

Expand Down
10 changes: 10 additions & 0 deletions docs/docs/features/ml-hardware-acceleration.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ You do not need to redo any machine learning jobs after enabling hardware accele
- ARM NN (Mali)
- CUDA (NVIDIA GPUs with [compute capability](https://developer.nvidia.com/cuda-gpus) 5.2 or higher)
- OpenVINO (Intel discrete GPUs such as Iris Xe and Arc)
- RKNN (Rockchip)
yoni13 marked this conversation as resolved.
Show resolved Hide resolved

## Limitations

Expand Down Expand Up @@ -46,6 +47,15 @@ You do not need to redo any machine learning jobs after enabling hardware accele
- The server must have a discrete GPU, i.e. Iris Xe or Arc. Expect issues when attempting to use integrated graphics.
- Ensure the server's kernel version is new enough to use the device for hardware accceleration.

#### RKNN

- You must have a supported Rockchip SoC, only RK3566 and RK3588 are supported at this moment.
- Make sure you have the appropriate linux kernel driver installed
- This is usually pre-installed on the device vendor's Linux images
- RKNPU driver V0.9.8 or later must be available in the host server
- You may confirm this by running `cat /sys/kernel/debug/rknpu/version` to check the version
- Optional: Configure your `.env` file, see [environment variables](/docs/install/environment-variables) for RKNN specific settings

## Setup

1. If you do not already have it, download the latest [`hwaccel.ml.yml`][hw-file] file and ensure it's in the same folder as the `docker-compose.yml`.
Expand Down
4 changes: 4 additions & 0 deletions docs/docs/install/environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ Redis (Sentinel) URL example JSON before encoding:
| `MACHINE_LEARNING_ANN_TUNING_LEVEL` | ARM-NN GPU tuning level (1: rapid, 2: normal, 3: exhaustive) | `2` | machine learning |
| `MACHINE_LEARNING_DEVICE_IDS`<sup>\*4</sup> | Device IDs to use in multi-GPU environments | `0` | machine learning |
| `MACHINE_LEARNING_MAX_BATCH_SIZE__FACIAL_RECOGNITION` | Set the maximum number of faces that will be processed at once by the facial recognition model | None (`1` if using OpenVINO) | machine learning |
| `MACHINE_LEARNING_RKNN` | Enable RKNN hardware acceleration if supported | `True` | machine learning |
| `MACHINE_LEARNING_RKNN_TEXTUAL_THREADS` | How many threads of RKNN runtime should be spinned up while infrencing textual model. | `1` | machine learning |
| `MACHINE_LEARNING_RKNN_VISUAL_THREADS` | How many threads of RKNN runtime should be spinned up while infrencing visual model. | `1` | machine learning |
| `MACHINE_LEARNING_RKNN_FACIAL_THREADS` | How many threads of RKNN runtime should be spinned up while infrencing facial model. | `1` | machine learning |

\*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones.

Expand Down
9 changes: 8 additions & 1 deletion machine-learning/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ RUN mkdir /opt/armnn && \
cd /opt/ann && \
sh build.sh

FROM builder-cpu AS builder-rknn

FROM builder-${DEVICE} AS builder

ARG DEVICE
Expand Down Expand Up @@ -80,6 +82,10 @@ COPY --from=builder-armnn \
/opt/ann/build.sh \
/opt/armnn/

FROM prod-cpu AS prod-rknn

ADD https://github.com/airockchip/rknn-toolkit2/raw/refs/tags/v2.3.0/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so /usr/lib/
yoni13 marked this conversation as resolved.
Show resolved Hide resolved

FROM prod-${DEVICE} AS prod
ARG DEVICE

Expand All @@ -104,9 +110,10 @@ RUN echo "hard core 0" >> /etc/security/limits.conf && \

COPY --from=builder /opt/venv /opt/venv
COPY ann/ann.py /usr/src/ann/ann.py
COPY rknn/rknnpool.py /usr/src/rknn/rknnpool.py
COPY start.sh log_conf.json gunicorn_conf.py ./
COPY app .
ENTRYPOINT ["tini", "--"]
CMD ["./start.sh"]

HEALTHCHECK CMD python3 healthcheck.py
HEALTHCHECK CMD python3 healthcheck.py
4 changes: 4 additions & 0 deletions machine-learning/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class Settings(BaseSettings):
ann: bool = True
ann_fp16_turbo: bool = False
ann_tuning_level: int = 2
rknn: bool = True
rknn_textual_threads: int = 1
rknn_visual_threads: int = 1
rknn_facial_detection_threads: int = 1
yoni13 marked this conversation as resolved.
Show resolved Hide resolved
preload: PreloadModelData | None = None
max_batch_size: MaxBatchSize | None = None

Expand Down
6 changes: 6 additions & 0 deletions machine-learning/app/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def ann_session() -> Iterator[mock.Mock]:
yield mocked


@pytest.fixture(scope="function")
def rknn_session() -> Iterator[mock.Mock]:
with mock.patch("app.sessions.rknn.rknnPoolExecutor") as mocked:
yield mocked


@pytest.fixture(scope="function")
def rmtree() -> Iterator[mock.Mock]:
with mock.patch("app.models.base.rmtree", autospec=True) as mocked:
Expand Down
13 changes: 12 additions & 1 deletion machine-learning/app/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from huggingface_hub import snapshot_download

import ann.ann
import rknn.rknnpool
from app.sessions.ort import OrtSession
from app.sessions.rknn import RknnSession

from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
Expand Down Expand Up @@ -67,6 +69,8 @@ def configure(self, **kwargs: Any) -> None:

def _download(self) -> None:
ignore_patterns = [] if self.model_format == ModelFormat.ARMNN else ["*.armnn"]
if self.model_format != ModelFormat.RKNN:
ignore_patterns.append("*.rknn")
yoni13 marked this conversation as resolved.
Show resolved Hide resolved
snapshot_download(
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
Expand Down Expand Up @@ -108,6 +112,8 @@ def _make_session(self, model_path: Path) -> ModelSession:
session: ModelSession = AnnSession(model_path)
case ".onnx":
session = OrtSession(model_path)
case ".rknn":
session = RknnSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
Expand Down Expand Up @@ -155,4 +161,9 @@ def model_format(self, model_format: ModelFormat) -> None:

@property
def _model_format_default(self) -> ModelFormat:
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
if rknn.rknnpool.is_available and settings.rknn:
return ModelFormat.RKNN
elif ann.ann.is_available and settings.ann:
return ModelFormat.ARMNN
else:
return ModelFormat.ONNX
1 change: 1 addition & 0 deletions machine-learning/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ModelType(StrEnum):
class ModelFormat(StrEnum):
ARMNN = "armnn"
ONNX = "onnx"
RKNN = "rknn"


class ModelSource(StrEnum):
Expand Down
72 changes: 72 additions & 0 deletions machine-learning/app/sessions/rknn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray

from app.schemas import SessionNode
from rknn.rknnpool import rknnPoolExecutor, soc_name

from ..config import log, settings


def runInfrence(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
yoni13 marked this conversation as resolved.
Show resolved Hide resolved
outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")

return outputs


class RknnSession:
def __init__(self, model_path: Path | str):
self.model_path = Path(str(model_path).replace("model", soc_name))
self.ort_model_path = Path(str(self.model_path).replace(f"{soc_name}.rknn", "model.onnx"))

if "textual" in str(self.model_path):
self.tpe = settings.rknn_textual_threads
elif "visual" in str(self.model_path):
self.tpe = settings.rknn_visual_threads
else:
self.tpe = settings.rknn_facial_detection_threads

log.info(f"Loading RKNN model from {self.model_path} with {self.tpe} threads.")
self.rknnpool = rknnPoolExecutor(rknnModel=self.model_path.as_posix(), TPEs=self.tpe, func=runInfrence)
log.info(f"Loaded RKNN model from {self.model_path} with {self.tpe} threads.")

def __del__(self) -> None:
self.rknnpool.release()

def _load_ort_session(self) -> None:
self.ort_session = ort.InferenceSession(
self.ort_model_path.as_posix(),
)
self.inputs: list[SessionNode] = self.ort_session.get_inputs()
self.outputs: list[SessionNode] = self.ort_session.get_outputs()
del self.ort_session

def get_inputs(self) -> list[SessionNode]:
try:
return self.inputs
except AttributeError:
self._load_ort_session()
return self.inputs

def get_outputs(self) -> list[SessionNode]:
try:
return self.outputs
except AttributeError:
self._load_ort_session()
return self.outputs
yoni13 marked this conversation as resolved.
Show resolved Hide resolved

def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
self.rknnpool.put(input_data)
yoni13 marked this conversation as resolved.
Show resolved Hide resolved
outputs: list[NDArray[np.float32]] = self.rknnpool.get()
return outputs
Loading
Loading