diff --git a/pyproject.toml b/pyproject.toml index 6ff93ca..1e970ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ qdarkstyle = "^3.0.3" beholder-client = "^0.1.0" sharktopoda-client = "^0.4.5" platformdirs = "^4.0.0" +dreamsim = "^0.1.3" [tool.poetry.scripts] vars-gridview = "vars_gridview.scripts.run:main" diff --git a/vars_gridview/lib/constants.py b/vars_gridview/lib/constants.py index 5eb814b..5c3f095 100644 --- a/vars_gridview/lib/constants.py +++ b/vars_gridview/lib/constants.py @@ -42,4 +42,7 @@ CACHE_DIR_DEFAULT = Path(user_cache_dir(APP_NAME)) # Sharktopoda -SHARKTOPODA_APP_NAME = "Sharktopoda" \ No newline at end of file +SHARKTOPODA_APP_NAME = "Sharktopoda" + +# Embeddings +EMBEDDINGS_ENABLED_DEFAULT = False diff --git a/vars_gridview/lib/embedding.py b/vars_gridview/lib/embedding.py new file mode 100644 index 0000000..63e2108 --- /dev/null +++ b/vars_gridview/lib/embedding.py @@ -0,0 +1,55 @@ +from abc import ABC +from pathlib import Path + +import numpy as np +from dreamsim import dreamsim +from PIL import Image + +from vars_gridview.lib.settings import SettingsManager + + +class Embedding(ABC): + """ + Embedding abstract base class. Produces embedding vectors for images. + """ + + def embed(self, image: np.ndarray) -> np.ndarray: + """ + Embed an image. + + Args: + image (np.ndarray): Image as an RGB (h,w,3) Numpy array. + + Returns: + np.ndarray: Vector embedding (n,) for the image. + """ + raise NotImplementedError() + + +class DreamSimEmbedding(Embedding): + """ + DreamSim embedding. + """ + + CACHE_SUBDIR_NAME = "dreamsim" + + def __init__(self) -> None: + settings = SettingsManager.get_instance() + base_cache_dir = Path(settings.cache_dir.value) + dreamsim_cache_dir = base_cache_dir / DreamSimEmbedding.CACHE_SUBDIR_NAME + + # Download / load the models + self._model, self._preprocess = dreamsim( + pretrained=True, + device="cuda", + cache_dir=str(dreamsim_cache_dir.resolve().absolute()), + ) + + def embed(self, image: np.ndarray) -> np.ndarray: + # Preprocess the image + image_pil = Image.fromarray(image) + image_tensor = self._preprocess(image_pil).cuda() + + # Compute the embedding + embedding = self._model.embed(image_tensor).cpu().detach().numpy().flatten() + return embedding diff --git a/vars_gridview/lib/image_mosaic.py b/vars_gridview/lib/image_mosaic.py index 1f502b0..19d14f7 100644 --- a/vars_gridview/lib/image_mosaic.py +++ b/vars_gridview/lib/image_mosaic.py @@ -19,6 +19,7 @@ from vars_gridview.lib import m3 from vars_gridview.lib.annotation import VARSLocalization from vars_gridview.lib.cache import CacheController +from vars_gridview.lib.embedding import Embedding from vars_gridview.lib.log import LOGGER from vars_gridview.lib.m3 import operations from vars_gridview.lib.sort_methods import SortMethod @@ -42,6 +43,7 @@ def __init__( rect_clicked_slot: callable, verifier: str, zoom: float = 1.0, + embedding_model: Optional[Embedding] = None, ): super().__init__() @@ -62,6 +64,8 @@ def __init__( self.cache_controller = cache_controller + self._embedding_model = embedding_model + self.verifier = verifier self.image_reference_urls = {} @@ -511,16 +515,36 @@ def __init__( video_data, observer, len(other_locs), + embedding_model=self._embedding_model, ) rw.text_label = localization.text_label rw.update_zoom(zoom) rw.clicked.connect(rect_clicked_slot) + rw.similaritySort.connect(self._similarity_sort_slot) self._rect_widgets.append(rw) localization.rect = rw # Back reference self.n_localizations += 1 + def _similarity_sort_slot(self, clicked_rect: RectWidget, same_class_only: bool): + def key(rect_widget: RectWidget) -> float: + if same_class_only and clicked_rect.text_label != rect_widget.text_label: + return float("inf") + return clicked_rect.embedding_distance(rect_widget) + + # Sort the rects by distance + self._rect_widgets.sort(key=key) + + # Re-render the mosaic + self.render_mosaic() + + def update_embedding_model(self, embedding_model: Embedding): + self._embedding_model = embedding_model + for rect_widget in self._rect_widgets: + rect_widget.update_embedding_model(embedding_model) + rect_widget.update_embedding() + def find_mp4_video_data( self, video_sequence_name: str, timestamp: datetime ) -> Optional[dict]: @@ -651,7 +675,8 @@ def render_mosaic(self): rect_widgets_to_render = [ rw for rw in self._rect_widgets - if (not rw.is_verified and not self._hide_unlabeled) or (rw.is_verified and not self._hide_labeled) + if (not rw.is_verified and not self._hide_unlabeled) + or (rw.is_verified and not self._hide_labeled) ] # Hide all rect widgets that we aren't rendering diff --git a/vars_gridview/lib/widgets.py b/vars_gridview/lib/widgets.py index a58e0d1..ccba44f 100644 --- a/vars_gridview/lib/widgets.py +++ b/vars_gridview/lib/widgets.py @@ -12,8 +12,10 @@ import cv2 import numpy as np from PyQt6 import QtCore, QtGui, QtWidgets +from scipy.spatial.distance import cosine from vars_gridview.lib.annotation import VARSLocalization +from vars_gridview.lib.embedding import Embedding from vars_gridview.lib.log import LOGGER from vars_gridview.lib.m3 import operations from vars_gridview.lib.settings import SettingsManager @@ -24,6 +26,7 @@ class RectWidget(QtWidgets.QGraphicsWidget): rectHover = QtCore.pyqtSignal(object) clicked = QtCore.pyqtSignal(object, object) # self, event + similaritySort = QtCore.pyqtSignal(object, bool) # self, same_class_only def __init__( self, @@ -33,6 +36,7 @@ def __init__( video_data: dict, observer: str, localization_index: int, + embedding_model: Optional[Embedding] = None, parent=None, text_label="rect widget", ): @@ -58,8 +62,11 @@ def __init__( self.is_last_selected = False self.is_selected = False + self._embedding_model = embedding_model + self.roi = None self.pic = None + self._embedding = None self.update_roi_pic() self._deleted = False # Flag to indicate if this rect widget has been deleted. Used to prevent double deletion. @@ -132,11 +139,50 @@ def association_uuid(self) -> str: """ return self.localization.association_uuid + @property + def embedding(self): + if self._embedding is None: + self.update_embedding() + return self._embedding + + def update_embedding(self): + """ + Update the embedding value. + + Raises: + ValueError: If the embedding model is None. + """ + if self._embedding_model is None: + raise ValueError( + "Embedding model is not provided; cannot compute embedding" + ) + + self._embedding = self._embedding_model.embed( + self.localization.get_roi(self.image)[::-1] + ) + def update_roi_pic(self): self.roi = self.localization.get_roi(self.image) self.pic = self.getpic(self.roi) + if self._embedding_model is not None: + self.update_embedding() self.update() + def embedding_distance(self, other: "RectWidget") -> float: + """ + Calculate the embedding distance between this rect widget and another. + + Args: + other: The other rect widget to compare to. + + Returns: + The embedding distance between the two rect widgets. + """ + return cosine(self.embedding, other.embedding) + + def update_embedding_model(self, embedding_model: Embedding): + self._embedding_model = embedding_model + @property def is_verified(self) -> bool: return self.localizations[self.localization_index].verified @@ -192,71 +238,76 @@ def get_full_image(self): # self._boundingRect = thumb_widget_rect # return thumb_widget_rect - + @property def outline_x(self): return 0 - + @property def outline_y(self): return 0 - + @property def outline_width(self): return self.picdims[0] + self.bordersize * 2 + self.outlinesize * 2 - + @property def outline_height(self): - return self.picdims[1] + self.labelheight + self.bordersize * 2 + self.outlinesize * 2 - + return ( + self.picdims[1] + + self.labelheight + + self.bordersize * 2 + + self.outlinesize * 2 + ) + @property def border_x(self): return self.outline_x + self.outlinesize - + @property def border_y(self): return self.outline_y + self.outlinesize - + @property def border_width(self): return self.outline_width - self.outlinesize * 2 - + @property def border_height(self): return self.outline_height - self.outlinesize * 2 - + @property def pic_x(self): return self.border_x + self.bordersize - + @property def pic_y(self): return self.border_y + self.bordersize - + @property def pic_width(self): return self.picdims[0] - + @property def pic_height(self): return self.picdims[1] - + @property def label_x(self): return self.pic_x - + @property def label_y(self): return self.pic_y + self.pic_height - + @property def label_width(self): return self.pic_width - + @property def label_height(self): return self.labelheight - + def scale_rect(self, rect: QtCore.QRectF) -> QtCore.QRect: return QtCore.QRect( round(rect.x() * self.zoom), @@ -264,7 +315,7 @@ def scale_rect(self, rect: QtCore.QRectF) -> QtCore.QRect: round(rect.width() * self.zoom), round(rect.height() * self.zoom), ) - + @property def outline_rect(self): rect = QtCore.QRectF( @@ -274,7 +325,7 @@ def outline_rect(self): self.outline_height, ) return self.scale_rect(rect) - + @property def border_rect(self): rect = QtCore.QRectF( @@ -284,7 +335,7 @@ def border_rect(self): self.border_height, ) return self.scale_rect(rect) - + @property def pic_rect(self): rect = QtCore.QRectF( @@ -294,7 +345,7 @@ def pic_rect(self): self.pic_height, ) return self.scale_rect(rect) - + @property def label_rect(self): rect = QtCore.QRectF( @@ -304,7 +355,7 @@ def label_rect(self): self.label_height, ) return self.scale_rect(rect) - + def boundingRect(self): return QtCore.QRectF( self.zoom * self.outline_x, @@ -319,13 +370,13 @@ def sizeHint(self, which, constraint=QtCore.QSizeF()): def getpic(self, roi: np.ndarray) -> QtGui.QPixmap: """ Get the scaled and padded pixmap for the given ROI. - + Fits the ROI into a square of size picdims, scaling it up or down as necessary. Then, pads the ROI with a border to fit the square. - + Args: roi: The ROI to get the pixmap for. - + Returns: The scaled and padded pixmap. """ @@ -333,11 +384,11 @@ def getpic(self, roi: np.ndarray) -> QtGui.QPixmap: roi_height, roi_width, _ = roi.shape max_width = self.pic_width max_height = self.pic_height - + # Scale the ROI to fit the square scale = min(max_width / roi_width, max_height / roi_height) roi = cv2.resize(roi, (0, 0), fx=scale, fy=scale) - + # Pad the image with a border pad_x = (max_width - roi.shape[1]) // 2 pad_y = (max_height - roi.shape[0]) // 2 @@ -350,7 +401,7 @@ def getpic(self, roi: np.ndarray) -> QtGui.QPixmap: cv2.BORDER_CONSTANT, value=(45, 35, 25), ) - + # Convert to Qt pixmap qimg = self.toqimage(roi_padded) orpixmap = QtGui.QPixmap.fromImage(qimg) @@ -412,10 +463,27 @@ def color_for_concept(concept: str): # Draw label text painter.drawText( - self.label_rect, - QtCore.Qt.AlignmentFlag.AlignCenter, - self.text_label + self.label_rect, QtCore.Qt.AlignmentFlag.AlignCenter, self.text_label ) def mousePressEvent(self, event): - self.clicked.emit(self, event) + if event.button() == QtCore.Qt.MouseButton.LeftButton: + self.clicked.emit(self, event) + else: + self.handle_right_click(event) + + def handle_right_click(self, event): + """ + Handle a right click event. Open a context menu with options. + """ + menu = QtWidgets.QMenu() + similarity_sort = menu.addAction("Find similar") + similarity_sort.triggered.connect(lambda: self.similaritySort.emit(self, False)) + similarity_sort_same_label = menu.addAction("Find similar with same label") + similarity_sort_same_label.triggered.connect( + lambda: self.similaritySort.emit(self, True) + ) + no_embedding_model = self._embedding_model is None + similarity_sort.setDisabled(no_embedding_model) + similarity_sort_same_label.setDisabled(no_embedding_model) + menu.exec(event.screenPos()) diff --git a/vars_gridview/scripts/run.py b/vars_gridview/scripts/run.py index 8010bb6..4138e57 100644 --- a/vars_gridview/scripts/run.py +++ b/vars_gridview/scripts/run.py @@ -39,6 +39,7 @@ from vars_gridview.lib import constants, m3, raziel, sql from vars_gridview.lib.boxes import BoxHandler from vars_gridview.lib.cache import CacheController +from vars_gridview.lib.embedding import DreamSimEmbedding, Embedding from vars_gridview.lib.image_mosaic import ImageMosaic from vars_gridview.lib.log import LOGGER, AppLogger from vars_gridview.lib.m3.operations import get_kb_concepts, get_kb_name, get_kb_parts @@ -99,10 +100,14 @@ def __init__(self, app): self.last_selected_rect = None # Last selected ROI - self.image_mosaic = ( - None # Image mosaic (holds the thumbnails as a grid of RectWidgets) - ) - self.box_handler = None # Box handler (handles the ROIs and annotations) + # Image mosaic (holds the thumbnails as a grid of RectWidgets) + self.image_mosaic: Optional[ImageMosaic] = None + + # Box handler (handles the ROIs and annotations) + self.box_handler: Optional[BoxHandler] = None + + # Embedding model + self._embedding_model: Optional[Embedding] = None self.cached_moment_concepts = ( {} @@ -133,6 +138,9 @@ def __init__(self, app): self._settings = SettingsManager.get_instance() self._settings.label_font_size.valueChanged.connect(self.update_layout) + self._settings.embeddings_enabled.valueChanged.connect( + self.update_embeddings_enabled + ) self.settings_dialog = SettingsDialog( self._setup_sharktopoda_client, @@ -177,6 +185,9 @@ def _launch(self): # Set up the menu bar self._setup_menu_bar() + # Set up embeddings + self.update_embeddings_enabled(self._settings.embeddings_enabled.value) + # Set up Sharktopoda client if self._settings.sharktopoda_autoconnect.value: try: @@ -439,6 +450,7 @@ def _do_query(self): self.rect_clicked, self.verifier, zoom=self.ui.zoomSpinBox.value() / 100, + embedding_model=self._embedding_model, ) self.image_mosaic.hide_discarded = False @@ -714,6 +726,14 @@ def update_zoom(self, zoom): self.image_mosaic.update_zoom(zoom / 100) + @QtCore.pyqtSlot(object) + def update_embeddings_enabled(self, embeddings_enabled: bool): + if embeddings_enabled: + if self._embedding_model is None: + self._embedding_model = DreamSimEmbedding() + if self.image_mosaic is not None: + self.image_mosaic.update_embedding_model(self._embedding_model) + @QtCore.pyqtSlot(object, object) def rect_clicked(self, rect: RectWidget, event: Optional[QtGui.QMouseEvent]): if not self.loaded: @@ -1032,6 +1052,12 @@ def init_settings(): 1000, ) + settings.embeddings_enabled = ( + "embeddings/enabled", + bool, + constants.EMBEDDINGS_ENABLED_DEFAULT, + ) + def parse_args(): """ diff --git a/vars_gridview/ui/settings/SettingsDialog.py b/vars_gridview/ui/settings/SettingsDialog.py index d68a170..ff21790 100644 --- a/vars_gridview/ui/settings/SettingsDialog.py +++ b/vars_gridview/ui/settings/SettingsDialog.py @@ -3,6 +3,7 @@ from vars_gridview.ui.settings.tabs.AbstractSettingsTab import AbstractSettingsTab from vars_gridview.ui.settings.tabs.AppearanceTab import AppearanceTab from vars_gridview.ui.settings.tabs.CacheTab import CacheTab +from vars_gridview.ui.settings.tabs.EmbeddingsTab import EmbeddingsTab from vars_gridview.ui.settings.tabs.M3Tab import M3Tab from vars_gridview.ui.settings.tabs.SQLTab import SQLTab from vars_gridview.ui.settings.tabs.VideoPlayerTab import VideoPlayerTab @@ -97,3 +98,4 @@ def _add_tabs(self, connect_slot, connected_signal, clear_cache_slot): self._register_tab(AppearanceTab()) self._register_tab(VideoPlayerTab(connect_slot, connected_signal)) self._register_tab(CacheTab(clear_cache_slot)) + self._register_tab(EmbeddingsTab()) diff --git a/vars_gridview/ui/settings/tabs/EmbeddingsTab.py b/vars_gridview/ui/settings/tabs/EmbeddingsTab.py new file mode 100644 index 0000000..acf449a --- /dev/null +++ b/vars_gridview/ui/settings/tabs/EmbeddingsTab.py @@ -0,0 +1,38 @@ +from PyQt6 import QtWidgets + +from vars_gridview.ui.settings.tabs.AbstractSettingsTab import AbstractSettingsTab + + +class EmbeddingsTab(AbstractSettingsTab): + """ + Embeddings tab. + """ + + def __init__(self, parent=None): + super().__init__("Embeddings", parent=parent) + + self._embeddings_enabled_toggle = QtWidgets.QCheckBox() + self._embeddings_enabled_toggle.setChecked( + self._settings.embeddings_enabled.value + ) + self._embeddings_enabled_toggle.stateChanged.connect(self.settingsChanged.emit) + self._settings.embeddings_enabled.valueChanged.connect( + self._embeddings_enabled_toggle.setChecked + ) + + self.arrange() + + def arrange(self): + layout = QtWidgets.QFormLayout() + layout.setFieldGrowthPolicy( + QtWidgets.QFormLayout.FieldGrowthPolicy.ExpandingFieldsGrow + ) + + layout.addRow("Embeddings enabled", self._embeddings_enabled_toggle) + + self.setLayout(layout) + + def apply_settings(self): + self._settings.embeddings_enabled.value = ( + self._embeddings_enabled_toggle.isChecked() + )