diff --git a/app/authentication/api_keys.py b/app/authentication/api_keys.py index 0b80878dd..c05a2b909 100644 --- a/app/authentication/api_keys.py +++ b/app/authentication/api_keys.py @@ -93,32 +93,14 @@ def api_key_is_valid( return is_valid -def api_key_is_internal( - domains: List[str], - user_id: Optional[str] = None, - origin: Optional[str] = None, - referrer: Optional[str] = None, -) -> bool: - - is_internal: bool = False - if origin and domains: - is_internal = any( - [ - re.search(_to_regex(internal_domain.strip()), domain) - for domain in domains - for internal_domain in INTERNAL_DOMAINS.split(",") - ] - ) - elif referrer and domains: - is_internal = any( - [ - re.search(_to_regex(domain), internal_domain) - for domain in domains - for internal_domain in INTERNAL_DOMAINS.split(",") - ] - ) - - return is_internal +def api_key_is_internal(domains: List[str]) -> bool: + return any( + [ + re.search(_to_regex(internal_domain.strip()), domain) + for domain in domains + for internal_domain in INTERNAL_DOMAINS.split(",") + ] + ) def _api_key_origin_auto_error( @@ -139,7 +121,7 @@ def _api_key_origin_auto_error( def _to_regex(domain): result = domain.replace(".", r"\.").replace("*", ".*") - return fr"^{result}$" + return rf"^{result}$" def _extract_domain(url: str) -> str: diff --git a/app/crud/versions.py b/app/crud/versions.py index 82cfc0d3b..e3de5b4bb 100644 --- a/app/crud/versions.py +++ b/app/crud/versions.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional +from async_lru import alru_cache from asyncpg import UniqueViolationError from ..errors import RecordAlreadyExistsError, RecordNotFoundError @@ -11,7 +12,6 @@ from . import datasets, update_data from .metadata import ( create_version_metadata, - update_all_metadata, update_version_metadata, ) @@ -52,6 +52,7 @@ async def get_version(dataset: str, version: str) -> ORMVersion: return row +@alru_cache(maxsize=64, ttl=3600.0) async def get_latest_version(dataset) -> str: """Fetch latest version number.""" @@ -80,9 +81,6 @@ async def create_version(dataset: str, version: str, **data) -> ORMVersion: if data.get("is_downloadable") is None: data["is_downloadable"] = d.is_downloadable - if data.get("is_latest"): - await _reset_is_latest(dataset, version) - metadata_data = data.pop("metadata", None) try: new_version: ORMVersion = await ORMVersion.create( @@ -100,6 +98,13 @@ async def create_version(dataset: str, version: str, **data) -> ORMVersion: ) new_version.metadata = metadata + # NOTE: We disallow specifying a new version as latest on creation via + # the VersionCreateIn model in order to prevent requests temporarily going + # to an incompletely-imported asset, however it's technically allowed in + # this function to facilitate testing. + if data.get("is_latest"): + await _reset_is_latest(dataset, version) + return new_version @@ -155,8 +160,15 @@ async def _update_is_downloadable( async def _reset_is_latest(dataset: str, version: str) -> None: + """Set is_latest to False for all other versions of a dataset.""" + # NOTE: Please remember to only call after setting the provided version to + # latest to avoid setting nothing to latest + # FIXME: This will get slower and more DB-intensive the more versions + # there are for a dataset. Could be re-written to use a single DB call, + # no? versions = await get_versions(dataset) version_gen = list_to_async_generator(versions) async for version_orm in version_gen: if version_orm.version != version: await update_version(dataset, version_orm.version, is_latest=False) + _: bool = get_latest_version.cache_invalidate(dataset) diff --git a/app/responses.py b/app/responses.py index f906a6588..050f4ea94 100644 --- a/app/responses.py +++ b/app/responses.py @@ -1,6 +1,7 @@ import decimal import io from typing import Any +import asyncpg import orjson from fastapi.responses import Response, StreamingResponse @@ -74,7 +75,7 @@ def jsonencoder_lite(obj): encoding large lists. This encoder only encodes the bare necessities needed to work with serializers like ORJSON. """ - if isinstance(obj, decimal.Decimal): + if isinstance(obj, decimal.Decimal) or isinstance(obj, asyncpg.pgproto.pgproto.UUID): return str(obj) raise TypeError( f"Unknown type for value {obj} with class type {type(obj).__name__}" diff --git a/app/routes/analysis/analysis.py b/app/routes/analysis/analysis.py index c49c17c1f..4843542e6 100644 --- a/app/routes/analysis/analysis.py +++ b/app/routes/analysis/analysis.py @@ -2,13 +2,13 @@ from typing import Any, Dict, List, Optional from uuid import UUID -from fastapi import APIRouter, Path, Query +from fastapi import APIRouter, Depends, Path, Query from fastapi.exceptions import HTTPException from fastapi.logger import logger -# from fastapi.openapi.models import APIKey +from fastapi.openapi.models import APIKey from fastapi.responses import ORJSONResponse -# from ...authentication.api_keys import get_api_key +from ...authentication.api_keys import get_api_key from ...models.enum.analysis import RasterLayer from ...models.enum.geostore import GeostoreOrigin from ...models.pydantic.analysis import ZonalAnalysisRequestIn @@ -50,7 +50,7 @@ async def zonal_statistics_get( description="Must be either year or YYYY-MM-DD date format.", regex=DATE_REGEX, ), - # api_key: APIKey = Depends(get_api_key), + api_key: APIKey = Depends(get_api_key), ): """Calculate zonal statistics on any registered raster layers in a geostore.""" @@ -80,8 +80,7 @@ async def zonal_statistics_get( deprecated=True, ) async def zonal_statistics_post( - request: ZonalAnalysisRequestIn, - # api_key: APIKey = Depends(get_api_key) + request: ZonalAnalysisRequestIn, api_key: APIKey = Depends(get_api_key) ): return await _zonal_statistics( request.geometry, @@ -104,7 +103,7 @@ async def _zonal_statistics( if geometry.type != "Polygon" and geometry.type != "MultiPolygon": raise HTTPException( status_code=400, - detail=f"Geometry must be a Polygon or MultiPolygon for raster analysis" + detail="Geometry must be a Polygon or MultiPolygon for raster analysis", ) # OTF will just not apply a base filter diff --git a/app/routes/authentication/authentication.py b/app/routes/authentication/authentication.py index 7e9090798..9c47971a2 100644 --- a/app/routes/authentication/authentication.py +++ b/app/routes/authentication/authentication.py @@ -74,16 +74,8 @@ async def create_api_key( input_data = api_key_data.dict(by_alias=True) - origin = request.headers.get("origin") - referrer = request.headers.get("referer") - if not api_key_is_valid(input_data["domains"], origin=origin, referrer=referrer): - raise HTTPException( - status_code=400, - detail="Domain name did not match the request origin or referrer.", - ) - # Give a good error code/message if user is specifying an alias that exists for - # another one of his API keys. + # another one of their API keys. prev_keys: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id=user.id) for key in prev_keys: if key.alias == api_key_data.alias: @@ -94,9 +86,7 @@ async def create_api_key( row: ORMApiKey = await api_keys.create_api_key(user_id=user.id, **input_data) - is_internal = api_key_is_internal( - api_key_data.domains, user_id=None, origin=origin, referrer=referrer - ) + is_internal = api_key_is_internal(api_key_data.domains) usage_plan_id = ( API_GATEWAY_INTERNAL_USAGE_PLAN if is_internal is True diff --git a/app/routes/datasets/queries.py b/app/routes/datasets/queries.py index 4e39c107c..42aebe2be 100755 --- a/app/routes/datasets/queries.py +++ b/app/routes/datasets/queries.py @@ -14,7 +14,7 @@ from fastapi import Response as FastApiResponse from fastapi.encoders import jsonable_encoder from fastapi.logger import logger -# from fastapi.openapi.models import APIKey +from fastapi.openapi.models import APIKey from fastapi.responses import RedirectResponse from pglast import printers # noqa from pglast import Node, parse_sql @@ -24,7 +24,7 @@ from ...authentication.token import is_gfwpro_admin_for_query from ...application import db -# from ...authentication.api_keys import get_api_key +from ...authentication.api_keys import get_api_key from ...crud import assets from ...models.enum.assets import AssetType from ...models.enum.creation_options import Delimiters @@ -128,7 +128,7 @@ async def query_dataset_json( GeostoreOrigin.gfw, description="Service to search first for geostore." ), is_authorized: bool = Depends(is_gfwpro_admin_for_query), - # api_key: APIKey = Depends(get_api_key), + api_key: APIKey = Depends(get_api_key), ): """Execute a READ-ONLY SQL query on the given dataset version (if implemented) and return response in JSON format. @@ -190,7 +190,7 @@ async def query_dataset_csv( Delimiters.comma, description="Delimiter to use for CSV file." ), is_authorized: bool = Depends(is_gfwpro_admin_for_query), - # api_key: APIKey = Depends(get_api_key), + api_key: APIKey = Depends(get_api_key), ): """Execute a READ-ONLY SQL query on the given dataset version (if implemented) and return response in CSV format. @@ -253,7 +253,7 @@ async def query_dataset_json_post( dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), request: QueryRequestIn, is_authorized: bool = Depends(is_gfwpro_admin_for_query), - # api_key: APIKey = Depends(get_api_key), + api_key: APIKey = Depends(get_api_key), ): """Execute a READ-ONLY SQL query on the given dataset version (if implemented).""" @@ -284,7 +284,7 @@ async def query_dataset_csv_post( dataset_version: Tuple[str, str] = Depends(dataset_version_dependency), request: CsvQueryRequestIn, is_authorized: bool = Depends(is_gfwpro_admin_for_query), - # api_key: APIKey = Depends(get_api_key), + api_key: APIKey = Depends(get_api_key), ): """Execute a READ-ONLY SQL query on the given dataset version (if implemented).""" @@ -595,7 +595,7 @@ async def _query_raster( if geostore.geojson.type != "Polygon" and geostore.geojson.type != "MultiPolygon": raise HTTPException( status_code=400, - detail=f"Geostore must be a Polygon or MultiPolygon for raster analysis" + detail="Geostore must be a Polygon or MultiPolygon for raster analysis", ) # use default data type to get default raster layer for dataset diff --git a/app/settings/globals.py b/app/settings/globals.py index 018daa267..9210e261f 100644 --- a/app/settings/globals.py +++ b/app/settings/globals.py @@ -178,6 +178,7 @@ "api.resourcewatch.org", "my.gfw-mapbuilder.org", "resourcewatch.org", + "*.wri.org", ] ) diff --git a/terraform/modules/api_gateway/gateway/main.tf b/terraform/modules/api_gateway/gateway/main.tf index 2dcb9d63e..ffbed1516 100644 --- a/terraform/modules/api_gateway/gateway/main.tf +++ b/terraform/modules/api_gateway/gateway/main.tf @@ -46,9 +46,9 @@ module "query_get" { authorizer_id = aws_api_gateway_authorizer.api_key.id api_resource = module.query_resource.aws_api_gateway_resource - require_api_key = false + require_api_key = true http_method = "GET" - authorization = "NONE" + authorization = "CUSTOM" integration_parameters = { "integration.request.path.version" = "method.request.path.version" @@ -73,9 +73,9 @@ module "query_post" { authorizer_id = aws_api_gateway_authorizer.api_key.id api_resource = module.query_resource.aws_api_gateway_resource - require_api_key = false + require_api_key = true http_method = "POST" - authorization = "NONE" + authorization = "CUSTOM" integration_parameters = { "integration.request.path.version" = "method.request.path.version" @@ -180,13 +180,6 @@ resource "aws_api_gateway_usage_plan" "internal" { burst_limit = var.api_gateway_usage_plans.internal_apps.burst_limit rate_limit = var.api_gateway_usage_plans.internal_apps.rate_limit } - - # terraform doesn't expose API Gateway's method level throttling so will do that - # manually and this will stop terraform from destroying the manual changes - # Open PR to add the feature to terraform: https://github.com/hashicorp/terraform-provider-aws/pull/20672 - lifecycle { - ignore_changes = all - } } resource "aws_api_gateway_usage_plan" "external" { @@ -206,14 +199,6 @@ resource "aws_api_gateway_usage_plan" "external" { burst_limit = var.api_gateway_usage_plans.external_apps.burst_limit rate_limit = var.api_gateway_usage_plans.external_apps.rate_limit } - - # terraform doesn't expose API Gateway's method level throttling so will do that - # manually and this will stop terraform from destroying the manual changes - # Open PR to add the feature to terraform: https://github.com/hashicorp/terraform-provider-aws/pull/20672 - lifecycle { - ignore_changes = all - } - } resource "aws_api_gateway_deployment" "api_gw_dep" { diff --git a/terraform/modules/api_gateway/gateway/variables.tf b/terraform/modules/api_gateway/gateway/variables.tf index 3f820a5cf..ae8a468b0 100644 --- a/terraform/modules/api_gateway/gateway/variables.tf +++ b/terraform/modules/api_gateway/gateway/variables.tf @@ -51,14 +51,14 @@ variable "api_gateway_usage_plans" { description = "Throttling limits for API Gateway" default = { internal_apps = { - quota_limit = 10000 # per day - burst_limit = 100 # per second - rate_limit = 200 + quota_limit = 500000 # per day + burst_limit = 1000 + rate_limit = 200 # per second } external_apps = { - quota_limit = 500 - burst_limit = 10 - rate_limit = 20 + quota_limit = 10000 + burst_limit = 20 + rate_limit = 10 } } } diff --git a/terraform/modules/api_gateway/resource/main.tf b/terraform/modules/api_gateway/resource/main.tf index 6eba66851..7c1023ae6 100644 --- a/terraform/modules/api_gateway/resource/main.tf +++ b/terraform/modules/api_gateway/resource/main.tf @@ -17,8 +17,8 @@ resource "aws_api_gateway_integration" "get_endpoint_integration" { http_method = aws_api_gateway_method.get_endpoint_method.http_method type = "MOCK" - passthrough_behavior = "WHEN_NO_MATCH" - request_templates = { + passthrough_behavior = "WHEN_NO_MATCH" + request_templates = { "application/json" : <