diff --git a/.isort.cfg b/.isort.cfg index d4fdf72a8..bb0418102 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -2,4 +2,4 @@ line_length = 88 multi_line_output = 3 include_trailing_comma = True -known_third_party = _pytest,aenum,affine,aiohttp,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,typer +known_third_party = _pytest,aenum,affine,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,typer diff --git a/app/main.py b/app/main.py index 4391045d8..ef1a5004a 100644 --- a/app/main.py +++ b/app/main.py @@ -31,6 +31,7 @@ versions, ) from .routes.geostore import geostore as geostore_top +from .routes.jobs import job from .routes.tasks import task ################ @@ -161,6 +162,16 @@ async def rve_error_handler( for r in analysis_routers: app.include_router(r, prefix="/analysis") + +############### +# JOB API +############### + +job_routes = (job.router,) +for r in job_routes: + app.include_router(r, prefix="/job") + + ############### # HEALTH API ############### @@ -185,6 +196,7 @@ async def rve_error_handler( {"name": "Geostore", "description": geostore.__doc__}, {"name": "Tasks", "description": task.__doc__}, {"name": "Analysis", "description": analysis.__doc__}, + {"name": "Job", "description": job.__doc__}, {"name": "Health", "description": health.__doc__}, ] diff --git a/app/models/pydantic/query.py b/app/models/pydantic/query.py index db67a6bf9..b5ee05ded 100644 --- a/app/models/pydantic/query.py +++ b/app/models/pydantic/query.py @@ -2,7 +2,8 @@ from app.models.enum.creation_options import Delimiters from app.models.pydantic.base import StrictBaseModel -from app.models.pydantic.geostore import Geometry +from app.models.pydantic.geostore import FeatureCollection, Geometry +from pydantic import Field class QueryRequestIn(StrictBaseModel): @@ -10,5 +11,18 @@ class QueryRequestIn(StrictBaseModel): sql: str +class QueryBatchRequestIn(StrictBaseModel): + feature_collection: Optional[FeatureCollection] = Field( + None, description="An inline collection of GeoJson features on which to do the same query" + ) + uri: Optional[str] = Field( + None, description="URI to a vector file in a variety of formats supported by Geopandas, including GeoJson and CSV format, giving a list of features on which to do the same query. For a CSV file, the column with the geometry in WKB format should be named 'WKT' (not 'WKB')" + ) + id_field: str = Field( + "fid", description="Name of field with the feature id, for use in labeling the results for each feature. This field must contain a unique value for each feature." + ) + sql: str + + class CsvQueryRequestIn(QueryRequestIn): delimiter: Delimiters = Delimiters.comma diff --git a/app/models/pydantic/user_job.py b/app/models/pydantic/user_job.py new file mode 100644 index 000000000..7936deaf9 --- /dev/null +++ b/app/models/pydantic/user_job.py @@ -0,0 +1,20 @@ +from typing import Optional +from uuid import UUID + +from pydantic import BaseModel + +from .responses import Response + + +class UserJob(BaseModel): + job_id: UUID + job_link: Optional[str] # Full URL to check the job status + status: str = "pending" # Can be pending, success, partial_success, failure, and error + message: Optional[str] # Error message when status is "error" + download_link: Optional[str] = None + failed_geometries_link: Optional[str] = None + progress: Optional[str] = "0%" + + +class UserJobResponse(Response): + data: UserJob diff --git a/app/routes/assets/asset.py b/app/routes/assets/asset.py index 1a80cd959..c3bdece96 100644 --- a/app/routes/assets/asset.py +++ b/app/routes/assets/asset.py @@ -29,8 +29,8 @@ from app.models.pydantic.responses import Response from app.settings.globals import API_URL -from ..datasets.downloads import _get_presigned_url +from ...authentication.token import get_manager from ...crud import assets from ...crud import metadata as metadata_crud from ...crud import tasks @@ -49,6 +49,7 @@ asset_metadata_factory, ) from ...models.pydantic.assets import AssetResponse, AssetType, AssetUpdateIn +from ...models.pydantic.authentication import User from ...models.pydantic.change_log import ChangeLog, ChangeLogResponse from ...models.pydantic.creation_options import ( CreationOptions, @@ -69,11 +70,9 @@ from ...utils.paginate import paginate_collection from ...utils.path import infer_srid_from_grid, split_s3_path from ..assets import asset_response -from ..tasks import paginated_tasks_response, tasks_response - -from ...authentication.token import get_manager -from ...models.pydantic.authentication import User +from ..datasets import _get_presigned_url from ..datasets.dataset import get_owner +from ..tasks import paginated_tasks_response, tasks_response router = APIRouter() @@ -111,7 +110,8 @@ async def update_asset( ) -> AssetResponse: """Update Asset metadata. - Only the dataset's owner or a user with `ADMIN` user role can do this operation. + Only the dataset's owner or a user with `ADMIN` user role can do + this operation. """ try: @@ -322,7 +322,7 @@ async def get_tiles_info(asset_id: UUID = Path(...)): if asset.asset_type != AssetType.raster_tile_set: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Tiles information only available for raster tile sets" + detail="Tiles information only available for raster tile sets", ) bucket, asset_key = split_s3_path(asset.asset_uri) @@ -383,12 +383,16 @@ async def get_field_metadata(*, asset_id: UUID = Path(...), field_name: str): response_model=FieldMetadataResponse, ) async def update_field_metadata( - *, asset_id: UUID = Path(...), field_name: str, request: FieldMetadataUpdate, + *, + asset_id: UUID = Path(...), + field_name: str, + request: FieldMetadataUpdate, user: User = Depends(get_manager), ): """Update the field metadata for an asset. - Only the dataset's owner or a user with `ADMIN` user role can do this operation. + Only the dataset's owner or a user with `ADMIN` user role can do + this operation. """ try: @@ -434,7 +438,8 @@ async def get_metadata(asset_id: UUID = Path(...)): async def create_metadata(*, asset_id: UUID = Path(...), request: AssetMetadata): """Create metadata record for an asset. - Only the dataset's owner or a user with `ADMIN` user role can do this operation. + Only the dataset's owner or a user with `ADMIN` user role can do + this operation. """ input_data = request.dict(exclude_none=True, by_alias=True) asset = await assets.get_asset(asset_id) @@ -457,11 +462,16 @@ async def create_metadata(*, asset_id: UUID = Path(...), request: AssetMetadata) tags=["Assets"], response_model=AssetMetadataResponse, ) -async def update_metadata(*, asset_id: UUID = Path(...), request: AssetMetadataUpdate, - user: User = Depends(get_manager)): +async def update_metadata( + *, + asset_id: UUID = Path(...), + request: AssetMetadataUpdate, + user: User = Depends(get_manager), +): """Update metadata record for an asset. - Only the dataset's owner or a user with `ADMIN` user role can do this operation. + Only the dataset's owner or a user with `ADMIN` user role can do + this operation. """ input_data = request.dict(exclude_none=True, by_alias=True) @@ -488,11 +498,13 @@ async def update_metadata(*, asset_id: UUID = Path(...), request: AssetMetadataU tags=["Assets"], response_model=AssetMetadataResponse, ) -async def delete_metadata(asset_id: UUID = Path(...), - user: User = Depends(get_manager)): +async def delete_metadata( + asset_id: UUID = Path(...), user: User = Depends(get_manager) +): """Delete an asset's metadata record. - Only the dataset's owner or a user with `ADMIN` user role can do this operation. + Only the dataset's owner or a user with `ADMIN` user role can do + this operation. """ try: asset = await assets.get_asset(asset_id) diff --git a/app/routes/datasets/__init__.py b/app/routes/datasets/__init__.py index 663acb2c1..f9dfc0aa7 100644 --- a/app/routes/datasets/__init__.py +++ b/app/routes/datasets/__init__.py @@ -1,5 +1,8 @@ -from typing import Any, Dict, List +from collections import defaultdict +from typing import Any, Dict, List, Sequence +from urllib.parse import urlparse +from botocore.exceptions import ClientError from fastapi import HTTPException from ...crud import assets @@ -11,6 +14,20 @@ from ...tasks.raster_tile_set_assets.raster_tile_set_assets import ( raster_tile_set_validator, ) +from ...utils.aws import get_aws_files, get_s3_client +from ...utils.google import get_gs_files +from ...utils.path import split_s3_path + +SUPPORTED_FILE_EXTENSIONS: Sequence[str] = ( + ".csv", + ".geojson", + ".gpkg", + ".ndjson", + ".shp", + ".tif", + ".tsv", + ".zip", +) async def verify_version_status(dataset, version): @@ -82,3 +99,81 @@ async def validate_creation_options( await validator[input_data["asset_type"]](dataset, version, input_data) except KeyError: pass + + +# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried: +# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]] +source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore +source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore + + +def _verify_source_file_access(sources: List[str]) -> None: + + # TODO: + # 1. Making the list functions asynchronous and using asyncio.gather + # to check for valid sources in a non-blocking fashion would be good. + # Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs. + # 2. It would be nice if the acceptable file extensions were passed + # into this function so we could say, for example, that there must be + # TIFFs found for a new raster tile set, but a CSV is required for a new + # vector tile set version. Even better would be to specify whether + # paths to individual files or "folders" (prefixes) are allowed. + + invalid_sources: List[str] = list() + + for source in sources: + url_parts = urlparse(source, allow_fragments=False) + list_func = source_uri_lister_constructor[url_parts.scheme.lower()] + bucket = url_parts.netloc + prefix = url_parts.path.lstrip("/") + + # Allow pseudo-globbing: Tolerate a "*" at the end of a + # src_uri entry to allow partial prefixes (for example + # /bucket/prefix_part_1/prefix_fragment* will match + # /bucket/prefix_part_1/prefix_fragment_1.tif and + # /bucket/prefix_part_1/prefix_fragment_2.tif, etc.) + # If the prefix doesn't end in "*" or an acceptable file extension + # add a "/" to the end of the prefix to enforce it being a "folder". + new_prefix: str = prefix + if new_prefix.endswith("*"): + new_prefix = new_prefix[:-1] + elif not new_prefix.endswith("/") and not any( + [new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS] + ): + new_prefix += "/" + + if not list_func( + bucket, + new_prefix, + limit=10, + exit_after_max=1, + extensions=SUPPORTED_FILE_EXTENSIONS, + ): + invalid_sources.append(source) + + if invalid_sources: + raise HTTPException( + status_code=400, + detail=( + "Cannot access all of the source files (non-existent or access denied). " + f"Invalid sources: {invalid_sources}" + ), + ) + + +async def _get_presigned_url_from_path(path): + bucket, key = split_s3_path(path) + return await _get_presigned_url(bucket, key) + + +async def _get_presigned_url(bucket, key): + s3_client = get_s3_client() + try: + presigned_url = s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=900 + ) + except ClientError: + raise HTTPException( + status_code=404, detail="Requested resources does not exist." + ) + return presigned_url diff --git a/app/routes/datasets/downloads.py b/app/routes/datasets/downloads.py index 1645a5397..1e3d66315 100644 --- a/app/routes/datasets/downloads.py +++ b/app/routes/datasets/downloads.py @@ -3,16 +3,13 @@ from typing import Any, Dict, List, Optional, Tuple from uuid import UUID, uuid4 -from aiohttp import ClientError from fastapi import APIRouter, Depends, HTTPException, Query # from fastapi.openapi.models import APIKey from fastapi.responses import RedirectResponse -# from ...authentication.api_keys import get_api_key from ...crud.assets import get_assets_by_filter from ...crud.versions import get_version -from ...main import logger from ...models.enum.assets import AssetType from ...models.enum.creation_options import Delimiters from ...models.enum.geostore import GeostoreOrigin @@ -20,10 +17,12 @@ from ...models.pydantic.downloads import DownloadCSVIn, DownloadJSONIn from ...models.pydantic.geostore import GeostoreCommon from ...responses import CSVStreamingResponse, ORJSONStreamingResponse -from ...utils.aws import get_s3_client from ...utils.geostore import get_geostore from ...utils.path import split_s3_path from .. import dataset_version_dependency + +# from ...authentication.api_keys import get_api_key +from . import _get_presigned_url from .queries import _query_dataset_csv, _query_dataset_json router: APIRouter = APIRouter() @@ -37,7 +36,10 @@ async def download_json( dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), sql: str = Query(..., description="SQL query."), - geostore_id: Optional[UUID] = Query(None, description="Geostore ID. The geostore must represent a Polygon or MultiPolygon."), + geostore_id: Optional[UUID] = Query( + None, + description="Geostore ID. The geostore must represent a Polygon or MultiPolygon.", + ), geostore_origin: GeostoreOrigin = Query( GeostoreOrigin.gfw, description="Service to search first for geostore." ), @@ -118,7 +120,10 @@ async def download_json_post( async def download_csv( dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), sql: str = Query(..., description="SQL query."), - geostore_id: Optional[UUID] = Query(None, description="Geostore ID. The geostore must represent a Polygon or MultiPolygon."), + geostore_id: Optional[UUID] = Query( + None, + description="Geostore ID. The geostore must represent a Polygon or MultiPolygon.", + ), geostore_origin: GeostoreOrigin = Query( GeostoreOrigin.gfw, description="Service to search first for geostore." ), @@ -316,20 +321,6 @@ async def _get_asset_url(dataset: str, version: str, asset_type: str) -> str: return assets[0].asset_uri -async def _get_presigned_url(bucket, key): - s3_client = get_s3_client() - try: - presigned_url = s3_client.generate_presigned_url( - "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=900 - ) - except ClientError as e: - logger.error(e) - raise HTTPException( - status_code=404, detail="Requested resources does not exist." - ) - return presigned_url - - async def _check_downloadability(dataset, version): v = await get_version(dataset, version) if not v.is_downloadable: diff --git a/app/routes/datasets/queries.py b/app/routes/datasets/queries.py index 42aebe2be..51ebdfe81 100755 --- a/app/routes/datasets/queries.py +++ b/app/routes/datasets/queries.py @@ -1,11 +1,14 @@ """Explore data entries for a given dataset version using standard SQL.""" import csv +import json import re +import uuid from io import StringIO from typing import Any, Dict, List, Optional, Tuple, Union, cast from urllib.parse import unquote from uuid import UUID, uuid4 +import botocore import httpx from async_lru import alru_cache from asyncpg import DataError, InsufficientPrivilegeError, SyntaxOrAccessError @@ -15,16 +18,17 @@ from fastapi.encoders import jsonable_encoder from fastapi.logger import logger from fastapi.openapi.models import APIKey -from fastapi.responses import RedirectResponse +from fastapi.responses import ORJSONResponse, RedirectResponse from pglast import printers # noqa from pglast import Node, parse_sql from pglast.parser import ParseError from pglast.printer import RawStream from pydantic.tools import parse_obj_as +from app.settings.globals import API_URL -from ...authentication.token import is_gfwpro_admin_for_query from ...application import db from ...authentication.api_keys import get_api_key +from ...authentication.token import is_gfwpro_admin_for_query from ...crud import assets from ...models.enum.assets import AssetType from ...models.enum.creation_options import Delimiters @@ -63,7 +67,11 @@ from ...models.pydantic.asset_metadata import RasterTable, RasterTableRow from ...models.pydantic.creation_options import NoDataType from ...models.pydantic.geostore import Geometry, GeostoreCommon -from ...models.pydantic.query import CsvQueryRequestIn, QueryRequestIn +from ...models.pydantic.query import ( + CsvQueryRequestIn, + QueryBatchRequestIn, + QueryRequestIn, +) from ...models.pydantic.raster_analysis import ( DataEnvironment, DerivedLayer, @@ -71,11 +79,17 @@ SourceLayer, ) from ...models.pydantic.responses import Response +from ...models.pydantic.user_job import UserJob, UserJobResponse from ...responses import CSVStreamingResponse, ORJSONLiteResponse -from ...settings.globals import GEOSTORE_SIZE_LIMIT_OTF, RASTER_ANALYSIS_LAMBDA_NAME -from ...utils.aws import invoke_lambda +from ...settings.globals import ( + GEOSTORE_SIZE_LIMIT_OTF, + RASTER_ANALYSIS_LAMBDA_NAME, + RASTER_ANALYSIS_STATE_MACHINE_ARN, +) +from ...utils.aws import get_sfn_client, invoke_lambda from ...utils.geostore import get_geostore from .. import dataset_version_dependency +from . import _verify_source_file_access router = APIRouter() @@ -83,6 +97,7 @@ # Special suffixes to do an extra area density calculation on the raster data set. AREA_DENSITY_RASTER_SUFFIXES = ["_ha-1", "_ha_yr-1"] + @router.get( "/{dataset}/{version}/query", response_class=RedirectResponse, @@ -123,7 +138,10 @@ async def query_dataset_json( response: FastApiResponse, dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), sql: str = Query(..., description="SQL query."), - geostore_id: Optional[UUID] = Query(None, description="Geostore ID. The geostore must represent a Polygon or MultiPolygon."), + geostore_id: Optional[UUID] = Query( + None, + description="Geostore ID. The geostore must represent a Polygon or MultiPolygon.", + ), geostore_origin: GeostoreOrigin = Query( GeostoreOrigin.gfw, description="Service to search first for geostore." ), @@ -150,7 +168,6 @@ async def query_dataset_json( referenced. There are also several reserved fields with special meaning that can be used, including "area__ha", "latitude", and "longitude". - """ dataset, version = dataset_version @@ -182,7 +199,10 @@ async def query_dataset_csv( response: FastApiResponse, dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), sql: str = Query(..., description="SQL query."), - geostore_id: Optional[UUID] = Query(None, description="Geostore ID. The geostore must represent a Polygon or MultiPolygon."), + geostore_id: Optional[UUID] = Query( + None, + description="Geostore ID. The geostore must represent a Polygon or MultiPolygon.", + ), geostore_origin: GeostoreOrigin = Query( GeostoreOrigin.gfw, description="Service to search first for geostore." ), @@ -260,7 +280,6 @@ async def query_dataset_json_post( dataset, version = dataset_version - # create geostore with unknowns as blank if request.geometry: geostore: Optional[GeostoreCommon] = GeostoreCommon( geojson=request.geometry, geostore_id=uuid4(), area__ha=0, bbox=[0, 0, 0, 0] @@ -305,6 +324,114 @@ async def query_dataset_csv_post( return CSVStreamingResponse(iter([csv_data.getvalue()]), download=False) +@router.post( + "/{dataset}/{version}/query/batch", + response_class=ORJSONResponse, + response_model=UserJobResponse, + tags=["Query"], + status_code=202, +) +async def query_dataset_list_post( + *, + dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), + request: QueryBatchRequestIn, + api_key: APIKey = Depends(get_api_key), +): + """Execute a READ-ONLY SQL query on the specified raster-based dataset version + for a potentially large list of features. The features may be specified by an + inline GeoJson feature collection or the URI of vector file that is in any of a + variety of formats supported by GeoPandas, include GeoJson and CSV format. For + CSV files, the geometry column should be named "WKT" (not "WKB") and the geometry + values should be in WKB format. + + The specified sql query will be run on each individual feature, and so may take a + while. Therefore, the results of this query include a job_id. The user should + then periodically query the specified job via the /job/{job_id} api. When the + "data.status" indicates "success" or "partial_success", then the successful + results will be available at the specified "data.download_link". When the + "data.status" indicates "partial_success" or "failed", then failed results + (likely because of improper geometries) will be available at + "data.failed_geometries_link". If the "data.status" indicates "error", then there + will be no results available (nothing was able to complete, possible because of + an infrastructure problem). + + There is currently a five-minute time limit on the entire list query, but up to + 100 individual feature queries proceed in parallel, so lists with several + thousands of features can potentially be processed within that time limit. + + """ + + dataset, version = dataset_version + + default_asset: AssetORM = await assets.get_default_asset(dataset, version) + if default_asset.asset_type != AssetType.raster_tile_set: + raise HTTPException( + status_code=400, + detail="Querying on lists is only available for raster tile sets.", + ) + + if request.feature_collection: + for feature in request.feature_collection.features: + if ( + feature.geometry.type != "Polygon" + and feature.geometry.type != "MultiPolygon" + ): + raise HTTPException( + status_code=400, + detail="Feature collection must only contain Polygons or MultiPolygons for raster analysis", + ) + + job_id = uuid.uuid4() + + # get grid, query and data environment based on default asset + default_layer = _get_default_layer( + dataset, default_asset.creation_options["pixel_meaning"] + ) + grid = default_asset.creation_options["grid"] + sql = re.sub("from \w+", f"from {default_layer}", request.sql, flags=re.IGNORECASE) + data_environment = await _get_data_environment(grid) + + input = { + "query": sql, + "id_field": request.id_field, + "environment": data_environment.dict()["layers"], + } + + if request.feature_collection is not None: + if request.uri is not None: + raise HTTPException( + status_code=400, + detail="Must provide only one of valid feature collection or URI.", + ) + + input["feature_collection"] = jsonable_encoder(request.feature_collection) + elif request.uri is not None: + _verify_source_file_access([request.uri]) + input["uri"] = request.uri + else: + raise HTTPException( + status_code=400, + detail="Must provide valid feature collection or URI.", + ) + + try: + await _start_batch_execution(job_id, input) + except botocore.exceptions.ClientError as error: + logger.error(error) + return HTTPException(500, "There was an error starting your job.") + + job_link = f"{API_URL}/job/{job_id}" + return UserJobResponse(data=UserJob(job_id=job_id, job_link=job_link)) + + +async def _start_batch_execution(job_id: UUID, input: Dict[str, Any]) -> None: + get_sfn_client().start_execution( + stateMachineArn=RASTER_ANALYSIS_STATE_MACHINE_ARN, + name=str(job_id), + input=json.dumps(input), + ) + + async def _query_dataset_json( dataset: str, version: str, @@ -660,7 +787,7 @@ def _get_area_density_name(nm): return nm with the area-density suffix removed.""" for suffix in AREA_DENSITY_RASTER_SUFFIXES: if nm.endswith(suffix): - return nm[:-len(suffix)] + return nm[: -len(suffix)] return "" diff --git a/app/routes/datasets/versions.py b/app/routes/datasets/versions.py index 6a09530f7..89a2f1722 100644 --- a/app/routes/datasets/versions.py +++ b/app/routes/datasets/versions.py @@ -9,10 +9,8 @@ Available assets and endpoints to choose from depend on the source type. """ -from collections import defaultdict from copy import deepcopy -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast -from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Tuple, Union, cast from fastapi import ( APIRouter, @@ -43,8 +41,8 @@ from ...models.pydantic.creation_options import ( CreationOptions, CreationOptionsResponse, + TableDrivers, creation_option_factory, - TableDrivers ) from ...models.pydantic.extent import Extent, ExtentResponse from ...models.pydantic.metadata import ( @@ -67,29 +65,12 @@ from ...tasks.aws_tasks import flush_cloudfront_cache from ...tasks.default_assets import append_default_asset, create_default_asset from ...tasks.delete_assets import delete_all_assets -from ...utils.aws import get_aws_files -from ...utils.google import get_gs_files +from . import _verify_source_file_access from .dataset import get_owner from .queries import _get_data_environment router = APIRouter() -SUPPORTED_FILE_EXTENSIONS: Sequence[str] = ( - ".csv", - ".geojson", - ".gpkg", - ".ndjson", - ".shp", - ".tif", - ".tsv", - ".zip", -) - -# I cannot seem to satisfy mypy WRT the type of this default dict. Last thing I tried: -# DefaultDict[str, Callable[[str, str, int, int, ...], List[str]]] -source_uri_lister_constructor = defaultdict((lambda: lambda w, x, limit=None, exit_after_max=None, extensions=None: list())) # type: ignore -source_uri_lister_constructor.update(**{"gs": get_gs_files, "s3": get_aws_files}) # type: ignore - @router.get( "/{dataset}/{version}", @@ -212,7 +193,7 @@ async def update_version( "/{dataset}/{version}/append", response_class=ORJSONResponse, tags=["Versions"], - response_model=VersionResponse + response_model=VersionResponse, ) async def append_to_version( *, @@ -243,14 +224,14 @@ async def append_to_version( input_data["creation_options"]["source_uri"] = request.source_uri # If source_driver is "text", this is a datapump request - if input_data["creation_options"]["source_driver"] != TableDrivers.text: + if input_data["creation_options"]["source_driver"] != TableDrivers.text: # Verify that source_driver matches the original source_driver # TODO: Ideally append source_driver should not need to match the original source_driver, # but this would break other operations that expect only one source_driver if input_data["creation_options"]["source_driver"] != request.source_driver: raise HTTPException( status_code=400, - detail="source_driver must match the original source_driver" + detail="source_driver must match the original source_driver", ) # Use layers from request if provided, else set to None if layers are in version creation_options @@ -267,7 +248,7 @@ async def append_to_version( # Now update the version's creation_options to reflect the changes from the append request update_data = {"creation_options": deepcopy(default_asset.creation_options)} - update_data["creation_options"]["source_uri"] += request.source_uri + update_data["creation_options"]["source_uri"] += request.source_uri if request.layers is not None: if update_data["creation_options"]["layers"] is not None: update_data["creation_options"]["layers"] += request.layers @@ -561,56 +542,3 @@ async def _version_response( data["assets"] = [(asset[0], asset[1], str(asset[2])) for asset in assets] return VersionResponse(data=Version(**data)) - -def _verify_source_file_access(sources: List[str]) -> None: - - # TODO: - # 1. Making the list functions asynchronous and using asyncio.gather - # to check for valid sources in a non-blocking fashion would be good. - # Perhaps use the aioboto3 package for aws, gcloud-aio-storage for gcs. - # 2. It would be nice if the acceptable file extensions were passed - # into this function so we could say, for example, that there must be - # TIFFs found for a new raster tile set, but a CSV is required for a new - # vector tile set version. Even better would be to specify whether - # paths to individual files or "folders" (prefixes) are allowed. - - invalid_sources: List[str] = list() - - for source in sources: - url_parts = urlparse(source, allow_fragments=False) - list_func = source_uri_lister_constructor[url_parts.scheme.lower()] - bucket = url_parts.netloc - prefix = url_parts.path.lstrip("/") - - # Allow pseudo-globbing: Tolerate a "*" at the end of a - # src_uri entry to allow partial prefixes (for example - # /bucket/prefix_part_1/prefix_fragment* will match - # /bucket/prefix_part_1/prefix_fragment_1.tif and - # /bucket/prefix_part_1/prefix_fragment_2.tif, etc.) - # If the prefix doesn't end in "*" or an acceptable file extension - # add a "/" to the end of the prefix to enforce it being a "folder". - new_prefix: str = prefix - if new_prefix.endswith("*"): - new_prefix = new_prefix[:-1] - elif not new_prefix.endswith("/") and not any( - [new_prefix.endswith(suffix) for suffix in SUPPORTED_FILE_EXTENSIONS] - ): - new_prefix += "/" - - if not list_func( - bucket, - new_prefix, - limit=10, - exit_after_max=1, - extensions=SUPPORTED_FILE_EXTENSIONS, - ): - invalid_sources.append(source) - - if invalid_sources: - raise HTTPException( - status_code=400, - detail=( - "Cannot access all of the source files (non-existent or access denied). " - f"Invalid sources: {invalid_sources}" - ), - ) diff --git a/app/routes/jobs/__init__.py b/app/routes/jobs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/routes/jobs/job.py b/app/routes/jobs/job.py new file mode 100644 index 000000000..1b336c006 --- /dev/null +++ b/app/routes/jobs/job.py @@ -0,0 +1,126 @@ +"""Jobs represent long running analysis tasks. Certain APIs, like querying like +a list, will return immediately with a job_id. You can poll the job until it's +complete, and a download like will be provided. + +Jobs are only saved for 90 days. +""" +import json +from typing import Any, Dict +from uuid import UUID + +import botocore +from fastapi import APIRouter, HTTPException, Path +from fastapi.logger import logger +from fastapi.responses import ORJSONResponse + +from ...models.pydantic.user_job import UserJob, UserJobResponse +from ...settings.globals import RASTER_ANALYSIS_STATE_MACHINE_ARN +from ...utils.aws import get_sfn_client +from ..datasets import _get_presigned_url_from_path + +router = APIRouter() + + +@router.get( + "/{job_id}", + response_class=ORJSONResponse, + tags=["Jobs"], + response_model=UserJobResponse, +) +async def get_job(*, job_id: UUID = Path(...)) -> UserJobResponse: + """Get job status. + + Jobs expire after 90 days. + """ + try: + job = await _get_user_job(job_id) + return UserJobResponse(data=job) + except botocore.exceptions.ClientError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +async def _get_user_job(job_id: UUID) -> UserJob: + execution = await _get_sfn_execution(job_id) + + if execution["status"] == "SUCCEEDED": + output = ( + json.loads(execution["output"]) if execution["output"] is not None else None + ) + + if output["status"] == "success": + download_link = await _get_presigned_url_from_path( + output["data"]["download_link"] + ) + failed_geometries_link = None + elif output["status"] == "partial_success": + download_link = await _get_presigned_url_from_path( + output["data"]["download_link"] + ) + failed_geometries_link = await _get_presigned_url_from_path( + output["data"]["failed_geometries_link"] + ) + elif output["status"] == "failed": + download_link = None + failed_geometries_link = await _get_presigned_url_from_path( + output["data"]["failed_geometries_link"] + ) + else: + logger.error(f"Analysis service returned an unexpected response: {output}") + return UserJob( + job_id=job_id, + status="error", + message=output["message"], + download_link=None, + failed_geometries_link=None, + progress="0%", + ) + + return UserJob( + job_id=job_id, + status=output["status"], + download_link=download_link, + failed_geometries_link=failed_geometries_link, + progress=await _get_progress(execution), + ) + + elif execution["status"] == "RUNNING": + return UserJob( + job_id=job_id, + status="pending", + download_link=None, + failed_geometries_link=None, + progress=await _get_progress(execution), + ) + else: + return UserJob( + job_id=job_id, + status="failed", + download_link=None, + failed_geometries_link=None, + progress=None, + ) + + +async def _get_sfn_execution(job_id: UUID) -> Dict[str, Any]: + execution_arn = f"{RASTER_ANALYSIS_STATE_MACHINE_ARN.replace('stateMachine', 'execution')}:{str(job_id)}" + execution = get_sfn_client().describe_execution(executionArn=execution_arn) + return execution + + +async def _get_progress(execution: Dict[str, Any]) -> str: + map_run = await _get_map_run(execution) + if len(map_run) == 0: + # No map runs have started yet + return "0%" + success_ratio = map_run["itemCounts"]["succeeded"] / map_run["itemCounts"]["total"] + return f"{round(success_ratio * 100)}%" + + +async def _get_map_run(execution: Dict[str, Any]) -> Dict[str, Any]: + map_runs = get_sfn_client().list_map_runs(executionArn=execution["executionArn"])["mapRuns"] + if len(map_runs) == 0: + # No map runs have started yet, return empty dict + return {} + map_run_arn = map_runs[0]["mapRunArn"] + map_run = get_sfn_client().describe_map_run(mapRunArn=map_run_arn) + return map_run diff --git a/app/settings/globals.py b/app/settings/globals.py index 0ac0d348e..c72037fe7 100644 --- a/app/settings/globals.py +++ b/app/settings/globals.py @@ -8,7 +8,6 @@ from ..models.enum.pixetl import ResamplingMethod from ..models.pydantic.database import DatabaseURL - # Read .env file, if exists p: Path = Path(__file__).parents[2] / ".env" config: Config = Config(p if p.exists() else None) @@ -181,3 +180,7 @@ # Datasets that require admin privileges to do a query. (Extra protection on # commercial datasets which shouldn't be downloaded in any way.) PROTECTED_QUERY_DATASETS = ["wdpa_licensed_protected_areas"] + +RASTER_ANALYSIS_STATE_MACHINE_ARN = config( + "RASTER_ANALYSIS_STATE_MACHINE_ARN", cast=str, default=None +) diff --git a/app/utils/aws.py b/app/utils/aws.py index 2f9e2da66..5bb1e0a81 100644 --- a/app/utils/aws.py +++ b/app/utils/aws.py @@ -3,8 +3,8 @@ import boto3 import botocore import httpx -from httpx_auth import AWS4Auth from fastapi.logger import logger +from httpx_auth import AWS4Auth from ..settings.globals import ( AWS_REGION, @@ -38,6 +38,7 @@ def client(): get_api_gateway_client = client_constructor("apigateway") get_s3_client = client_constructor("s3", S3_ENTRYPOINT_URL) get_secret_client = client_constructor("secretsmanager", AWS_SECRETSMANAGER_URL) +get_sfn_client = client_constructor("stepfunctions") async def invoke_lambda( diff --git a/terraform/data.tf b/terraform/data.tf index 56b445327..d3748b037 100644 --- a/terraform/data.tf +++ b/terraform/data.tf @@ -69,6 +69,7 @@ data "template_file" "container_definition" { pixetl_job_definition = module.batch_job_queues.pixetl_job_definition_arn pixetl_job_queue = module.batch_job_queues.pixetl_job_queue_arn raster_analysis_lambda_name = "raster-analysis-tiled_raster_analysis-default" + raster_analysis_sfn_arn = data.terraform_remote_state.raster_analysis_lambda.outputs.raster_analysis_state_machine_arn service_url = local.service_url rw_api_url = var.rw_api_url api_token_secret_arn = data.terraform_remote_state.core.outputs.secrets_read-gfw-api-token_arn @@ -182,4 +183,11 @@ data "template_file" "tile_cache_bucket_policy" { vars = { bucket_arn = data.terraform_remote_state.tile_cache.outputs.tile_cache_bucket_arn } +} + +data "template_file" "step_function_policy" { + template = file("${path.root}/templates/step_function_policy.json.tmpl") + vars = { + raster_analysis_state_machine_arn = data.terraform_remote_state.raster_analysis_lambda.outputs.raster_analysis_state_machine_arn + } } \ No newline at end of file diff --git a/terraform/iam.tf b/terraform/iam.tf index 6d7691880..e0a137b49 100644 --- a/terraform/iam.tf +++ b/terraform/iam.tf @@ -37,4 +37,9 @@ resource "aws_iam_policy" "read_new_relic_secret" { resource "aws_iam_policy" "tile_cache_bucket_policy" { name = substr("${local.project}-tile_cache_bucket_policy${local.name_suffix}", 0, 64) policy = data.template_file.tile_cache_bucket_policy.rendered -} \ No newline at end of file +} + +resource "aws_iam_policy" "step_function_policy" { + name = substr("${local.project}-step_function_policy${local.name_suffix}", 0, 64) + policy = data.template_file.step_function_policy.rendered +} diff --git a/terraform/main.tf b/terraform/main.tf index 92b41c23f..263d70da0 100644 --- a/terraform/main.tf +++ b/terraform/main.tf @@ -106,7 +106,8 @@ module "fargate_autoscaling" { aws_iam_policy.read_gcs_secret.arn, data.terraform_remote_state.tile_cache.outputs.ecs_update_service_policy_arn, aws_iam_policy.tile_cache_bucket_policy.arn, - data.terraform_remote_state.tile_cache.outputs.cloudfront_invalidation_policy_arn + data.terraform_remote_state.tile_cache.outputs.cloudfront_invalidation_policy_arn, + aws_iam_policy.step_function_policy.arn, ] task_execution_role_policies = [ aws_iam_policy.query_batch_jobs.arn, diff --git a/terraform/modules/api_gateway/gateway/main.tf b/terraform/modules/api_gateway/gateway/main.tf index 61fe4ac26..ffbed1516 100644 --- a/terraform/modules/api_gateway/gateway/main.tf +++ b/terraform/modules/api_gateway/gateway/main.tf @@ -48,7 +48,7 @@ module "query_get" { require_api_key = true http_method = "GET" - authorization = "NONE" + authorization = "CUSTOM" integration_parameters = { "integration.request.path.version" = "method.request.path.version" @@ -75,7 +75,7 @@ module "query_post" { require_api_key = true http_method = "POST" - authorization = "NONE" + authorization = "CUSTOM" integration_parameters = { "integration.request.path.version" = "method.request.path.version" diff --git a/terraform/modules/api_gateway/resource/main.tf b/terraform/modules/api_gateway/resource/main.tf index 7c1023ae6..83075a736 100644 --- a/terraform/modules/api_gateway/resource/main.tf +++ b/terraform/modules/api_gateway/resource/main.tf @@ -71,7 +71,7 @@ resource "aws_api_gateway_gateway_response" "exceeded_quota" { response_type = "QUOTA_EXCEEDED" response_templates = { - "application/json" = "{\"status\":\"failed\",\"message\":\"Exceeded the daily quota for this resource. Please email us at gfw@wri.org to see if your use case may qualify for higher quota.\"}" + "application/json" = "{\"status\":\"failed\",\"message\":\"Exceeded the daily quota for this resource. If you're running analysis on a list of areas of interest, please use the batch analysis endpoint to avoid this error: https://staging-data-api.globalforestwatch.org/#tag/Query/operation/query_dataset_list_post_dataset__dataset___version__query_batch_post. Otherwise, email us at gfw@wri.org to see if your use case may qualify for higher quota.\"}" } } @@ -81,7 +81,7 @@ resource "aws_api_gateway_gateway_response" "throttled" { response_type = "THROTTLED" response_templates = { - "application/json" = "{\"status\":\"failed\",\"message\":\"Exceeded the rate limit for this resource. Please try again later. Also email us at gfw@wri.org to see if your use case may qualify for higher rate limit.\"}" + "application/json" = "{\"status\":\"failed\",\"message\":\"Exceeded the rate limit for this resource. If you're running analysis on a list of areas of interest, please use the batch analysis endpoint to avoid this error: https://staging-data-api.globalforestwatch.org/#tag/Query/operation/query_dataset_list_post_dataset__dataset___version__query_batch_post. Otherwise, email us at gfw@wri.org to see if your use case may qualify for higher rate limit.\"}" } } diff --git a/terraform/templates/container_definition.json.tmpl b/terraform/templates/container_definition.json.tmpl index 1b5713fd7..f031b29ca 100644 --- a/terraform/templates/container_definition.json.tmpl +++ b/terraform/templates/container_definition.json.tmpl @@ -112,6 +112,10 @@ { "name": "NAME_SUFFIX", "value": "${name_suffix}" + }, + { + "name": "RASTER_ANALYSIS_STATE_MACHINE_ARN", + "value": "${raster_analysis_sfn_arn}" } ], "secrets": [ diff --git a/terraform/templates/step_function_policy.json.tmpl b/terraform/templates/step_function_policy.json.tmpl new file mode 100644 index 000000000..41c402842 --- /dev/null +++ b/terraform/templates/step_function_policy.json.tmpl @@ -0,0 +1,23 @@ +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "states:StartExecution" + ], + "Resource": [ + "${raster_analysis_state_machine_arn}" + ] + }, + { + "Effect": "Allow", + "Action": [ + "states:DescribeExecution", + "states:DescribeMapRun", + "states:ListMapRuns" + ], + "Resource": "*" + } + ] +} \ No newline at end of file diff --git a/terraform/vars/terraform-dev.tfvars b/terraform/vars/terraform-dev.tfvars index 92024ec8d..5f9477624 100644 --- a/terraform/vars/terraform-dev.tfvars +++ b/terraform/vars/terraform-dev.tfvars @@ -5,7 +5,7 @@ rw_api_url = "https://api.resourcewatch.org" desired_count = 1 auto_scaling_min_capacity = 1 auto_scaling_max_capacity = 5 -lambda_analysis_workspace = "features-lat_lon" +lambda_analysis_workspace = "feature-otf_lists" key_pair = "dmannarino_gfw" create_cloudfront_distribution = false new_relic_license_key_arn = "arn:aws:secretsmanager:us-east-1:563860007740:secret:newrelic/license_key-lolw24" diff --git a/tests_v2/unit/app/routes/datasets/test_query.py b/tests_v2/unit/app/routes/datasets/test_query.py index 3078e8293..254bc9b90 100755 --- a/tests_v2/unit/app/routes/datasets/test_query.py +++ b/tests_v2/unit/app/routes/datasets/test_query.py @@ -1,6 +1,7 @@ from typing import Tuple from unittest.mock import Mock from urllib.parse import parse_qsl, urlparse +from uuid import UUID import pytest from _pytest.monkeypatch import MonkeyPatch @@ -11,7 +12,11 @@ from app.routes.datasets import queries from app.routes.datasets.queries import _get_data_environment from tests_v2.fixtures.creation_options.versions import RASTER_CREATION_OPTIONS -from tests_v2.utils import custom_raster_version, invoke_lambda_mocked +from tests_v2.utils import ( + custom_raster_version, + invoke_lambda_mocked, + start_batch_execution_mocked, +) @pytest.mark.skip("Temporarily skip until we require API keys") @@ -75,23 +80,22 @@ async def test_query_dataset_with_unrestricted_api_key( @pytest.mark.asyncio -async def test_fields_dataset_raster( - generic_raster_version, async_client: AsyncClient -): +async def test_fields_dataset_raster(generic_raster_version, async_client: AsyncClient): dataset_name, version_name, _ = generic_raster_version response = await async_client.get(f"/dataset/{dataset_name}/{version_name}/fields") assert response.status_code == 200 data = response.json()["data"] assert len(data) == 4 - assert data[0]["pixel_meaning"] == 'area__ha' - assert data[0]["values_table"] == None - assert data[1]["pixel_meaning"] == 'latitude' - assert data[1]["values_table"] == None - assert data[2]["pixel_meaning"] == 'longitude' - assert data[2]["values_table"] == None - assert data[3]["pixel_meaning"] == 'my_first_dataset__year' - assert data[3]["values_table"] == None + assert data[0]["pixel_meaning"] == "area__ha" + assert data[0]["values_table"] is None + assert data[1]["pixel_meaning"] == "latitude" + assert data[1]["values_table"] is None + assert data[2]["pixel_meaning"] == "longitude" + assert data[2]["values_table"] is None + assert data[3]["pixel_meaning"] == "my_first_dataset__year" + assert data[3]["values_table"] is None + @pytest.mark.asyncio async def test_query_dataset_raster_bad_get( @@ -237,9 +241,13 @@ async def test_redirect_get_query( follow_redirects=False, ) assert response.status_code == 308 - assert ( - parse_qsl(urlparse(response.headers["location"]).query, strict_parsing=True) - == parse_qsl(urlparse(f"/dataset/{dataset_name}/{version_name}/query/json?{response.request.url.query.decode('utf-8')}").query, strict_parsing=True) + assert parse_qsl( + urlparse(response.headers["location"]).query, strict_parsing=True + ) == parse_qsl( + urlparse( + f"/dataset/{dataset_name}/{version_name}/query/json?{response.request.url.query.decode('utf-8')}" + ).query, + strict_parsing=True, ) @@ -483,9 +491,10 @@ async def test_query_vector_asset_disallowed_10( "You might need to add explicit type casts." ) + @pytest.mark.asyncio() async def test_query_licensed_disallowed_11( - licensed_version, apikey, async_client: AsyncClient + licensed_version, apikey, async_client: AsyncClient ): dataset, version, _ = licensed_version @@ -499,9 +508,8 @@ async def test_query_licensed_disallowed_11( follow_redirects=True, ) assert response.status_code == 401 - assert response.json()["message"] == ( - "Unauthorized query on a restricted dataset" - ) + assert response.json()["message"] == ("Unauthorized query on a restricted dataset") + @pytest.mark.asyncio @pytest.mark.skip("Temporarily skip while _get_data_environment is being cached") @@ -655,3 +663,155 @@ async def test__get_data_environment_helper_called( no_data_value, None, ) + + +@pytest.mark.asyncio +async def test_query_batch_feature_collection( + generic_raster_version, + apikey, + monkeypatch: MonkeyPatch, + async_client: AsyncClient, +): + dataset_name, version_name, _ = generic_raster_version + api_key, payload = apikey + origin = "https://" + payload["domains"][0] + + headers = {"origin": origin, "x-api-key": api_key} + + monkeypatch.setattr(queries, "_start_batch_execution", start_batch_execution_mocked) + payload = { + "sql": "select count(*) from data", + "feature_collection": FEATURE_COLLECTION, + "id_field": "id", + } + + response = await async_client.post( + f"/dataset/{dataset_name}/{version_name}/query/batch", + json=payload, + headers=headers, + ) + + print(response.json()) + assert response.status_code == 202 + + data = response.json()["data"] + + # assert valid UUID + try: + uuid = UUID(data["job_id"]) + except ValueError: + assert False + + assert str(uuid) == data["job_id"] + assert data["job_link"].endswith(f"/job/{data['job_id']}") + + assert data["status"] == "pending" + + assert response.json()["status"] == "success" + + +@pytest.mark.asyncio +async def test_query_batch_uri( + generic_raster_version, + apikey, + monkeypatch: MonkeyPatch, + async_client: AsyncClient, +): + dataset_name, version_name, _ = generic_raster_version + api_key, payload = apikey + origin = "https://" + payload["domains"][0] + + headers = {"origin": origin, "x-api-key": api_key} + + monkeypatch.setattr(queries, "_start_batch_execution", start_batch_execution_mocked) + monkeypatch.setattr(queries, "_verify_source_file_access", lambda source_uris: True) + + payload = { + "sql": "select count(*) from data", + "uri": "s3://path/to/files", + "id_field": "id", + } + + response = await async_client.post( + f"/dataset/{dataset_name}/{version_name}/query/batch", + json=payload, + headers=headers, + ) + + print(response.json()) + assert response.status_code == 202 + + data = response.json()["data"] + + # assert valid UUID + try: + uuid = UUID(data["job_id"]) + except ValueError: + assert False + + assert str(uuid) == data["job_id"] + + assert data["status"] == "pending" + + assert response.json()["status"] == "success" + + +FEATURE_COLLECTION = { + "type": "FeatureCollection", + "features": [ + { + "type": "Feature", + "properties": { + "id": 1, + }, + "geometry": { + "coordinates": [ + [ + [-57.43488539218248, -11.378524299779286], + [-57.43488539218248, -11.871111619666053], + [-56.950732779425806, -11.871111619666053], + [-56.950732779425806, -11.378524299779286], + [-57.43488539218248, -11.378524299779286], + ] + ], + "type": "Polygon", + }, + }, + { + "type": "Feature", + "properties": { + "id": 2, + }, + "geometry": { + "coordinates": [ + [ + [-55.84751191303597, -11.845408946893727], + [-55.84751191303597, -12.293066281588139], + [-55.32975635387763, -12.293066281588139], + [-55.32975635387763, -11.845408946893727], + [-55.84751191303597, -11.845408946893727], + ] + ], + "type": "Polygon", + }, + }, + { + "type": "Feature", + "properties": { + "id": 3, + }, + "geometry": { + "coordinates": [ + [ + [-58.36172075077614, -12.835185539172727], + [-58.36172075077614, -13.153322454532116], + [-57.98648069126074, -13.153322454532116], + [-57.98648069126074, -12.835185539172727], + [-58.36172075077614, -12.835185539172727], + ] + ], + "type": "Polygon", + }, + }, + ], +} diff --git a/tests_v2/unit/app/routes/jobs/__init__.py b/tests_v2/unit/app/routes/jobs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_v2/unit/app/routes/jobs/test_job.py b/tests_v2/unit/app/routes/jobs/test_job.py new file mode 100644 index 000000000..4adab7633 --- /dev/null +++ b/tests_v2/unit/app/routes/jobs/test_job.py @@ -0,0 +1,191 @@ +import json + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from httpx import AsyncClient + +from app.routes.jobs import job + +TEST_JOB_ID = "f3caa6c8-09d7-43a8-823f-e7528344a169" + + +async def _get_sfn_execution_mocked_pending(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "stateMachineArn": "arn::fake_state_machine_arn", + "status": "RUNNING", + "input": json.dumps({"job_id": TEST_JOB_ID}), + "output": None, + "mapRunArn": "arn::fake_map_run_arn", + } + + +async def _get_sfn_execution_mocked_success(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "stateMachineArn": "arn::fake_state_machine_arn", + "status": "SUCCEEDED", + "input": json.dumps({"job_id": TEST_JOB_ID}), + "output": json.dumps( + { + "data": { + "job_id": TEST_JOB_ID, + "download_link": "s3://test/results.csv", + "failed_geometries_link": None, + }, + "status": "success", + } + ), + "mapRunArn": "arn::fake_map_run_arn", + } + + +async def _get_sfn_execution_mocked_failed(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "stateMachineArn": "arn::fake_state_machine_arn", + "status": "FAILED", + "input": json.dumps({"job_id": TEST_JOB_ID}), + "output": None, + "mapRunArn": "arn::fake_map_run_arn", + } + + +async def _get_sfn_execution_mocked_partial_success(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "stateMachineArn": "arn::fake_state_machine_arn", + "status": "SUCCEEDED", + "input": json.dumps({"job_id": TEST_JOB_ID}), + "output": json.dumps( + { + "data": { + "job_id": TEST_JOB_ID, + "download_link": "s3://test/results.csv", + "failed_geometries_link": "s3://test/results_failed.csv", + }, + "status": "partial_success", + } + ), + "mapRunArn": "arn::fake_map_run_arn", + } + + +async def _get_sfn_execution_mocked_failed_geoms(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "stateMachineArn": "arn::fake_state_machine_arn", + "status": "SUCCEEDED", + "input": json.dumps({"job_id": TEST_JOB_ID}), + "output": json.dumps( + { + "data": { + "job_id": TEST_JOB_ID, + "download_link": None, + "failed_geometries_link": "s3://test/results_failed.csv", + }, + "status": "failed", + } + ), + "mapRunArn": "arn::fake_map_run_arn", + } + + +async def _get_map_run_mocked_partial(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "mapRunArn": "arn::fake_map_run_arn", + "itemCounts": { + "succeeded": 100, + "total": 1000, + }, + } + + +async def _get_map_run_mocked_all(job_id): + return { + "executionArn": "arn::fake_execution_arn", + "mapRunArn": "arn::fake_map_run_arn", + "itemCounts": { + "succeeded": 1000, + "total": 1000, + }, + } + + +@pytest.mark.asyncio +async def test_job_pending( + async_client: AsyncClient, + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setattr(job, "_get_sfn_execution", _get_sfn_execution_mocked_pending) + monkeypatch.setattr(job, "_get_map_run", _get_map_run_mocked_partial) + + resp = await async_client.get(f"job/{TEST_JOB_ID}") + + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["status"] == "pending" + assert data["download_link"] is None + assert data["progress"] == "10%" + + +@pytest.mark.asyncio +async def test_job_success( + async_client: AsyncClient, + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setattr(job, "_get_sfn_execution", _get_sfn_execution_mocked_success) + monkeypatch.setattr(job, "_get_map_run", _get_map_run_mocked_all) + + resp = await async_client.get(f"job/{TEST_JOB_ID}") + + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == TEST_JOB_ID + assert data["status"] == "success" + assert "test/results.csv" in data["download_link"] + assert data["failed_geometries_link"] is None + assert data["progress"] == "100%" + + +@pytest.mark.asyncio +async def test_job_partial_success( + async_client: AsyncClient, + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setattr( + job, "_get_sfn_execution", _get_sfn_execution_mocked_partial_success + ) + monkeypatch.setattr(job, "_get_map_run", _get_map_run_mocked_partial) + + resp = await async_client.get(f"job/{TEST_JOB_ID}") + + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == TEST_JOB_ID + assert data["status"] == "partial_success" + assert "test/results.csv" in data["download_link"] + assert "test/results_failed.csv" in data["failed_geometries_link"] + assert data["progress"] == "10%" + + +@pytest.mark.asyncio +async def test_job_failed( + async_client: AsyncClient, + monkeypatch: MonkeyPatch, +) -> None: + monkeypatch.setattr(job, "_get_sfn_execution", _get_sfn_execution_mocked_failed) + + resp = await async_client.get(f"job/{TEST_JOB_ID}") + + assert resp.status_code == 200 + data = resp.json()["data"] + + assert data["job_id"] == TEST_JOB_ID + assert data["status"] == "failed" + assert data["download_link"] is None + assert data["progress"] is None diff --git a/tests_v2/utils.py b/tests_v2/utils.py index ec2bfb646..f0bef3e3e 100644 --- a/tests_v2/utils.py +++ b/tests_v2/utils.py @@ -86,6 +86,10 @@ async def invoke_lambda_mocked( return httpx.Response(200, json={"status": "success", "data": []}) +async def start_batch_execution_mocked(job_id: uuid.UUID, input: Dict[str, Any]): + pass + + def void_function(*args, **kwargs) -> None: return