diff --git a/app/routes/datasets/__init__.py b/app/routes/datasets/__init__.py index 663acb2c1..2013324e3 100644 --- a/app/routes/datasets/__init__.py +++ b/app/routes/datasets/__init__.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, Dict, List, Sequence +from urllib.parse import urlparse from fastapi import HTTPException @@ -11,6 +13,19 @@ from ...tasks.raster_tile_set_assets.raster_tile_set_assets import ( raster_tile_set_validator, ) +from ...utils.aws import get_aws_files +from ...utils.google import get_gs_files + +SUPPORTED_FILE_EXTENSIONS: Sequence[str] = ( + ".csv", + ".geojson", + ".gpkg", + ".ndjson", + ".shp", + ".tif", + ".tsv", + ".zip", +) async def verify_version_status(dataset, version): @@ -82,3 +97,63 @@ async def validate_creation_options( await validator[input_data["asset_type"]](dataset, version, input_data) except KeyError: pass + + +# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried: +# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]] +source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore +source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore + + +def _verify_source_file_access(sources: List[str]) -> None: + + # TODO: + # 1. Making the list functions asynchronous and using asyncio.gather + # to check for valid sources in a non-blocking fashion would be good. + # Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs. + # 2. It would be nice if the acceptable file extensions were passed + # into this function so we could say, for example, that there must be + # TIFFs found for a new raster tile set, but a CSV is required for a new + # vector tile set version. Even better would be to specify whether + # paths to individual files or "folders" (prefixes) are allowed. + + invalid_sources: List[str] = list() + + for source in sources: + url_parts = urlparse(source, allow_fragments=False) + list_func = source_uri_lister_constructor[url_parts.scheme.lower()] + bucket = url_parts.netloc + prefix = url_parts.path.lstrip("/") + + # Allow pseudo-globbing: Tolerate a "*" at the end of a + # src_uri entry to allow partial prefixes (for example + # /bucket/prefix_part_1/prefix_fragment* will match + # /bucket/prefix_part_1/prefix_fragment_1.tif and + # /bucket/prefix_part_1/prefix_fragment_2.tif, etc.) + # If the prefix doesn't end in "*" or an acceptable file extension + # add a "/" to the end of the prefix to enforce it being a "folder". + new_prefix: str = prefix + if new_prefix.endswith("*"): + new_prefix = new_prefix[:-1] + elif not new_prefix.endswith("/") and not any( + [new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS] + ): + new_prefix += "/" + + if not list_func( + bucket, + new_prefix, + limit=10, + exit_after_max=1, + extensions=SUPPORTED_FILE_EXTENSIONS, + ): + invalid_sources.append(source) + + if invalid_sources: + raise HTTPException( + status_code=400, + detail=( + "Cannot access all of the source files (non-existent or access denied). " + f"Invalid sources: {invalid_sources}" + ), + ) diff --git a/app/routes/datasets/queries.py b/app/routes/datasets/queries.py index 8f7437105..5e53b9423 100755 --- a/app/routes/datasets/queries.py +++ b/app/routes/datasets/queries.py @@ -93,6 +93,7 @@ from ...utils.aws import get_sfn_client, invoke_lambda from ...utils.geostore import get_geostore from .. import dataset_version_dependency +from . import _verify_source_file_access router = APIRouter() @@ -374,11 +375,21 @@ async def query_dataset_list_post( data_environment = await _get_data_environment(grid) input = { - "feature_collection": jsonable_encoder(request.feature_collection), "query": sql, "environment": data_environment.dict()["layers"], } + if request.feature_collection: + input["feature_collection"] = jsonable_encoder(request.feature_collection) + elif request.uri: + _verify_source_file_access([request.uri]) + input["uri"] = request.uri + else: + raise HTTPException( + status_code=400, + detail="Must provide valid feature collection or URI.", + ) + try: await _start_batch_execution(job_id, input) except botocore.exceptions.ClientError as error: diff --git a/app/routes/datasets/versions.py b/app/routes/datasets/versions.py index 6a09530f7..89a2f1722 100644 --- a/app/routes/datasets/versions.py +++ b/app/routes/datasets/versions.py @@ -9,10 +9,8 @@ Available assets and endpoints to choose from depend on the source type. """ -from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Tuple, Union, cast from fastapi import ( APIRouter, @@ -43,8 +41,8 @@ from ...models.pydantic.creation_options import ( CreationOptions, CreationOptionsResponse, + TableDrivers, creation_option_factory, - TableDrivers ) from ...models.pydantic.extent import Extent, ExtentResponse from ...models.pydantic.metadata import ( @@ -67,29 +65,12 @@ from ...tasks.aws_tasks import flush_cloudfront_cache from ...tasks.default_assets import append_default_asset, create_default_asset from ...tasks.delete_assets import delete_all_assets -from ...utils.aws import get_aws_files -from ...utils.google import get_gs_files +from . import _verify_source_file_access from .dataset import get_owner from .queries import _get_data_environment router = APIRouter() -SUPPORTED_FILE_EXTENSIONS: Sequence[str] = ( - ".csv", - ".geojson", - ".gpkg", - ".ndjson", - ".shp", - ".tif", - ".tsv", - ".zip", -) - -# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried: -# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]] -source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore -source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore - @router.get( "/{dataset}/{version}", @@ -212,7 +193,7 @@ async def update_version( "/{dataset}/{version}/append", response_class=ORJSONResponse, tags=["Versions"], - response_model=VersionResponse + response_model=VersionResponse, ) async def append_to_version( *, @@ -243,14 +224,14 @@ async def append_to_version( input_data["creation_options"]["source_uri"] = request.source_uri # If source_driver is "text", this is a datapump request - if input_data["creation_options"]["source_driver"] != TableDrivers.text: + if input_data["creation_options"]["source_driver"] != TableDrivers.text: # Verify that source_driver matches the original source_driver # TODO: Ideally append source_driver should not need to match the original source_driver, # but this would break other operations that expect only one source_driver if input_data["creation_options"]["source_driver"] != request.source_driver: raise HTTPException( status_code=400, - detail="source_driver must match the original source_driver" + detail="source_driver must match the original source_driver", ) # Use layers from request if provided, else set to None if layers are in version creation_options @@ -267,7 +248,7 @@ async def append_to_version( # Now update the version's creation_options to reflect the changes from the append request update_data = {"creation_options": deepcopy(default_asset.creation_options)} - update_data["creation_options"]["source_uri"] += request.source_uri + update_data["creation_options"]["source_uri"] += request.source_uri if request.layers is not None: if update_data["creation_options"]["layers"] is not None: update_data["creation_options"]["layers"] += request.layers @@ -561,56 +542,3 @@ async def _version_response( data["assets"] = [(asset[0], asset[1], str(asset[2])) for asset in assets] return VersionResponse(data=Version(**data)) - -def _verify_source_file_access(sources: List[str]) -> None: - - # TODO: - # 1. Making the list functions asynchronous and using asyncio.gather - # to check for valid sources in a non-blocking fashion would be good. - # Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs. - # 2. It would be nice if the acceptable file extensions were passed - # into this function so we could say, for example, that there must be - # TIFFs found for a new raster tile set, but a CSV is required for a new - # vector tile set version. Even better would be to specify whether - # paths to individual files or "folders" (prefixes) are allowed. - - invalid_sources: List[str] = list() - - for source in sources: - url_parts = urlparse(source, allow_fragments=False) - list_func = source_uri_lister_constructor[url_parts.scheme.lower()] - bucket = url_parts.netloc - prefix = url_parts.path.lstrip("/") - - # Allow pseudo-globbing: Tolerate a "*" at the end of a - # src_uri entry to allow partial prefixes (for example - # /bucket/prefix_part_1/prefix_fragment* will match - # /bucket/prefix_part_1/prefix_fragment_1.tif and - # /bucket/prefix_part_1/prefix_fragment_2.tif, etc.) - # If the prefix doesn't end in "*" or an acceptable file extension - # add a "/" to the end of the prefix to enforce it being a "folder". - new_prefix: str = prefix - if new_prefix.endswith("*"): - new_prefix = new_prefix[:-1] - elif not new_prefix.endswith("/") and not any( - [new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS] - ): - new_prefix += "/" - - if not list_func( - bucket, - new_prefix, - limit=10, - exit_after_max=1, - extensions=SUPPORTED_FILE_EXTENSIONS, - ): - invalid_sources.append(source) - - if invalid_sources: - raise HTTPException( - status_code=400, - detail=( - "Cannot access all of the source files (non-existent or access denied). " - f"Invalid sources: {invalid_sources}" - ), - ) diff --git a/tests_v2/unit/app/routes/datasets/test_query.py b/tests_v2/unit/app/routes/datasets/test_query.py index 6d0962bc1..c5d8d0377 100755 --- a/tests_v2/unit/app/routes/datasets/test_query.py +++ b/tests_v2/unit/app/routes/datasets/test_query.py @@ -666,7 +666,7 @@ async def test__get_data_environment_helper_called( @pytest.mark.asyncio -async def test_query_batch( +async def test_query_batch_feature_collection( generic_raster_version, apikey, monkeypatch: MonkeyPatch, @@ -708,6 +708,51 @@ async def test_query_batch( assert response.json()["status"] == "success" +@pytest.mark.asyncio +async def test_query_batch_uri( + generic_raster_version, + apikey, + monkeypatch: MonkeyPatch, + async_client: AsyncClient, +): + dataset_name, version_name, _ = generic_raster_version + api_key, payload = apikey + origin = "https://" + payload["domains"][0] + + headers = {"origin": origin, "x-api-key": api_key} + + monkeypatch.setattr(queries, "_start_batch_execution", start_batch_execution_mocked) + monkeypatch.setattr(queries, "_verify_source_file_access", lambda source_uris: True) + + payload = { + "sql": "select count(*) from data", + "uri": "s3://path/to/files", + } + + response = await async_client.post( + f"/dataset/{dataset_name}/{version_name}/query/batch", + json=payload, + headers=headers, + ) + + print(response.json()) + assert response.status_code == 202 + + data = response.json()["data"] + + # assert valid UUID + try: + uuid = UUID(data["job_id"]) + except ValueError: + assert False + + assert str(uuid) == data["job_id"] + + assert data["status"] == "pending" + + assert response.json()["status"] == "success" + + FEATURE_COLLECTION = { "type": "FeatureCollection", "features": [