Skip to content

Commit

Permalink
Merge pull request #509 from wri/feature/user_role
Browse files Browse the repository at this point in the history
Pass user details when authenticating instead of a bool, check for manager on dataset creation
  • Loading branch information
jterry64 authored May 9, 2024
2 parents 3843ed5 + 4b5b704 commit 7213ac6
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 65 deletions.
42 changes: 25 additions & 17 deletions app/authentication/token.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Tuple, cast
from typing import cast

from fastapi import Depends, HTTPException
from fastapi.logger import logger
from fastapi.security import OAuth2PasswordBearer
from httpx import Response

from ..models.pydantic.authentication import User
from ..routes import dataset_dependency
from ..settings.globals import PROTECTED_QUERY_DATASETS
from ..utils.rw_api import who_am_i
Expand Down Expand Up @@ -42,12 +43,6 @@ async def is_admin(token: str = Depends(oauth2_scheme)) -> bool:
return await is_app_admin(token, "gfw", "Unauthorized")


async def rw_user_id(token: str = Depends(oauth2_scheme)) -> str:
"""Gets user ID from token."""

return await who_am_i(token).json()["id"]


async def is_gfwpro_admin_for_query(
dataset: str = Depends(dataset_dependency),
token: str | None = Depends(oauth2_scheme_no_auto),
Expand Down Expand Up @@ -93,19 +88,32 @@ async def is_app_admin(token: str, app: str, error_str: str) -> bool:
return True


async def get_user(token: str = Depends(oauth2_scheme)) -> Tuple[str, str]:
"""Calls GFW API to authorize user.
This functions check is user of any level is associated with the GFW
app and returns the user ID
"""
async def get_user(token: str = Depends(oauth2_scheme)) -> User:
"""Get the details for authenticated user."""

response: Response = await who_am_i(token)

if response.status_code == 401 or not (
"gfw" in response.json()["extraUserData"]["apps"]
):
if response.status_code == 401:
logger.info("Unauthorized user")
raise HTTPException(status_code=401, detail="Unauthorized")
else:
return response.json()["id"], response.json()["role"]
return User(**response.json())


async def get_admin(user: User = Depends(get_user)) -> User:
"""Get the details for authenticated ADMIN user."""

if user.role != "ADMIN":
raise HTTPException(status_code=401, detail="Unauthorized")

return user


async def get_manager(user: User = Depends(get_user)) -> User:
"""Get the details for authenticated MANAGER for data-api application or
ADMIN user."""

if user.role != "ADMIN" or user.role != "MANAGER":
raise HTTPException(status_code=401, detail="Unauthorized")

return user
5 changes: 3 additions & 2 deletions app/models/pydantic/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@ class SignUpRequestIn(StrictBaseModel):
email: EmailStr = Query(..., description="User's email address")


class SignUp(StrictBaseModel):
class User(StrictBaseModel):
id: str
name: str
email: EmailStr
createdAt: datetime
role: str
applications: List[str]
extraUserData: Dict[str, Any]


class SignUpResponse(Response):
data: SignUp
data: User


class APIKeyRequestIn(StrictBaseModel):
Expand Down
33 changes: 15 additions & 18 deletions app/routes/authentication/authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request
Expand All @@ -18,6 +18,7 @@
ApiKeyValidationResponse,
SignUpRequestIn,
SignUpResponse,
User,
)
from ...models.pydantic.responses import Response
from ...settings.globals import (
Expand Down Expand Up @@ -58,19 +59,17 @@ async def get_token(form_data: OAuth2PasswordRequestForm = Depends()):
async def create_api_key(
api_key_data: APIKeyRequestIn,
request: Request,
user: Tuple[str, str] = Depends(get_user),
user: User = Depends(get_user),
):
"""Request a new API key.
Default keys are valid for one year
"""

user_id, user_role = user

if api_key_data.never_expires and user_role != "ADMIN":
if api_key_data.never_expires and user.role != "ADMIN":
raise HTTPException(
status_code=400,
detail=f"Users with role {user_role} cannot set `never_expires` to True.",
detail=f"Users with role {user.role} cannot set `never_expires` to True.",
)

input_data = api_key_data.dict(by_alias=True)
Expand All @@ -85,15 +84,15 @@ async def create_api_key(

# Give a good error code/message if user is specifying an alias that exists for
# another one of his API keys.
prev_keys: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id=user_id)
prev_keys: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id=user.id)
for key in prev_keys:
if key.alias == api_key_data.alias:
raise HTTPException(
status_code=409,
detail="Key with specified alias already exists; use a different alias"
detail="Key with specified alias already exists; use a different alias",
)

