diff --git a/app/authentication/token.py b/app/authentication/token.py index 5e7cc24de..c622a82d7 100644 --- a/app/authentication/token.py +++ b/app/authentication/token.py @@ -1,10 +1,11 @@ -from typing import Tuple, cast +from typing import cast from fastapi import Depends, HTTPException from fastapi.logger import logger from fastapi.security import OAuth2PasswordBearer from httpx import Response +from ..models.pydantic.authentication import User from ..routes import dataset_dependency from ..settings.globals import PROTECTED_QUERY_DATASETS from ..utils.rw_api import who_am_i @@ -42,12 +43,6 @@ async def is_admin(token: str = Depends(oauth2_scheme)) -> bool: return await is_app_admin(token, "gfw", "Unauthorized") -async def rw_user_id(token: str = Depends(oauth2_scheme)) -> str: - """Gets user ID from token.""" - - return await who_am_i(token).json()["id"] - - async def is_gfwpro_admin_for_query( dataset: str = Depends(dataset_dependency), token: str | None = Depends(oauth2_scheme_no_auto), @@ -93,19 +88,34 @@ async def is_app_admin(token: str, app: str, error_str: str) -> bool: return True -async def get_user(token: str = Depends(oauth2_scheme)) -> Tuple[str, str]: - """Calls GFW API to authorize user. - - This functions check is user of any level is associated with the GFW - app and returns the user ID - """ +async def get_user(token: str = Depends(oauth2_scheme)) -> User: + """Get the details for authenticated user.""" response: Response = await who_am_i(token) - if response.status_code == 401 or not ( - "gfw" in response.json()["extraUserData"]["apps"] - ): + if response.status_code == 401: logger.info("Unauthorized user") raise HTTPException(status_code=401, detail="Unauthorized") else: - return response.json()["id"], response.json()["role"] + return User(**response.json()) + + +async def get_admin(user: User = Depends(get_user)) -> User: + """Get the details for authenticated ADMIN user.""" + + if user.role != "ADMIN": + raise HTTPException(status_code=401, detail="Unauthorized") + + return user + + +async def get_manager(user: User = Depends(get_user)) -> User: + """Get the details for authenticated MANAGER for data-api application or + ADMIN user.""" + + if user.role != "ADMIN" or not ( + user.role == "MANAGER" and "data-api" in user.applications + ): + raise HTTPException(status_code=401, detail="Unauthorized") + + return user diff --git a/app/models/pydantic/authentication.py b/app/models/pydantic/authentication.py index e218bc872..61be0cf5c 100644 --- a/app/models/pydantic/authentication.py +++ b/app/models/pydantic/authentication.py @@ -14,17 +14,18 @@ class SignUpRequestIn(StrictBaseModel): email: EmailStr = Query(..., description="User's email address") -class SignUp(StrictBaseModel): +class User(StrictBaseModel): id: str name: str email: EmailStr createdAt: datetime role: str + applications: List[str] extraUserData: Dict[str, Any] class SignUpResponse(Response): - data: SignUp + data: User class APIKeyRequestIn(StrictBaseModel): diff --git a/app/routes/authentication/authentication.py b/app/routes/authentication/authentication.py index 1014868c8..7e9090798 100644 --- a/app/routes/authentication/authentication.py +++ b/app/routes/authentication/authentication.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request @@ -18,6 +18,7 @@ ApiKeyValidationResponse, SignUpRequestIn, SignUpResponse, + User, ) from ...models.pydantic.responses import Response from ...settings.globals import ( @@ -58,19 +59,17 @@ async def get_token(form_data: OAuth2PasswordRequestForm = Depends()): async def create_api_key( api_key_data: APIKeyRequestIn, request: Request, - user: Tuple[str, str] = Depends(get_user), + user: User = Depends(get_user), ): """Request a new API key. Default keys are valid for one year """ - user_id, user_role = user - - if api_key_data.never_expires and user_role != "ADMIN": + if api_key_data.never_expires and user.role != "ADMIN": raise HTTPException( status_code=400, - detail=f"Users with role {user_role} cannot set `never_expires` to True.", + detail=f"Users with role {user.role} cannot set `never_expires` to True.", ) input_data = api_key_data.dict(by_alias=True) @@ -85,15 +84,15 @@ async def create_api_key( # Give a good error code/message if user is specifying an alias that exists for # another one of his API keys. - prev_keys: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id=user_id) + prev_keys: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id=user.id) for key in prev_keys: if key.alias == api_key_data.alias: raise HTTPException( status_code=409, - detail="Key with specified alias already exists; use a different alias" + detail="Key with specified alias already exists; use a different alias", ) - row: ORMApiKey = await api_keys.create_api_key(user_id=user_id, **input_data) + row: ORMApiKey = await api_keys.create_api_key(user_id=user.id, **input_data) is_internal = api_key_is_internal( api_key_data.domains, user_id=None, origin=origin, referrer=referrer @@ -117,19 +116,19 @@ async def create_api_key( @router.get("/apikey/{api_key}", tags=["Authentication"]) async def get_api_key( api_key: UUID = Path(..., description="API Key"), - user: Tuple[str, str] = Depends(get_user), + user: User = Depends(get_user), ): """Get details for a specific API Key. User must own API Key or must be Admin to see details. """ - user_id, role = user + try: row: ORMApiKey = await api_keys.get_api_key(api_key) except RecordNotFoundError: raise HTTPException(status_code=404, detail="The API Key does not exist.") - if role != "ADMIN" and row.user_id != user_id: + if user.role != "ADMIN" and row.user_id != user.id: raise HTTPException( status_code=403, detail="API Key is not associated with current user." ) @@ -141,14 +140,13 @@ async def get_api_key( @router.get("/apikeys", tags=["Authentication"]) async def get_api_keys( - user: Tuple[str, str] = Depends(get_user), + user: User = Depends(get_user), ): """Request a new API key. Default keys are valid for one year """ - user_id, _ = user - rows: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id) + rows: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user.id) data = [ApiKey.from_orm(row) for row in rows] return ApiKeysResponse(data=data) @@ -184,13 +182,12 @@ async def delete_api_key( api_key: UUID = Path( ..., description="Api Key to delete. Must be owned by authenticated user." ), - user: Tuple[str, str] = Depends(get_user), + user: User = Depends(get_user), ): """Delete existing API key. API Key must belong to user. """ - user_id, _ = user try: row: ORMApiKey = await api_keys.get_api_key(api_key) except RecordNotFoundError: @@ -199,7 +196,7 @@ async def delete_api_key( ) # TODO: we might want to allow admins to delete api keys of other users? - if not row.user_id == user_id: + if not row.user_id == user.id: raise HTTPException( status_code=403, detail="The requested API key does not belong to the current user.", diff --git a/app/routes/datasets/dataset.py b/app/routes/datasets/dataset.py index d56d7d621..d3ee06d8d 100644 --- a/app/routes/datasets/dataset.py +++ b/app/routes/datasets/dataset.py @@ -1,6 +1,5 @@ """Datasets are just a bucket, for datasets which share the same core metadata.""" - from typing import Any, Dict, List from fastapi import APIRouter, Depends, HTTPException, Response @@ -8,11 +7,12 @@ from sqlalchemy.schema import CreateSchema, DropSchema from ...application import db -from ...authentication.token import is_admin, rw_user_id +from ...authentication.token import get_manager from ...crud import datasets, versions from ...errors import RecordAlreadyExistsError, RecordNotFoundError from ...models.orm.datasets import Dataset as ORMDataset from ...models.orm.versions import Version as ORMVersion +from ...models.pydantic.authentication import User from ...models.pydantic.datasets import ( Dataset, DatasetCreateIn, @@ -25,6 +25,19 @@ router = APIRouter() +async def get_owner( + dataset: str = Depends(dataset_dependency), user: User = Depends(get_manager) +) -> User: + """Retrieves the user object that owns the dataset if that user is the one + making the request, otherwise raises a 401.""" + + dataset_row: ORMDataset = await datasets.get_dataset(dataset) + owner: str = dataset_row.owner_id + if owner != user.id: + raise HTTPException(status_code=401, detail="Unauthorized") + return user + + @router.get( "/{dataset}", response_class=ORJSONResponse, @@ -52,21 +65,21 @@ async def create_dataset( *, dataset: str = Depends(dataset_dependency), request: DatasetCreateIn, - is_authorized: bool = Depends(is_admin), - owner_id: str = Depends(rw_user_id), + user: User = Depends(get_manager), response: Response, ) -> DatasetResponse: - """Create a dataset. A “dataset” is largely a metadata concept: it represents - a data product that may have multiple versions or file formats over time. - The user that creates a dataset using this operation becomes the owner of - the dataset, which provides the user with the privileges to do further write - operations on the dataset, including creating and modifying versions and assets. + """Create a dataset. A “dataset” is largely a metadata concept: it + represents a data product that may have multiple versions or file formats + over time. The user that creates a dataset using this operation becomes the + owner of the dataset, which provides the user with the privileges to do + further write operations on the dataset, including creating and modifying + versions and assets. This operation requires a `MANAGER` or an `ADMIN` user role. """ input_data: Dict = request.dict(exclude_none=True, by_alias=True) - input_data["owner_id"] = owner_id + input_data["owner_id"] = user.id try: new_dataset: ORMDataset = await datasets.create_dataset(dataset, **input_data) @@ -93,7 +106,7 @@ async def update_dataset( *, dataset: str = Depends(dataset_dependency), request: DatasetUpdateIn, - is_authorized: bool = Depends(is_admin), + is_authorized: User = Depends(get_owner), ) -> DatasetResponse: """Partially update a dataset. @@ -117,7 +130,7 @@ async def update_dataset( async def delete_dataset( *, dataset: str = Depends(dataset_dependency), - is_authorized: bool = Depends(is_admin), + is_authorized: User = Depends(get_owner), ) -> DatasetResponse: """Delete a dataset. diff --git a/app/utils/rw_api.py b/app/utils/rw_api.py index 22ff7effc..d580b7a69 100644 --- a/app/utils/rw_api.py +++ b/app/utils/rw_api.py @@ -12,7 +12,7 @@ RecordNotFoundError, UnauthorizedError, ) -from ..models.pydantic.authentication import SignUp +from ..models.pydantic.authentication import User from ..models.pydantic.geostore import Geometry, GeostoreCommon from ..settings.globals import RW_API_URL @@ -116,7 +116,7 @@ async def login(user_name: str, password: str) -> str: return response.json()["data"]["token"] -async def signup(name: str, email: str) -> SignUp: +async def signup(name: str, email: str) -> User: """Obtain a token form RW API using given user name and password.""" headers = {"Content-Type": "application/json"} @@ -153,4 +153,4 @@ async def signup(name: str, email: str) -> SignUp: detail="An error occurred while trying to create a new user account. Please try again.", ) - return SignUp(**response.json()["data"]) + return User(**response.json()["data"]) diff --git a/tests/__init__.py b/tests/__init__.py index e268bda67..16b498565 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,6 +12,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session, sessionmaker +from app.models.pydantic.authentication import User from app.settings.globals import ( AWS_REGION, DATA_LAKE_BUCKET, @@ -51,8 +52,6 @@ BUCKET = "test-bucket" PORT = 9000 -RW_USER_ID = "5874bfcca049b7a56ad42771" # pragma: allowlist secret - SessionLocal: Optional[Session] = None @@ -315,8 +314,16 @@ async def get_api_key_mocked() -> Tuple[Optional[str], Optional[str]]: return str(uuid.uuid4()), "localhost" -async def get_rw_user_id() -> str: - return RW_USER_ID +async def get_manager_mocked() -> User: + return User( + id="mr_manager123", + name="Mr. Manager", + email="mr_manager@management.com", + createdAt="2021-06-13T03:18:23.000Z", + role="MANAGER", + applications=["data-api"], + extraUserData={}, + ) def setup_clients(ec2_client, iam_client): diff --git a/tests/conftest.py b/tests/conftest.py index efb5801f6..8d66ad9bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ from httpx import AsyncClient from app.authentication.api_keys import get_api_key -from app.authentication.token import is_admin, is_service_account, rw_user_id +from app.authentication.token import get_manager, is_admin, is_service_account from app.settings.globals import ( AURORA_JOB_QUEUE, AURORA_JOB_QUEUE_FAST, @@ -52,7 +52,7 @@ AWSMock, MemoryServer, get_api_key_mocked, - get_rw_user_id, + get_manager_mocked, is_admin_mocked, is_service_account_mocked, setup_clients, @@ -252,7 +252,7 @@ async def async_client(): app.dependency_overrides[is_admin] = is_admin_mocked app.dependency_overrides[is_service_account] = is_service_account_mocked app.dependency_overrides[get_api_key] = get_api_key_mocked - app.dependency_overrides[rw_user_id] = get_rw_user_id + app.dependency_overrides[get_manager] = get_manager_mocked async with AsyncClient(app=app, base_url="http://test", trust_env=False) as client: yield client diff --git a/tests/routes/datasets/test_datasets.py b/tests/routes/datasets/test_datasets.py index f823a9798..5e149ec1e 100644 --- a/tests/routes/datasets/test_datasets.py +++ b/tests/routes/datasets/test_datasets.py @@ -3,7 +3,6 @@ import pytest from app.application import ContextEngine, db -from tests import RW_USER_ID from tests.utils import create_default_asset, dataset_metadata payload = {"metadata": dataset_metadata} @@ -69,7 +68,7 @@ async def test_datasets(async_client): ) assert len(rows) == 1 - assert rows[0][0] == RW_USER_ID + assert rows[0][0] == "mr_manager123" new_payload = {"metadata": {"title": "New Title"}} response = await async_client.patch(f"/dataset/{dataset}", json=new_payload) diff --git a/tests_v2/conftest.py b/tests_v2/conftest.py index bd58a3def..4a9fefce0 100755 --- a/tests_v2/conftest.py +++ b/tests_v2/conftest.py @@ -11,7 +11,7 @@ from fastapi.testclient import TestClient from httpx import AsyncClient -from app.authentication.token import get_user, is_admin, is_service_account, rw_user_id +from app.authentication.token import get_manager, get_user, is_admin, is_service_account from app.crud import api_keys from app.models.enum.change_log import ChangeLogStatus from app.models.pydantic.change_log import ChangeLog @@ -31,7 +31,7 @@ dict_function_closure, get_admin_mocked, get_extent_mocked, - get_rw_user_id_mocked, + get_manager_mocked, get_user_mocked, int_function_closure, void_coroutine, @@ -81,7 +81,7 @@ async def async_client(db, init_db) -> AsyncGenerator[AsyncClient, None]: True, with_args=False ) app.dependency_overrides[get_user] = get_admin_mocked - app.dependency_overrides[rw_user_id] = get_rw_user_id_mocked + app.dependency_overrides[get_manager] = get_manager_mocked async with AsyncClient( app=app, diff --git a/tests_v2/unit/app/routes/datasets/test_dataset.py b/tests_v2/unit/app/routes/datasets/test_dataset.py index 38a59038a..5b0e97941 100644 --- a/tests_v2/unit/app/routes/datasets/test_dataset.py +++ b/tests_v2/unit/app/routes/datasets/test_dataset.py @@ -1,16 +1,138 @@ -from typing import Any, Dict, Tuple +from typing import Tuple import pytest +from fastapi.exceptions import HTTPException from httpx import AsyncClient -from app.models.pydantic.datasets import DatasetResponse, Dataset -from app.models.pydantic.metadata import DatasetMetadata -from tests_v2.unit.app.routes.utils import assert_jsend +from app.authentication.token import get_manager +from app.models.pydantic.datasets import DatasetResponse +from app.routes.datasets.dataset import get_owner from tests_v2.fixtures.metadata.dataset import DATASET_METADATA +from tests_v2.unit.app.routes.utils import assert_jsend +from tests_v2.utils import ( + get_admin_mocked, + get_manager_mocked, + get_user_mocked, + raises_401, +) + + +@pytest.mark.asyncio +async def test_get_owner_fail(db, init_db, monkeypatch) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": {}} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + + some_user = await get_user_mocked() + + try: + _ = await get_owner(dataset_name, some_user) + except HTTPException as e: + assert e.status_code == 401 + assert e.detail == "Unauthorized" + + +async def test_get_owner_manager_success(db, init_db, monkeypatch) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": {}} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + + some_manager = await get_manager_mocked() + + _ = await get_owner(dataset_name, some_manager) + + +async def test_get_owner_different_manager_fail(db, init_db, monkeypatch) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": {}} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + + some_manager = await get_manager_mocked() + some_manager.id = "Some other manager" + + try: + _ = await get_owner(dataset_name, some_manager) + except HTTPException as e: + assert e.status_code == 401 + assert e.detail == "Unauthorized" + + +async def test_get_owner_admin_success(db, init_db, monkeypatch) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_admin_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": {}} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + + some_admin = await get_admin_mocked() + + _ = await get_owner(dataset_name, some_admin) @pytest.mark.asyncio -async def test_get_dataset( +async def test_get_dataset_success( async_client: AsyncClient, generic_dataset: Tuple[str, str] ) -> None: dataset_name, _ = generic_dataset @@ -19,6 +141,15 @@ async def test_get_dataset( _validate_dataset_response(resp.json(), dataset_name) +@pytest.mark.asyncio +async def test_get_dataset_fail( + async_client: AsyncClient, generic_dataset: Tuple[str, str] +) -> None: + dataset_name: str = "not_a_real_dataset" + resp = await async_client.get(f"/dataset/{dataset_name}") + assert resp.status_code == 404 + + # TODO: Use mark.parameterize to test variations @pytest.mark.asyncio async def test_create_dataset(async_client: AsyncClient) -> None: @@ -35,8 +166,75 @@ def test_update_dataset(): pass -def test_delete_dataset(): - pass +@pytest.mark.asyncio +async def test_delete_dataset_requires_creds_fail(db, init_db) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": DATASET_METADATA} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + app.dependency_overrides[get_owner] = raises_401 + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + delete_resp = await async_client.delete(f"/dataset/{dataset_name}") + assert delete_resp.json()["message"] == "Unauthorized" + assert delete_resp.status_code == 401 + + app.dependency_overrides = {} + + +@pytest.mark.asyncio +async def test_delete_dataset_requires_creds_succeed(db, init_db, monkeypatch) -> None: + dataset_name: str = "my_first_dataset" + + from app.main import app + + # Create a dataset + app.dependency_overrides[get_manager] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + create_resp = await async_client.put( + f"/dataset/{dataset_name}", json={"metadata": DATASET_METADATA} + ) + assert create_resp.status_code == 201 + + app.dependency_overrides = {} + app.dependency_overrides[get_owner] = get_manager_mocked + + async with AsyncClient( + app=app, + base_url="http://test", + trust_env=False, + headers={"Origin": "https://www.globalforestwatch.org"}, + ) as async_client: + delete_resp = await async_client.delete(f"/dataset/{dataset_name}") + assert delete_resp.status_code == 200 + + app.dependency_overrides = {} def test__dataset_response(): diff --git a/tests_v2/utils.py b/tests_v2/utils.py index 42a6b34ae..943e7cdd5 100644 --- a/tests_v2/utils.py +++ b/tests_v2/utils.py @@ -6,8 +6,10 @@ import httpx from _pytest.monkeypatch import MonkeyPatch +from fastapi.exceptions import HTTPException from app.application import ContextEngine +from app.models.pydantic.authentication import User from app.models.pydantic.extent import Extent from app.routes.datasets import versions from app.tasks import batch, delete_assets @@ -31,16 +33,40 @@ def submit_batch_job(self, *args, **kwargs) -> uuid.UUID: return job_id -async def get_user_mocked() -> Tuple[str, str]: - return "userid_123", "USER" +async def get_user_mocked() -> User: + return User( + id="userid_123", + name="Ms. User", + email="ms_user@user.com", + createdAt="2021-06-13T03:18:23.000Z", + role="USER", + applications=[], + extraUserData={}, + ) -async def get_admin_mocked() -> Tuple[str, str]: - return "adminid_123", "ADMIN" +async def get_admin_mocked() -> User: + return User( + id="adminid_123", + name="Sir Admin", + email="sir_admin@admin.com", + createdAt="2021-06-13T03:18:23.000Z", + role="ADMIN", + applications=[], + extraUserData={}, + ) -async def get_rw_user_id_mocked() -> str: - return "userid_123" +async def get_manager_mocked() -> User: + return User( + id="mr_manager123", + name="Mr. Manager", + email="mr_manager@management.com", + createdAt="2021-06-13T03:18:23.000Z", + role="MANAGER", + applications=["data-api"], + extraUserData={}, + ) async def get_api_key_mocked() -> Tuple[Optional[str], Optional[str]]: @@ -148,3 +174,7 @@ async def custom_raster_version( yield version_name finally: pass + + +async def raises_401() -> None: + raise HTTPException(status_code=401, detail="Unauthorized")