Skip to content

Commit

Permalink
feat: add embedding sort
Browse files Browse the repository at this point in the history
Requires a CUDA-enabled GPU.
  • Loading branch information
kevinsbarnard committed May 10, 2024
1 parent 7187b41 commit 60798b2
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 39 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ qdarkstyle = "^3.0.3"
beholder-client = "^0.1.0"
sharktopoda-client = "^0.4.5"
platformdirs = "^4.0.0"
dreamsim = "^0.1.3"

[tool.poetry.scripts]
vars-gridview = "vars_gridview.scripts.run:main"
Expand Down
5 changes: 4 additions & 1 deletion vars_gridview/lib/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@
CACHE_DIR_DEFAULT = Path(user_cache_dir(APP_NAME))

# Sharktopoda
SHARKTOPODA_APP_NAME = "Sharktopoda"
SHARKTOPODA_APP_NAME = "Sharktopoda"

# Embeddings
EMBEDDINGS_ENABLED_DEFAULT = False
55 changes: 55 additions & 0 deletions vars_gridview/lib/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import ABC
from pathlib import Path

import numpy as np
from dreamsim import dreamsim
from PIL import Image

from vars_gridview.lib.settings import SettingsManager


class Embedding(ABC):
"""
Embedding abstract base class. Produces embedding vectors for images.
"""

def embed(self, image: np.ndarray) -> np.ndarray:
"""
Embed an image.
Args:
image (np.ndarray): Image as an RGB (h,w,3) Numpy array.
Returns:
np.ndarray: Vector embedding (n,) for the image.
"""
raise NotImplementedError()


class DreamSimEmbedding(Embedding):
"""
DreamSim embedding.
"""

CACHE_SUBDIR_NAME = "dreamsim"

def __init__(self) -> None:
settings = SettingsManager.get_instance()
base_cache_dir = Path(settings.cache_dir.value)
dreamsim_cache_dir = base_cache_dir / DreamSimEmbedding.CACHE_SUBDIR_NAME

# Download / load the models
self._model, self._preprocess = dreamsim(
pretrained=True,
device="cuda",
cache_dir=str(dreamsim_cache_dir.resolve().absolute()),
)

def embed(self, image: np.ndarray) -> np.ndarray:
# Preprocess the image
image_pil = Image.fromarray(image)
image_tensor = self._preprocess(image_pil).cuda()

# Compute the embedding
embedding = self._model.embed(image_tensor).cpu().detach().numpy().flatten()
return embedding
27 changes: 26 additions & 1 deletion vars_gridview/lib/image_mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vars_gridview.lib import m3
from vars_gridview.lib.annotation import VARSLocalization
from vars_gridview.lib.cache import CacheController
from vars_gridview.lib.embedding import Embedding
from vars_gridview.lib.log import LOGGER
from vars_gridview.lib.m3 import operations
from vars_gridview.lib.sort_methods import SortMethod
Expand All @@ -42,6 +43,7 @@ def __init__(
rect_clicked_slot: callable,
verifier: str,
zoom: float = 1.0,
embedding_model: Optional[Embedding] = None,
):
super().__init__()

Expand All @@ -62,6 +64,8 @@ def __init__(

self.cache_controller = cache_controller

self._embedding_model = embedding_model

self.verifier = verifier

self.image_reference_urls = {}
Expand Down Expand Up @@ -511,16 +515,36 @@ def __init__(
video_data,
observer,
len(other_locs),
embedding_model=self._embedding_model,
)
rw.text_label = localization.text_label
rw.update_zoom(zoom)
rw.clicked.connect(rect_clicked_slot)
rw.similaritySort.connect(self._similarity_sort_slot)
self._rect_widgets.append(rw)

localization.rect = rw # Back reference

self.n_localizations += 1

def _similarity_sort_slot(self, clicked_rect: RectWidget, same_class_only: bool):
def key(rect_widget: RectWidget) -> float:
if same_class_only and clicked_rect.text_label != rect_widget.text_label:
return float("inf")
return clicked_rect.embedding_distance(rect_widget)

# Sort the rects by distance
self._rect_widgets.sort(key=key)

# Re-render the mosaic
self.render_mosaic()

def update_embedding_model(self, embedding_model: Embedding):
self._embedding_model = embedding_model
for rect_widget in self._rect_widgets:
rect_widget.update_embedding_model(embedding_model)
rect_widget.update_embedding()

def find_mp4_video_data(
self, video_sequence_name: str, timestamp: datetime
) -> Optional[dict]:
Expand Down Expand Up @@ -651,7 +675,8 @@ def render_mosaic(self):
rect_widgets_to_render = [
rw
for rw in self._rect_widgets
if (not rw.is_verified and not self._hide_unlabeled) or (rw.is_verified and not self._hide_labeled)
if (not rw.is_verified and not self._hide_unlabeled)
or (rw.is_verified and not self._hide_labeled)
]

# Hide all rect widgets that we aren't rendering
Expand Down
Loading

0 comments on commit 60798b2

Please sign in to comment.