row: ORMApiKey = await api_keys.create_api_key(user_id=user_id, **input_data)
row: ORMApiKey = await api_keys.create_api_key(user_id=user.id, **input_data)

is_internal = api_key_is_internal(
api_key_data.domains, user_id=None, origin=origin, referrer=referrer
Expand All @@ -117,19 +116,19 @@ async def create_api_key(
@router.get("/apikey/{api_key}", tags=["Authentication"])
async def get_api_key(
api_key: UUID = Path(..., description="API Key"),
user: Tuple[str, str] = Depends(get_user),
user: User = Depends(get_user),
):
"""Get details for a specific API Key.
User must own API Key or must be Admin to see details.
"""
user_id, role = user

try:
row: ORMApiKey = await api_keys.get_api_key(api_key)
except RecordNotFoundError:
raise HTTPException(status_code=404, detail="The API Key does not exist.")

if role != "ADMIN" and row.user_id != user_id:
if user.role != "ADMIN" and row.user_id != user.id:
raise HTTPException(
status_code=403, detail="API Key is not associated with current user."
)
Expand All @@ -141,14 +140,13 @@ async def get_api_key(

@router.get("/apikeys", tags=["Authentication"])
async def get_api_keys(
user: Tuple[str, str] = Depends(get_user),
user: User = Depends(get_user),
):
"""Request a new API key.
Default keys are valid for one year
"""
user_id, _ = user
rows: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user_id)
rows: List[ORMApiKey] = await api_keys.get_api_keys_from_user(user.id)
data = [ApiKey.from_orm(row) for row in rows]

return ApiKeysResponse(data=data)
Expand Down Expand Up @@ -184,13 +182,12 @@ async def delete_api_key(
api_key: UUID = Path(
..., description="Api Key to delete. Must be owned by authenticated user."
),
user: Tuple[str, str] = Depends(get_user),
user: User = Depends(get_user),
):
"""Delete existing API key.
API Key must belong to user.
"""
user_id, _ = user
try:
row: ORMApiKey = await api_keys.get_api_key(api_key)
except RecordNotFoundError:
Expand All @@ -199,7 +196,7 @@ async def delete_api_key(
)

# TODO: we might want to allow admins to delete api keys of other users?
if not row.user_id == user_id:
if not row.user_id == user.id:
raise HTTPException(
status_code=403,
detail="The requested API key does not belong to the current user.",
Expand Down
19 changes: 10 additions & 9 deletions app/routes/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from sqlalchemy.schema import CreateSchema, DropSchema

from ...application import db
from ...authentication.token import is_admin, rw_user_id
from ...authentication.token import get_manager, is_admin
from ...crud import datasets, versions
from ...errors import RecordAlreadyExistsError, RecordNotFoundError
from ...models.orm.datasets import Dataset as ORMDataset
from ...models.orm.versions import Version as ORMVersion
from ...models.pydantic.authentication import User
from ...models.pydantic.datasets import (
Dataset,
DatasetCreateIn,
Expand Down Expand Up @@ -52,21 +53,21 @@ async def create_dataset(
*,
dataset: str = Depends(dataset_dependency),
request: DatasetCreateIn,
is_authorized: bool = Depends(is_admin),
owner_id: str = Depends(rw_user_id),
user: User = Depends(get_manager),
response: Response,
) -> DatasetResponse:
"""Create a dataset. A “dataset” is largely a metadata concept: it represents
a data product that may have multiple versions or file formats over time.
The user that creates a dataset using this operation becomes the owner of
the dataset, which provides the user with the privileges to do further write
operations on the dataset, including creating and modifying versions and assets.
"""Create a dataset. A “dataset” is largely a metadata concept: it
represents a data product that may have multiple versions or file formats
over time. The user that creates a dataset using this operation becomes the
owner of the dataset, which provides the user with the privileges to do
further write operations on the dataset, including creating and modifying
versions and assets.
This operation requires a `MANAGER` or an `ADMIN` user role.
"""

input_data: Dict = request.dict(exclude_none=True, by_alias=True)
input_data["owner_id"] = owner_id
input_data["owner_id"] = user.id

try:
new_dataset: ORMDataset = await datasets.create_dataset(dataset, **input_data)
Expand Down
6 changes: 3 additions & 3 deletions app/utils/rw_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RecordNotFoundError,
UnauthorizedError,
)
from ..models.pydantic.authentication import SignUp
from ..models.pydantic.authentication import User
from ..models.pydantic.geostore import Geometry, GeostoreCommon
from ..settings.globals import RW_API_URL

Expand Down Expand Up @@ -116,7 +116,7 @@ async def login(user_name: str, password: str) -> str:
return response.json()["data"]["token"]


async def signup(name: str, email: str) -> SignUp:
async def signup(name: str, email: str) -> User:
"""Obtain a token form RW API using given user name and password."""

headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -153,4 +153,4 @@ async def signup(name: str, email: str) -> SignUp:
detail="An error occurred while trying to create a new user account. Please try again.",
)

return SignUp(**response.json()["data"])
return User(**response.json()["data"])
15 changes: 11 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker

from app.models.pydantic.authentication import User
from app.settings.globals import (
AWS_REGION,
DATA_LAKE_BUCKET,
Expand Down Expand Up @@ -51,8 +52,6 @@
BUCKET = "test-bucket"
PORT = 9000

RW_USER_ID = "5874bfcca049b7a56ad42771" # pragma: allowlist secret

SessionLocal: Optional[Session] = None


Expand Down Expand Up @@ -315,8 +314,16 @@ async def get_api_key_mocked() -> Tuple[Optional[str], Optional[str]]:
return str(uuid.uuid4()), "localhost"


async def get_rw_user_id() -> str:
return RW_USER_ID
async def get_manager_mocked() -> User:
return User(
id="mr_manager123",
name="Mr. Manager",
email="[email protected]",
createdAt="2021-06-13T03:18:23.000Z",
role="MANAGER",
applications=[],
extraUserData={},
)


def setup_clients(ec2_client, iam_client):
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from httpx import AsyncClient

from app.authentication.api_keys import get_api_key
from app.authentication.token import is_admin, is_service_account, rw_user_id
from app.authentication.token import get_manager, is_admin, is_service_account
from app.settings.globals import (
AURORA_JOB_QUEUE,
AURORA_JOB_QUEUE_FAST,
Expand Down Expand Up @@ -52,7 +52,7 @@
AWSMock,
MemoryServer,
get_api_key_mocked,
get_rw_user_id,
get_manager_mocked,
is_admin_mocked,
is_service_account_mocked,
setup_clients,
Expand Down Expand Up @@ -252,7 +252,7 @@ async def async_client():
app.dependency_overrides[is_admin] = is_admin_mocked
app.dependency_overrides[is_service_account] = is_service_account_mocked
app.dependency_overrides[get_api_key] = get_api_key_mocked
app.dependency_overrides[rw_user_id] = get_rw_user_id
app.dependency_overrides[get_manager] = get_manager_mocked

async with AsyncClient(app=app, base_url="http://test", trust_env=False) as client:
yield client
Expand Down
3 changes: 1 addition & 2 deletions tests/routes/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

from app.application import ContextEngine, db
from tests import RW_USER_ID
from tests.utils import create_default_asset, dataset_metadata

payload = {"metadata": dataset_metadata}
Expand Down Expand Up @@ -69,7 +68,7 @@ async def test_datasets(async_client):
)

assert len(rows) == 1
assert rows[0][0] == RW_USER_ID
assert rows[0][0] == "mr_manager123"

new_payload = {"metadata": {"title": "New Title"}}
response = await async_client.patch(f"/dataset/{dataset}", json=new_payload)
Expand Down
6 changes: 3 additions & 3 deletions tests_v2/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi.testclient import TestClient
from httpx import AsyncClient

from app.authentication.token import get_user, is_admin, is_service_account, rw_user_id
from app.authentication.token import get_manager, get_user, is_admin, is_service_account
from app.crud import api_keys
from app.models.enum.change_log import ChangeLogStatus
from app.models.pydantic.change_log import ChangeLog
Expand All @@ -31,7 +31,7 @@
dict_function_closure,
get_admin_mocked,
get_extent_mocked,
get_rw_user_id_mocked,
get_manager_mocked,
get_user_mocked,
int_function_closure,
void_coroutine,
Expand Down Expand Up @@ -81,7 +81,7 @@ async def async_client(db, init_db) -> AsyncGenerator[AsyncClient, None]:
True, with_args=False
)
app.dependency_overrides[get_user] = get_admin_mocked
app.dependency_overrides[rw_user_id] = get_rw_user_id_mocked
app.dependency_overrides[get_manager] = get_manager_mocked

async with AsyncClient(
app=app,
Expand Down
Loading

0 comments on commit 7213ac6

Please sign in to comment.