diff --git a/cli/medperf/commands/benchmark/benchmark.py b/cli/medperf/commands/benchmark/benchmark.py index aa35a859c..0140a3cc1 100644 --- a/cli/medperf/commands/benchmark/benchmark.py +++ b/cli/medperf/commands/benchmark/benchmark.py @@ -18,6 +18,7 @@ def list( local: bool = typer.Option(False, "--local", help="Get local benchmarks"), mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"), + valid: bool = typer.Option(False, "--valid", help="List only valid benchmarks"), ): """List benchmarks stored locally and remotely from the user""" EntityList.run( @@ -25,6 +26,7 @@ def list( fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"], local_only=local, mine_only=mine, + valid_only=valid ) diff --git a/cli/medperf/commands/dataset/dataset.py b/cli/medperf/commands/dataset/dataset.py index 07c97153c..16bdaf8ef 100644 --- a/cli/medperf/commands/dataset/dataset.py +++ b/cli/medperf/commands/dataset/dataset.py @@ -18,6 +18,7 @@ def list( local: bool = typer.Option(False, "--local", help="Get local datasets"), mine: bool = typer.Option(False, "--mine", help="Get current-user datasets"), + valid: bool = typer.Option(False, "--valid", help="List only valid datasets"), ): """List datasets stored locally and remotely from the user""" EntityList.run( @@ -25,6 +26,7 @@ def list( fields=["UID", "Name", "Data Preparation Cube UID", "Registered"], local_only=local, mine_only=mine, + valid_only=valid ) diff --git a/cli/medperf/commands/list.py b/cli/medperf/commands/list.py index 5fd462bf7..828e6a626 100644 --- a/cli/medperf/commands/list.py +++ b/cli/medperf/commands/list.py @@ -1,37 +1,43 @@ from medperf.exceptions import InvalidArgumentError from tabulate import tabulate +from typing import Type from medperf import config from medperf.account_management import get_medperf_user_data +from medperf.entities.schemas import DeployableEntity class EntityList: @staticmethod def run( - entity_class, - fields, + entity_class: Type[DeployableEntity], + fields: list[str], local_only: bool = False, mine_only: bool = False, + valid_only: bool = False, **kwargs, ): """Lists all local datasets Args: + entity_class (class): entity to list. Has to be Entity + DeployableSchema local_only (bool, optional): Display all local results. Defaults to False. mine_only (bool, optional): Display all current-user results. Defaults to False. + valid_only: (bool, optional): Show only valid results. Defaults to False. kwargs (dict): Additional parameters for filtering entity lists. """ - entity_list = EntityList(entity_class, fields, local_only, mine_only, **kwargs) + entity_list = EntityList(entity_class, fields, local_only, mine_only, valid_only, **kwargs) entity_list.prepare() entity_list.validate() entity_list.filter() entity_list.display() - def __init__(self, entity_class, fields, local_only, mine_only, **kwargs): + def __init__(self, entity_class, fields, local_only, mine_only, valid_only, **kwargs): self.entity_class = entity_class self.fields = fields self.local_only = local_only self.mine_only = mine_only + self.valid_only = valid_only self.filters = kwargs self.data = [] @@ -42,6 +48,10 @@ def prepare(self): entities = self.entity_class.all( local_only=self.local_only, filters=self.filters ) + + if self.valid_only: + entities = [entity for entity in entities if entity.is_valid] + self.data = [entity.display_dict() for entity in entities] def validate(self): diff --git a/cli/medperf/commands/mlcube/mlcube.py b/cli/medperf/commands/mlcube/mlcube.py index e5b0253ee..ba6d56517 100644 --- a/cli/medperf/commands/mlcube/mlcube.py +++ b/cli/medperf/commands/mlcube/mlcube.py @@ -18,6 +18,7 @@ def list( local: bool = typer.Option(False, "--local", help="Get local mlcubes"), mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"), + valid: bool = typer.Option(False, "--valid", help="List only valid mlcubes"), ): """List mlcubes stored locally and remotely from the user""" EntityList.run( @@ -25,6 +26,7 @@ def list( fields=["UID", "Name", "State", "Registered"], local_only=local, mine_only=mine, + valid_only=valid ) diff --git a/cli/medperf/commands/result/result.py b/cli/medperf/commands/result/result.py index 7253b205b..6a7c23376 100644 --- a/cli/medperf/commands/result/result.py +++ b/cli/medperf/commands/result/result.py @@ -63,6 +63,7 @@ def submit( def list( local: bool = typer.Option(False, "--local", help="Get local results"), mine: bool = typer.Option(False, "--mine", help="Get current-user results"), + valid: bool = typer.Option(False, "--valid", help="Get only valid results"), benchmark: int = typer.Option( None, "--benchmark", "-b", help="Get results for a given benchmark" ), @@ -73,6 +74,7 @@ def list( fields=["UID", "Benchmark", "Model", "Dataset", "Registered"], local_only=local, mine_only=mine, + valid_only=valid, benchmark=benchmark, ) diff --git a/cli/medperf/entities/benchmark.py b/cli/medperf/entities/benchmark.py index b3b1a90f7..04029bac0 100644 --- a/cli/medperf/entities/benchmark.py +++ b/cli/medperf/entities/benchmark.py @@ -61,11 +61,11 @@ def __init__(self, *args, **kwargs): self.path = path @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: + def all(cls, local_only: bool = False, filters: dict = None) -> List["Benchmark"]: """Gets and creates instances of all retrievable benchmarks Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + local_only (bool, optional): Whether to retrieve only local entities. Defaults to False. filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. Returns: @@ -73,6 +73,7 @@ def all(cls, local_only: bool = False, filters: dict = {}) -> List["Benchmark"]: """ logging.info("Retrieving all benchmarks") benchmarks = [] + filters = filters or {} if not local_only: benchmarks = cls.__remote_all(filters=filters) diff --git a/cli/medperf/entities/cube.py b/cli/medperf/entities/cube.py index ac2d58b55..0cb7f0992 100644 --- a/cli/medperf/entities/cube.py +++ b/cli/medperf/entities/cube.py @@ -73,11 +73,11 @@ def __init__(self, *args, **kwargs): self.params_path = os.path.join(path, config.params_filename) @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: + def all(cls, local_only: bool = False, filters: dict = None) -> List["Cube"]: """Class method for retrieving all retrievable MLCubes Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + local_only (bool, optional): Whether to retrieve only local entities. Defaults to False. filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. Returns: @@ -85,6 +85,8 @@ def all(cls, local_only: bool = False, filters: dict = {}) -> List["Cube"]: """ logging.info("Retrieving all cubes") cubes = [] + filters = filters or {} + if not local_only: cubes = cls.__remote_all(filters=filters) diff --git a/cli/medperf/entities/dataset.py b/cli/medperf/entities/dataset.py index c68099e5c..d0aa9ff60 100644 --- a/cli/medperf/entities/dataset.py +++ b/cli/medperf/entities/dataset.py @@ -71,11 +71,11 @@ def todict(self): return self.extended_dict() @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: + def all(cls, local_only: bool = False, filters: dict = None) -> List["Dataset"]: """Gets and creates instances of all the locally prepared datasets Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + local_only (bool, optional): Whether to retrieve only local entities. Defaults to False. filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. Returns: @@ -83,6 +83,8 @@ def all(cls, local_only: bool = False, filters: dict = {}) -> List["Dataset"]: """ logging.info("Retrieving all datasets") dsets = [] + filters = filters or {} + if not local_only: dsets = cls.__remote_all(filters=filters) diff --git a/cli/medperf/entities/interface.py b/cli/medperf/entities/interface.py index af2afabd7..d8d965491 100644 --- a/cli/medperf/entities/interface.py +++ b/cli/medperf/entities/interface.py @@ -5,15 +5,14 @@ class Entity(ABC): @abstractmethod def all( - cls, local_only: bool = False, comms_func: callable = None + cls, local_only: bool = False, filters: dict = None ) -> List["Entity"]: """Gets a list of all instances of the respective entity. - Wether the list is local or remote depends on the implementation. + Whether the list is local or remote depends on the implementation. Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. - comms_func (callable, optional): Function to use to retrieve remote entities. - If not provided, will use the default entrypoint. + local_only (bool, optional): Whether to retrieve only local entities. Defaults to False. + filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. Returns: List[Entity]: a list of entities. diff --git a/cli/medperf/entities/report.py b/cli/medperf/entities/report.py index 629d2806d..c44f18c87 100644 --- a/cli/medperf/entities/report.py +++ b/cli/medperf/entities/report.py @@ -54,7 +54,8 @@ def set_results(self, results): @classmethod def all( - cls, local_only: bool = False, mine_only: bool = False + # TODO: `mine_only` is never used. In other entities filtering by `mine_only` is implemented with `filter` field + cls, local_only: bool = False, mine_only: bool = False, filters: dict = None ) -> List["TestReport"]: """Gets and creates instances of test reports. Arguments are only specified for compatibility with @@ -66,6 +67,7 @@ def all( """ logging.info("Retrieving all reports") reports = [] + filters = filters or {} test_storage = storage_path(config.test_storage) try: uids = next(os.walk(test_storage))[1] diff --git a/cli/medperf/entities/result.py b/cli/medperf/entities/result.py index 1eaeaeadf..7dd47bafb 100644 --- a/cli/medperf/entities/result.py +++ b/cli/medperf/entities/result.py @@ -42,11 +42,11 @@ def __init__(self, *args, **kwargs): self.path = path @classmethod - def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: + def all(cls, local_only: bool = False, filters: dict = None) -> List["Result"]: """Gets and creates instances of all the user's results Args: - local_only (bool, optional): Wether to retrieve only local entities. Defaults to False. + local_only (bool, optional): Whether to retrieve only local entities. Defaults to False. filters (dict, optional): key-value pairs specifying filters to apply to the list of entities. Returns: @@ -54,6 +54,8 @@ def all(cls, local_only: bool = False, filters: dict = {}) -> List["Result"]: """ logging.info("Retrieving all results") results = [] + filters = filters or {} + if not local_only: results = cls.__remote_all(filters=filters) diff --git a/cli/medperf/entities/schemas.py b/cli/medperf/entities/schemas.py index ad0f5f596..517bd605f 100644 --- a/cli/medperf/entities/schemas.py +++ b/cli/medperf/entities/schemas.py @@ -3,6 +3,7 @@ from typing import Optional from collections import defaultdict +from medperf.entities.interface import Entity from medperf.enums import Status from medperf.exceptions import MedperfException from medperf.utils import format_errors_dict @@ -105,3 +106,7 @@ def default_status(cls, v): if v is not None: status = Status(v) return status + + +class DeployableEntity(DeployableSchema, Entity): + pass diff --git a/cli/medperf/tests/commands/mlcube/test_list.py b/cli/medperf/tests/commands/mlcube/test_list.py new file mode 100644 index 000000000..46ae90b1d --- /dev/null +++ b/cli/medperf/tests/commands/mlcube/test_list.py @@ -0,0 +1,101 @@ +from typing import Any + +import pytest + +from medperf.entities.cube import Cube +from medperf.commands.list import EntityList + +PATCH_CUBE = "medperf.entities.cube.Cube.{}" + + +def generate_cube(id: int, is_valid: bool, owner: int) -> dict[str, Any]: + git_mlcube_url = f"{id}-{is_valid}-{owner}" + name = git_mlcube_url + return { + 'id': id, + 'is_valid': is_valid, + 'owner': owner, + 'git_mlcube_url': git_mlcube_url, + 'name': name + } + + +def cls_local_cubes(*args, **kwargs) -> list[Cube]: + return [ + Cube(**generate_cube(id=101, is_valid=True, owner=1)), + Cube(**generate_cube(id=102, is_valid=False, owner=1)), + # Intended: for local mlcubes owner is never checked. + # All local cubes are supposed to be owned by current user + + # generate_cube(id=103, is_valid=True, owner=12345), + # generate_cube(id=104, is_valid=False, owner=12345), + ] + + +def comms_remote_cubes_dict_mine_only() -> list[dict[str, Any]]: + return [ + generate_cube(id=201, is_valid=True, owner=1), + generate_cube(id=202, is_valid=False, owner=1), + ] + + +def comms_remote_cubes_dict() -> list[dict[str, Any]]: + mine_only = comms_remote_cubes_dict_mine_only() + someone_else = [ + generate_cube(id=203, is_valid=True, owner=12345), + generate_cube(id=204, is_valid=False, owner=12345), + ] + return mine_only + someone_else + + +def cls_remote_cubes(*args, **kwargs) -> list[Cube]: + return [Cube(**d) for d in comms_remote_cubes_dict()] + + +@pytest.mark.parametrize("local_only", [False, True]) +@pytest.mark.parametrize("mine_only", [False, True]) +@pytest.mark.parametrize("valid_only", [False, True]) +def test_run_list_mlcubes(mocker, comms, ui, local_only, mine_only, valid_only): + # Arrange + mocker.patch("medperf.commands.list.get_medperf_user_data", return_value={"id": 1}) + mocker.patch("medperf.entities.cube.get_medperf_user_data", return_value={"id": 1}) + + # Implementation-specific: for local cubes there is a private classmethod. + mocker.patch(PATCH_CUBE.format("_Cube__local_all"), new=cls_local_cubes) + # For remote cubes there are two different endpoints - for all cubes and for mine only + mocker.patch.object(comms, 'get_user_cubes', new=comms_remote_cubes_dict_mine_only) + mocker.patch.object(comms, 'get_cubes', new=comms_remote_cubes_dict) + + tab_spy = mocker.patch("medperf.commands.list.tabulate", return_value="") + + local_cubes = cls_local_cubes() + remote_cubes = cls_remote_cubes() + cubes = local_cubes + remote_cubes + + # Act + EntityList.run(Cube, fields=['UID'], local_only=local_only, mine_only=mine_only, valid_only=valid_only) + + # Assert + tab_call = tab_spy.call_args_list[0] + received_cubes: list[list[Any]] = tab_call[0][0] + received_ids = {cube_fields[0] for cube_fields in received_cubes} + + local_ids = {c.id for c in local_cubes} + + expected_ids = set() + for c in cubes: + if local_only: + if c.id not in local_ids: + continue + + if mine_only: + if c.owner != 1: + continue + + if valid_only: + if not c.is_valid: + continue + + expected_ids.add(c.id) + + assert received_ids == expected_ids diff --git a/cli/medperf/tests/commands/test_list.py b/cli/medperf/tests/commands/test_list.py index 1c2dc3267..76ff8b003 100644 --- a/cli/medperf/tests/commands/test_list.py +++ b/cli/medperf/tests/commands/test_list.py @@ -3,6 +3,7 @@ from medperf.exceptions import InvalidArgumentError import pytest from medperf.entities.interface import Entity +from medperf.entities.schemas import DeployableEntity def generate_display_dicts(): @@ -21,13 +22,18 @@ def setup(request, mocker, ui): # mocks entity_object = mocker.create_autospec(spec=Entity) + # As object has to be Entity + DeployableSchema, it requires `is_valid` field. + # autospec does not create such a field as in schema is declared as class attribute + # rather than instance attribute. Thus, for testing purposes we have to create attr manually + entity_object.is_valid = True mocker.patch.object(entity_object, "display_dict", side_effect=display_dicts) mocker.patch("medperf.commands.list.get_medperf_user_data", return_value={"id": 1}) # spies generated_entities = [entity_object for _ in display_dicts] + print([e.is_valid for e in generated_entities]) all_spy = mocker.patch( - "medperf.entities.interface.Entity.all", return_value=generated_entities + "medperf.entities.schemas.DeployableEntity.all", return_value=generated_entities ) ui_spy = mocker.patch.object(ui, "print") tabulate_spy = mocker.spy(list_module, "tabulate") @@ -49,12 +55,13 @@ def set_common_attributes(self, setup): @pytest.mark.parametrize("local_only", [False, True]) @pytest.mark.parametrize("mine_only", [False, True]) - def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): + @pytest.mark.parametrize("valid_only", [False, True]) + def test_entity_all_is_called_properly(self, mocker, local_only, mine_only, valid_only): # Arrange filters = {"owner": 1} if mine_only else {} # Act - EntityList.run(Entity, [], local_only, mine_only) + EntityList.run(DeployableEntity, [], local_only, mine_only, valid_only) # Assert self.spies["all"].assert_called_once_with( @@ -65,7 +72,7 @@ def test_entity_all_is_called_properly(self, mocker, local_only, mine_only): def test_exception_raised_for_invalid_input(self, fields): # Act & Assert with pytest.raises(InvalidArgumentError): - EntityList.run(Entity, fields) + EntityList.run(DeployableEntity, fields) @pytest.mark.parametrize("fields", [["UID", "Is Valid"], ["Registered"]]) def test_display_calls_tabulate_and_ui_as_expected(self, fields): @@ -76,7 +83,7 @@ def test_display_calls_tabulate_and_ui_as_expected(self, fields): ] # Act - EntityList.run(Entity, fields) + EntityList.run(DeployableEntity, fields) # Assert self.spies["tabulate"].assert_called_once_with(expected_list, headers=fields)