Skip to content

Commit

Permalink
Support sending URI
Browse files Browse the repository at this point in the history
  • Loading branch information
jterry64 committed Jul 11, 2024
1 parent 66d2cf7 commit 6e2c1e6
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 82 deletions.
77 changes: 76 additions & 1 deletion app/routes/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, List
from collections import defaultdict
from typing import Any, Dict, List, Sequence
from urllib.parse import urlparse

from fastapi import HTTPException

Expand All @@ -11,6 +13,19 @@
from ...tasks.raster_tile_set_assets.raster_tile_set_assets import (
raster_tile_set_validator,
)
from ...utils.aws import get_aws_files
from ...utils.google import get_gs_files

SUPPORTED_FILE_EXTENSIONS: Sequence[str] = (
".csv",
".geojson",
".gpkg",
".ndjson",
".shp",
".tif",
".tsv",
".zip",
)


async def verify_version_status(dataset, version):
Expand Down Expand Up @@ -82,3 +97,63 @@ async def validate_creation_options(
await validator[input_data["asset_type"]](dataset, version, input_data)
except KeyError:
pass


# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried:
# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]]
source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore
source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore


def _verify_source_file_access(sources: List[str]) -> None:

# TODO:
# 1. Making the list functions asynchronous and using asyncio.gather
# to check for valid sources in a non-blocking fashion would be good.
# Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs.
# 2. It would be nice if the acceptable file extensions were passed
# into this function so we could say, for example, that there must be
# TIFFs found for a new raster tile set, but a CSV is required for a new
# vector tile set version. Even better would be to specify whether
# paths to individual files or "folders" (prefixes) are allowed.

invalid_sources: List[str] = list()

for source in sources:
url_parts = urlparse(source, allow_fragments=False)
list_func = source_uri_lister_constructor[url_parts.scheme.lower()]
bucket = url_parts.netloc
prefix = url_parts.path.lstrip("/")

# Allow pseudo-globbing: Tolerate a "*" at the end of a
# src_uri entry to allow partial prefixes (for example
# /bucket/prefix_part_1/prefix_fragment* will match
# /bucket/prefix_part_1/prefix_fragment_1.tif and
# /bucket/prefix_part_1/prefix_fragment_2.tif, etc.)
# If the prefix doesn't end in "*" or an acceptable file extension
# add a "/" to the end of the prefix to enforce it being a "folder".
new_prefix: str = prefix
if new_prefix.endswith("*"):
new_prefix = new_prefix[:-1]
elif not new_prefix.endswith("/") and not any(
[new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS]
):
new_prefix += "/"

if not list_func(
bucket,
new_prefix,
limit=10,
exit_after_max=1,
extensions=SUPPORTED_FILE_EXTENSIONS,
):
invalid_sources.append(source)

if invalid_sources:
raise HTTPException(
status_code=400,
detail=(
"Cannot access all of the source files (non-existent or access denied). "
f"Invalid sources: {invalid_sources}"
),
)
13 changes: 12 additions & 1 deletion app/routes/datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
from ...utils.aws import get_sfn_client, invoke_lambda
from ...utils.geostore import get_geostore
from .. import dataset_version_dependency
from . import _verify_source_file_access

router = APIRouter()

Expand Down Expand Up @@ -374,11 +375,21 @@ async def query_dataset_list_post(
data_environment = await _get_data_environment(grid)

input = {
"feature_collection": jsonable_encoder(request.feature_collection),
"query": sql,
"environment": data_environment.dict()["layers"],
}

if request.feature_collection:
input["feature_collection"] = jsonable_encoder(request.feature_collection)
elif request.uri:
_verify_source_file_access([request.uri])
input["uri"] = request.uri
else:
raise HTTPException(
status_code=400,
detail="Must provide valid feature collection or URI.",
)

try:
await _start_batch_execution(job_id, input)
except botocore.exceptions.ClientError as error:
Expand Down
86 changes: 7 additions & 79 deletions app/routes/datasets/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
Available assets and endpoints to choose from depend on the source type.
"""

from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from urllib.parse import urlparse
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from fastapi import (
APIRouter,
Expand Down Expand Up @@ -43,8 +41,8 @@
from ...models.pydantic.creation_options import (
CreationOptions,
CreationOptionsResponse,
TableDrivers,
creation_option_factory,
TableDrivers
)
from ...models.pydantic.extent import Extent, ExtentResponse
from ...models.pydantic.metadata import (
Expand All @@ -67,29 +65,12 @@
from ...tasks.aws_tasks import flush_cloudfront_cache
from ...tasks.default_assets import append_default_asset, create_default_asset
from ...tasks.delete_assets import delete_all_assets
from ...utils.aws import get_aws_files
from ...utils.google import get_gs_files
from . import _verify_source_file_access
from .dataset import get_owner
from .queries import _get_data_environment

router = APIRouter()

SUPPORTED_FILE_EXTENSIONS: Sequence[str] = (
".csv",
".geojson",
".gpkg",
".ndjson",
".shp",
".tif",
".tsv",
".zip",
)

# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried:
# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]]
source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore
source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore


@router.get(
"/{dataset}/{version}",
Expand Down Expand Up @@ -212,7 +193,7 @@ async def update_version(
"/{dataset}/{version}/append",
response_class=ORJSONResponse,
tags=["Versions"],
response_model=VersionResponse
response_model=VersionResponse,
)
async def append_to_version(
*,
Expand Down Expand Up @@ -243,14 +224,14 @@ async def append_to_version(
input_data["creation_options"]["source_uri"] = request.source_uri

# If source_driver is "text", this is a datapump request
if input_data["creation_options"]["source_driver"] != TableDrivers.text:
if input_data["creation_options"]["source_driver"] != TableDrivers.text:
# Verify that source_driver matches the original source_driver
# TODO: Ideally append source_driver should not need to match the original source_driver,
# but this would break other operations that expect only one source_driver
if input_data["creation_options"]["source_driver"] != request.source_driver:
raise HTTPException(
status_code=400,
detail="source_driver must match the original source_driver"
detail="source_driver must match the original source_driver",
)

# Use layers from request if provided, else set to None if layers are in version creation_options
Expand All @@ -267,7 +248,7 @@ async def append_to_version(

# Now update the version's creation_options to reflect the changes from the append request
update_data = {"creation_options": deepcopy(default_asset.creation_options)}
update_data["creation_options"]["source_uri"] += request.source_uri
update_data["creation_options"]["source_uri"] += request.source_uri
if request.layers is not None:
if update_data["creation_options"]["layers"] is not None:
update_data["creation_options"]["layers"] += request.layers
Expand Down Expand Up @@ -561,56 +542,3 @@ async def _version_response(
data["assets"] = [(asset[0], asset[1], str(asset[2])) for asset in assets]

return VersionResponse(data=Version(**data))

def _verify_source_file_access(sources: List[str]) -> None:

# TODO:
# 1. Making the list functions asynchronous and using asyncio.gather
# to check for valid sources in a non-blocking fashion would be good.
# Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs.
# 2. It would be nice if the acceptable file extensions were passed
# into this function so we could say, for example, that there must be
# TIFFs found for a new raster tile set, but a CSV is required for a new
# vector tile set version. Even better would be to specify whether
# paths to individual files or "folders" (prefixes) are allowed.

invalid_sources: List[str] = list()

for source in sources:
url_parts = urlparse(source, allow_fragments=False)
list_func = source_uri_lister_constructor[url_parts.scheme.lower()]
bucket = url_parts.netloc
prefix = url_parts.path.lstrip("/")

# Allow pseudo-globbing: Tolerate a "*" at the end of a
# src_uri entry to allow partial prefixes (for example
# /bucket/prefix_part_1/prefix_fragment* will match
# /bucket/prefix_part_1/prefix_fragment_1.tif and
# /bucket/prefix_part_1/prefix_fragment_2.tif, etc.)
# If the prefix doesn't end in "*" or an acceptable file extension
# add a "/" to the end of the prefix to enforce it being a "folder".
new_prefix: str = prefix
if new_prefix.endswith("*"):
new_prefix = new_prefix[:-1]
elif not new_prefix.endswith("/") and not any(
[new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS]
):
new_prefix += "/"

if not list_func(
bucket,
new_prefix,
limit=10,
exit_after_max=1,
extensions=SUPPORTED_FILE_EXTENSIONS,
):
invalid_sources.append(source)

if invalid_sources:
raise HTTPException(
status_code=400,
detail=(
"Cannot access all of the source files (non-existent or access denied). "
f"Invalid sources: {invalid_sources}"
),
)
47 changes: 46 additions & 1 deletion tests_v2/unit/app/routes/datasets/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ async def test__get_data_environment_helper_called(


@pytest.mark.asyncio
async def test_query_batch(
async def test_query_batch_feature_collection(
generic_raster_version,
apikey,
monkeypatch: MonkeyPatch,
Expand Down Expand Up @@ -708,6 +708,51 @@ async def test_query_batch(
assert response.json()["status"] == "success"


@pytest.mark.asyncio
async def test_query_batch_uri(
generic_raster_version,
apikey,
monkeypatch: MonkeyPatch,
async_client: AsyncClient,
):
dataset_name, version_name, _ = generic_raster_version
api_key, payload = apikey
origin = "https://" + payload["domains"][0]

headers = {"origin": origin, "x-api-key": api_key}

monkeypatch.setattr(queries, "_start_batch_execution", start_batch_execution_mocked)
monkeypatch.setattr(queries, "_verify_source_file_access", lambda source_uris: True)

payload = {
"sql": "select count(*) from data",
"uri": "s3://path/to/files",
}

response = await async_client.post(
f"/dataset/{dataset_name}/{version_name}/query/batch",
json=payload,
headers=headers,
)

print(response.json())
assert response.status_code == 202

data = response.json()["data"]

# assert valid UUID
try:
uuid = UUID(data["job_id"])
except ValueError:
assert False

assert str(uuid) == data["job_id"]

assert data["status"] == "pending"

assert response.json()["status"] == "success"


FEATURE_COLLECTION = {
"type": "FeatureCollection",
"features": [
Expand Down

0 comments on commit 6e2c1e6

Please sign in to comment.