Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jterry64 committed Jul 11, 2024
1 parent 2ae397d commit 66d2cf7
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 45 deletions.
2 changes: 1 addition & 1 deletion app/models/pydantic/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class QueryRequestIn(StrictBaseModel):
sql: str


class QueryListRequestIn(StrictBaseModel):
class QueryBatchRequestIn(StrictBaseModel):
feature_collection: Optional[FeatureCollection]
uri: Optional[str]
sql: str
Expand Down
51 changes: 28 additions & 23 deletions app/routes/datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKey
from fastapi.responses import RedirectResponse
from fastapi.responses import ORJSONResponse, RedirectResponse
from pglast import printers # noqa
from pglast import Node, parse_sql
from pglast.parser import ParseError
Expand Down Expand Up @@ -73,7 +73,7 @@
from ...models.pydantic.geostore import Geometry, GeostoreCommon
from ...models.pydantic.query import (
CsvQueryRequestIn,
QueryListRequestIn,
QueryBatchRequestIn,
QueryRequestIn,
)
from ...models.pydantic.raster_analysis import (
Expand Down Expand Up @@ -328,15 +328,16 @@ async def query_dataset_csv_post(


@router.post(
"/{dataset}/{version}/query/list",
response_class=ORJSONLiteResponse,
"/{dataset}/{version}/query/batch",
response_class=ORJSONResponse,
response_model=UserJobResponse,
tags=["Query"],
status_code=202,
)
async def query_dataset_list_post(
*,
dataset_version: Tuple[str, str] = Depends(dataset_version_dependency),
request: QueryListRequestIn,
request: QueryBatchRequestIn,
api_key: APIKey = Depends(get_api_key),
):
"""Execute a READ-ONLY SQL query on the given dataset version (if
Expand All @@ -347,18 +348,20 @@ async def query_dataset_list_post(
default_asset: AssetORM = await assets.get_default_asset(dataset, version)
if default_asset.asset_type != AssetType.raster_tile_set:
raise HTTPException(
status_code=422,
status_code=400,
detail="Querying on lists is only available for raster tile sets.",
)

if (
request.feature_collection.type != "Polygon"
and request.feature_collection.type != "MultiPolygon"
):
raise HTTPException(
status_code=422,
detail="Feature collection must be a Polygon or MultiPolygon for raster analysis",
)
if request.feature_collection:
for feature in request.feature_collection.features:
if (
feature.geometry.type != "Polygon"
and feature.geometry.type != "MultiPolygon"
):
raise HTTPException(
status_code=400,
detail="Feature collection must only contain Polygons or MultiPolygons for raster analysis",
)

job_id = uuid.uuid4()

Expand All @@ -367,30 +370,32 @@ async def query_dataset_list_post(
dataset, default_asset.creation_options["pixel_meaning"]
)
grid = default_asset.creation_options["grid"]
sql = re.sub(
"from \w+", f"from {default_layer}", request.query, flags=re.IGNORECASE
)
sql = re.sub("from \w+", f"from {default_layer}", request.sql, flags=re.IGNORECASE)
data_environment = await _get_data_environment(grid)

payload = {
input = {
"feature_collection": jsonable_encoder(request.feature_collection),
"query": sql,
"environment": data_environment.dict()["layers"],
}

try:
get_sfn_client().start_execution(
stateMachineArn=STATE_MACHINE_ARN,
name=job_id,
input=json.dumps(payload),
)
await _start_batch_execution(job_id, input)
except botocore.exceptions.ClientError as error:
logger.error(error)
return HTTPException(500, "There was an error starting your job.")

return UserJobResponse(data=UserJob(job_id=job_id))


async def _start_batch_execution(job_id: UUID, input: Dict[str, Any]) -> None:
get_sfn_client().start_execution(
stateMachineArn=STATE_MACHINE_ARN,
name=job_id,
input=json.dumps(input),
)


async def _query_dataset_json(
dataset: str,
version: str,
Expand Down
4 changes: 2 additions & 2 deletions scripts/test_v2
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ fi
set +e

# Everything from "--cov-report on" become the arguments to the pytest run inside the docker.
docker-compose -f docker-compose.test.yml --project-name gfw-data-api_test run --rm --name app_test app_test --cov-report xml:/app/tests_v2/cobertura.xml $DO_COV $DISABLE_WARNINGS $SHOW_STDOUT $args
docker compose -f docker-compose.test.yml --project-name gfw-data-api_test run --rm --name app_test app_test --cov-report xml:/app/tests_v2/cobertura.xml $DO_COV $DISABLE_WARNINGS $SHOW_STDOUT $args
exit_code=$?
docker-compose -f docker-compose.test.yml --project-name gfw-data-api_test down --remove-orphans
docker compose -f docker-compose.test.yml --project-name gfw-data-api_test down --remove-orphans
exit $exit_code
150 changes: 131 additions & 19 deletions tests_v2/unit/app/routes/datasets/test_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple
from unittest.mock import Mock
from urllib.parse import parse_qsl, urlparse
from uuid import UUID

import pytest
from _pytest.monkeypatch import MonkeyPatch
Expand All @@ -11,7 +12,11 @@
from app.routes.datasets import queries
from app.routes.datasets.queries import _get_data_environment
from tests_v2.fixtures.creation_options.versions import RASTER_CREATION_OPTIONS
from tests_v2.utils import custom_raster_version, invoke_lambda_mocked
from tests_v2.utils import (
custom_raster_version,
invoke_lambda_mocked,
start_batch_execution_mocked,
)


@pytest.mark.skip("Temporarily skip until we require API keys")
Expand Down Expand Up @@ -75,23 +80,22 @@ async def test_query_dataset_with_unrestricted_api_key(


@pytest.mark.asyncio
async def test_fields_dataset_raster(
generic_raster_version, async_client: AsyncClient
):
async def test_fields_dataset_raster(generic_raster_version, async_client: AsyncClient):
dataset_name, version_name, _ = generic_raster_version
response = await async_client.get(f"/dataset/{dataset_name}/{version_name}/fields")

assert response.status_code == 200
data = response.json()["data"]
assert len(data) == 4
assert data[0]["pixel_meaning"] == 'area__ha'
assert data[0]["values_table"] == None
assert data[1]["pixel_meaning"] == 'latitude'
assert data[1]["values_table"] == None
assert data[2]["pixel_meaning"] == 'longitude'
assert data[2]["values_table"] == None
assert data[3]["pixel_meaning"] == 'my_first_dataset__year'
assert data[3]["values_table"] == None
assert data[0]["pixel_meaning"] == "area__ha"
assert data[0]["values_table"] is None
assert data[1]["pixel_meaning"] == "latitude"
assert data[1]["values_table"] is None
assert data[2]["pixel_meaning"] == "longitude"
assert data[2]["values_table"] is None
assert data[3]["pixel_meaning"] == "my_first_dataset__year"
assert data[3]["values_table"] is None


@pytest.mark.asyncio
async def test_query_dataset_raster_bad_get(
Expand Down Expand Up @@ -237,9 +241,13 @@ async def test_redirect_get_query(
follow_redirects=False,
)
assert response.status_code == 308
assert (
parse_qsl(urlparse(response.headers["location"]).query, strict_parsing=True)
== parse_qsl(urlparse(f"/dataset/{dataset_name}/{version_name}/query/json?{response.request.url.query.decode('utf-8')}").query, strict_parsing=True)
assert parse_qsl(
urlparse(response.headers["location"]).query, strict_parsing=True
) == parse_qsl(
urlparse(
f"/dataset/{dataset_name}/{version_name}/query/json?{response.request.url.query.decode('utf-8')}"
).query,
strict_parsing=True,
)


Expand Down Expand Up @@ -483,9 +491,10 @@ async def test_query_vector_asset_disallowed_10(
"You might need to add explicit type casts."
)


@pytest.mark.asyncio()
async def test_query_licensed_disallowed_11(
licensed_version, apikey, async_client: AsyncClient
licensed_version, apikey, async_client: AsyncClient
):
dataset, version, _ = licensed_version

Expand All @@ -499,9 +508,8 @@ async def test_query_licensed_disallowed_11(
follow_redirects=True,
)
assert response.status_code == 401
assert response.json()["message"] == (
"Unauthorized query on a restricted dataset"
)
assert response.json()["message"] == ("Unauthorized query on a restricted dataset")


@pytest.mark.asyncio
@pytest.mark.skip("Temporarily skip while _get_data_environment is being cached")
Expand Down Expand Up @@ -655,3 +663,107 @@ async def test__get_data_environment_helper_called(
no_data_value,
None,
)


@pytest.mark.asyncio
async def test_query_batch(
generic_raster_version,
apikey,
monkeypatch: MonkeyPatch,
async_client: AsyncClient,
):
dataset_name, version_name, _ = generic_raster_version
api_key, payload = apikey
origin = "https://" + payload["domains"][0]

headers = {"origin": origin, "x-api-key": api_key}

monkeypatch.setattr(queries, "_start_batch_execution", start_batch_execution_mocked)
payload = {
"sql": "select count(*) from data",
"feature_collection": FEATURE_COLLECTION,
}

response = await async_client.post(
f"/dataset/{dataset_name}/{version_name}/query/batch",
json=payload,
headers=headers,
)

print(response.json())
assert response.status_code == 202

data = response.json()["data"]

# assert valid UUID
try:
uuid = UUID(data["job_id"])
except ValueError:
assert False

assert str(uuid) == data["job_id"]

assert data["status"] == "pending"

assert response.json()["status"] == "success"


FEATURE_COLLECTION = {
"type": "FeatureCollection",
"features": [
{
"type": "Feature",
"properties": {
"id": 1,
},
"geometry": {
"coordinates": [
[
[-57.43488539218248, -11.378524299779286],
[-57.43488539218248, -11.871111619666053],
[-56.950732779425806, -11.871111619666053],
[-56.950732779425806, -11.378524299779286],
[-57.43488539218248, -11.378524299779286],
]
],
"type": "Polygon",
},
},
{
"type": "Feature",
"properties": {
"id": 2,
},
"geometry": {
"coordinates": [
[
[-55.84751191303597, -11.845408946893727],
[-55.84751191303597, -12.293066281588139],
[-55.32975635387763, -12.293066281588139],
[-55.32975635387763, -11.845408946893727],
[-55.84751191303597, -11.845408946893727],
]
],
"type": "Polygon",
},
},
{
"type": "Feature",
"properties": {
"id": 3,
},
"geometry": {
"coordinates": [
[
[-58.36172075077614, -12.835185539172727],
[-58.36172075077614, -13.153322454532116],
[-57.98648069126074, -13.153322454532116],
[-57.98648069126074, -12.835185539172727],
[-58.36172075077614, -12.835185539172727],
]
],
"type": "Polygon",
},
},
],
}
4 changes: 4 additions & 0 deletions tests_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ async def invoke_lambda_mocked(
return httpx.Response(200, json={"status": "success", "data": []})


async def start_batch_execution_mocked(job_id: uuid.UUID, input: Dict[str, Any]):
pass


def void_function(*args, **kwargs) -> None:
return

Expand Down

0 comments on commit 66d2cf7

Please sign in to comment.