Skip to content

Commit

Permalink
#416 Add API for retrieving user information (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 6, 2024
1 parent a10e8c4 commit 354bf4c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 29 deletions.
7 changes: 7 additions & 0 deletions app/api/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fastapi.security import APIKeyHeader

# Define the header key
API_KEY_NAME = "X-API-Key"

# Create a security dependency for API key
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True)
19 changes: 4 additions & 15 deletions app/api/v1/endpoints/user.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, CreditInfo, DetailData, UserPayload, UserResponse
from app.api.deps import api_key_header
from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, UserResponse
from app.services.user import user_service
from netspresso.utils.db.session import get_db

Expand All @@ -16,19 +17,7 @@ def generate_api_key(*, db: Session = Depends(get_db), request_body: ApiKeyCreat


@router.get("/me", response_model=UserResponse)
def get_user() -> UserResponse:
user = UserPayload(
user_id="e8e8df79-2a62-4562-8e4d-06f51dd795b2",
email="[email protected]",
detail_data=DetailData(
first_name="Byeongman",
last_name="Lee",
company="Nota AI",
),
credit_info=CreditInfo(
free=1000,
total=1000,
),
)
def get_user(*, db: Session = Depends(get_db), api_key: str = Depends(api_key_header)) -> UserResponse:
user = user_service.get_user_info(db=db, api_key=api_key)

return UserResponse(data=user)
38 changes: 30 additions & 8 deletions app/services/user.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from sqlalchemy.orm import Session

from app.api.v1.schemas.user import ApiKeyPayload
from app.utils import generate_id, hash_password
from app.api.v1.schemas.user import ApiKeyPayload, CreditInfo, DetailData, UserPayload
from app.utils import generate_id
from netspresso.netspresso import NetsPresso
from netspresso.utils.db.models.user import User
from netspresso.utils.db.repositories.user import user_repository


class UserService:
def create_user(self, db: Session, email: str, password: str, api_key: str):
hashed_password = hash_password(password)

user = User(
email=email,
password=hashed_password,
password=password,
api_key=api_key,
)
user = user_repository.save(db=db, model=user)
Expand All @@ -25,9 +24,8 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl
user = user_repository.get_by_email(db=db, email=email)

if user:
hashed_password = hash_password(password)
if user.password != hashed_password:
user.password = hashed_password
if user.password != password:
user.password = password
user.api_key = generated_id
elif user.api_key != generated_id:
user.api_key = generated_id
Expand All @@ -44,5 +42,29 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl

return api_key

def get_user_info(self, db: Session, api_key: str) -> UserPayload:
user = user_repository.get_by_api_key(db=db, api_key=api_key)

netspresso = NetsPresso(email=user.email, password=user.password)

user = UserPayload(
user_id=netspresso.user_info.user_id,
email=netspresso.user_info.email,
detail_data=DetailData(
first_name=netspresso.user_info.detail_data.first_name,
last_name=netspresso.user_info.detail_data.last_name,
company=netspresso.user_info.detail_data.company,
),
credit_info=CreditInfo(
free=netspresso.user_info.credit_info.free,
reward=netspresso.user_info.credit_info.reward,
contract=netspresso.user_info.credit_info.contract,
paid=netspresso.user_info.credit_info.paid,
total=netspresso.user_info.credit_info.total,
),
)

return user


user_service = UserService()
6 changes: 0 additions & 6 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import hashlib

from nanoid import generate


def generate_id(entity: str, size: int = 10) -> str:
nano_id = generate(size=size)
return f"{entity}_{nano_id}"


def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()

0 comments on commit 354bf4c

Please sign in to comment.