diff --git a/app/routes/datasets/downloads.py b/app/routes/datasets/downloads.py index 1645a5397..72a8a1f0a 100644 --- a/app/routes/datasets/downloads.py +++ b/app/routes/datasets/downloads.py @@ -24,7 +24,7 @@ from ...utils.geostore import get_geostore from ...utils.path import split_s3_path from .. import dataset_version_dependency -from .queries import _query_dataset_csv, _query_dataset_json +from ..utils.downloads import _query_dataset_csv, _query_dataset_json router: APIRouter = APIRouter() diff --git a/app/routes/datasets/queries.py b/app/routes/datasets/queries.py index 4e39c107c..f14a3a365 100755 --- a/app/routes/datasets/queries.py +++ b/app/routes/datasets/queries.py @@ -2,7 +2,7 @@ import csv import re from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import unquote from uuid import UUID, uuid4 @@ -25,7 +25,6 @@ from ...authentication.token import is_gfwpro_admin_for_query from ...application import db # from ...authentication.api_keys import get_api_key -from ...crud import assets from ...models.enum.assets import AssetType from ...models.enum.creation_options import Delimiters from ...models.enum.geostore import GeostoreOrigin @@ -76,6 +75,7 @@ from ...utils.aws import invoke_lambda from ...utils.geostore import get_geostore from .. import dataset_version_dependency +from ..utils.downloads import _query_dataset_csv, _query_dataset_json router = APIRouter() @@ -305,56 +305,6 @@ async def query_dataset_csv_post( return CSVStreamingResponse(iter([csv_data.getvalue()]), download=False) -async def _query_dataset_json( - dataset: str, - version: str, - sql: str, - geostore: Optional[GeostoreCommon], -) -> List[Dict[str, Any]]: - # Make sure we can query the dataset - default_asset: AssetORM = await assets.get_default_asset(dataset, version) - query_type = _get_query_type(default_asset, geostore) - if query_type == QueryType.table: - geometry = geostore.geojson if geostore else None - return await _query_table(dataset, version, sql, geometry) - elif query_type == QueryType.raster: - geostore = cast(GeostoreCommon, geostore) - results = await _query_raster(dataset, default_asset, sql, geostore) - return results["data"] - else: - raise HTTPException( - status_code=501, - detail="This endpoint is not implemented for the given dataset.", - ) - - -async def _query_dataset_csv( - dataset: str, - version: str, - sql: str, - geostore: Optional[GeostoreCommon], - delimiter: Delimiters = Delimiters.comma, -) -> StringIO: - # Make sure we can query the dataset - default_asset: AssetORM = await assets.get_default_asset(dataset, version) - query_type = _get_query_type(default_asset, geostore) - if query_type == QueryType.table: - geometry = geostore.geojson if geostore else None - response = await _query_table(dataset, version, sql, geometry) - return _orm_to_csv(response, delimiter=delimiter) - elif query_type == QueryType.raster: - geostore = cast(GeostoreCommon, geostore) - results = await _query_raster( - dataset, default_asset, sql, geostore, QueryFormat.csv, delimiter - ) - return StringIO(results["data"]) - else: - raise HTTPException( - status_code=501, - detail="This endpoint is not implemented for the given dataset.", - ) - - def _get_query_type(default_asset: AssetORM, geostore: Optional[GeostoreCommon]): if default_asset.asset_type in [ AssetType.geo_database_table, diff --git a/app/routes/utils/__init__.py b/app/routes/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/routes/utils/queries.py b/app/routes/utils/queries.py new file mode 100644 index 000000000..3acb21409 --- /dev/null +++ b/app/routes/utils/queries.py @@ -0,0 +1,82 @@ +import csv +from io import StringIO +from typing import Any, Dict, List, Optional, cast + +from fastapi import HTTPException + +from ...crud import assets +from ...models.enum.creation_options import Delimiters +from ...models.enum.queries import QueryFormat, QueryType +from ...models.orm.assets import Asset as AssetORM +from ...models.pydantic.geostore import GeostoreCommon +from ..datasets.queries import _get_query_type, _query_raster, _query_table + + +async def _query_dataset_json( + dataset: str, + version: str, + sql: str, + geostore: Optional[GeostoreCommon], +) -> List[Dict[str, Any]]: + # Make sure we can query the dataset + default_asset: AssetORM = await assets.get_default_asset(dataset, version) + query_type = _get_query_type(default_asset, geostore) + if query_type == QueryType.table: + geometry = geostore.geojson if geostore else None + return await _query_table(dataset, version, sql, geometry) + elif query_type == QueryType.raster: + geostore = cast(GeostoreCommon, geostore) + results = await _query_raster(dataset, default_asset, sql, geostore) + return results["data"] + else: + raise HTTPException( + status_code=501, + detail="This endpoint is not implemented for the given dataset.", + ) + + +async def _query_dataset_csv( + dataset: str, + version: str, + sql: str, + geostore: Optional[GeostoreCommon], + delimiter: Delimiters = Delimiters.comma, +) -> StringIO: + # Make sure we can query the dataset + default_asset: AssetORM = await assets.get_default_asset(dataset, version) + query_type = _get_query_type(default_asset, geostore) + if query_type == QueryType.table: + geometry = geostore.geojson if geostore else None + response = await _query_table(dataset, version, sql, geometry) + return _orm_to_csv(response, delimiter=delimiter) + elif query_type == QueryType.raster: + geostore = cast(GeostoreCommon, geostore) + results = await _query_raster( + dataset, default_asset, sql, geostore, QueryFormat.csv, delimiter + ) + return StringIO(results["data"]) + else: + raise HTTPException( + status_code=501, + detail="This endpoint is not implemented for the given dataset.", + ) + + +def _orm_to_csv( + data: List[Dict[str, Any]], delimiter: Delimiters = Delimiters.comma +) -> StringIO: + """Create a new csv file that represents generated data. + + Response will return a temporary redirect to download URL. + """ + csv_file = StringIO() + + if data: + wr = csv.writer(csv_file, quoting=csv.QUOTE_NONNUMERIC, delimiter=delimiter) + field_names = data[0].keys() + wr.writerow(field_names) + for row in data: + wr.writerow(row.values()) + csv_file.seek(0) + + return csv_file