From b1f06b833884fb341c58a69353cbb6963344337c Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Fri, 21 Jun 2024 17:25:14 +0200 Subject: [PATCH] Issue #4 implement env/secret driven auth in benchmarks --- qa/benchmarks/requirements.txt | 1 + qa/benchmarks/tests/conftest.py | 75 ++++++++++++++++++++ qa/benchmarks/tests/test_benchmarks.py | 16 +++-- qa/tools/apex_algorithm_qa_tools/usecases.py | 1 + 4 files changed, 88 insertions(+), 5 deletions(-) diff --git a/qa/benchmarks/requirements.txt b/qa/benchmarks/requirements.txt index eee02b6..5b65f4c 100644 --- a/qa/benchmarks/requirements.txt +++ b/qa/benchmarks/requirements.txt @@ -1,3 +1,4 @@ apex-algorithm-qa-tools openeo>=0.30.0 pytest>=8.2.0 +requests>=2.32.0 diff --git a/qa/benchmarks/tests/conftest.py b/qa/benchmarks/tests/conftest.py index 9a0d2f5..e066157 100644 --- a/qa/benchmarks/tests/conftest.py +++ b/qa/benchmarks/tests/conftest.py @@ -1,5 +1,11 @@ import logging +import os import random +from typing import Callable + +import openeo +import pytest +import requests # TODO: how to make sure the logging/printing from this plugin is actually visible by default? _log = logging.getLogger(__name__) @@ -29,3 +35,72 @@ def pytest_collection_modifyitems(session, config, items): f"Selecting random subset of {subset_size} from {len(items)} benchmarks." ) items[:] = random.sample(items, k=subset_size) + + +def _get_client_credentials_env_var(url: str) -> str: + """ + Get client credentials env var name for a given backend URL. + """ + # TODO: parse url to more reliably extract hostname + if url == "openeofed.dataspace.copernicus.eu": + return "OPENEO_AUTH_CLIENT_CREDENTIALS_CDSEFED" + else: + raise ValueError(f"Unsupported backend: {url}") + + +@pytest.fixture +def connection_factory(request, capfd) -> Callable[[], openeo.Connection]: + """ + Fixture for a function that sets up an authenticated connection to an openEO backend. + + This is implemented as a fixture to have access to other fixtures that allow + deeper integration with the pytest framework. + For example, the `request` fixture allows to identify the currently running test/benchmark. + """ + + # Identifier for the current test/benchmark, to be injected automatically + # into requests to the backend for tracking/cross-referencing purposes + origin = f"apex-algorithms/benchmarks/{request.session.name}/{request.node.name}" + + def get_connection(url: str) -> openeo.Connection: + session = requests.Session() + session.params["_origin"] = origin + + _log.info(f"Connecting to {url!r}") + connection = openeo.connect(url, auto_validate=False, session=session) + connection.default_headers["X-OpenEO-Client-Context"] = ( + "APEx Algorithm Benchmarks" + ) + + # Authentication: + # In production CI context, we want to extract client credentials + # from environment variables (based on backend url). + # In absence of such environment variables, to allow local development, + # we fall back on a traditional `authenticate_oidc()` + # which automatically supports various authentication flows (device code, refresh token, client credentials, etc.) + auth_env_var = _get_client_credentials_env_var(url) + _log.info(f"Checking for {auth_env_var=} to drive auth against {url=}.") + if auth_env_var in os.environ: + client_credentials = os.environ[auth_env_var] + provider_id, client_id, client_secret = client_credentials.split("/") + connection.authenticate_oidc_client_credentials( + provider_id=provider_id, + client_id=client_id, + client_secret=client_secret, + ) + else: + # Temporarily disable output capturing, + # to make sure that the OIDC device code instructions are shown + # to the user running interactively. + with capfd.disabled(): + # Use a shorter max poll time by default + # to alleviate the default impression that the test seem to hang + # because of the OIDC device code poll loop. + max_poll_time = int( + os.environ.get("OPENEO_OIDC_DEVICE_CODE_MAX_POLL_TIME") or 30 + ) + connection.authenticate_oidc(max_poll_time=max_poll_time) + + return connection + + return get_connection diff --git a/qa/benchmarks/tests/test_benchmarks.py b/qa/benchmarks/tests/test_benchmarks.py index ce5445b..22f946a 100644 --- a/qa/benchmarks/tests/test_benchmarks.py +++ b/qa/benchmarks/tests/test_benchmarks.py @@ -4,16 +4,22 @@ @pytest.mark.parametrize( - "use_case", [pytest.param(uc, id=uc.id) for uc in get_use_cases()] + "use_case", + [ + # Use use case id as parameterization id to give nicer test names. + pytest.param(uc, id=uc.id) + for uc in get_use_cases() + ], ) -def test_run_benchmark(use_case: UseCase): - # TODO: cache connection? - # TODO: authentication - connection = openeo.connect(use_case.backend) +def test_run_benchmark(use_case: UseCase, connection_factory): + connection: openeo.Connection = connection_factory(url=use_case.backend) + # TODO: scenario option to use synchronous instead of batch job mode? job = connection.create_job( process_graph=use_case.process_graph, title=f"APEx benchmark {use_case.id}", ) job.start_and_wait() + + # TODO download job results and inspect diff --git a/qa/tools/apex_algorithm_qa_tools/usecases.py b/qa/tools/apex_algorithm_qa_tools/usecases.py index eaa5a98..e54b306 100644 --- a/qa/tools/apex_algorithm_qa_tools/usecases.py +++ b/qa/tools/apex_algorithm_qa_tools/usecases.py @@ -67,6 +67,7 @@ def get_algorithm_invocation_root() -> Path: def get_use_cases() -> List[UseCase]: # TODO: instead of flat list, keep original grouping/structure of "algorithm_invocations" files? + # TODO: check for uniqueness of scenario IDs? Also make this a pre-commit lint tool? use_cases = [] for path in get_algorithm_invocation_root().glob("*.json"): with open(path) as f: