diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 00000000..305a689e --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,7 @@ +from fastapi.security import APIKeyHeader + +# Define the header key +API_KEY_NAME = "X-API-Key" + +# Create a security dependency for API key +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=True) diff --git a/app/api/v1/endpoints/user.py b/app/api/v1/endpoints/user.py index 722e41af..bb47a09c 100644 --- a/app/api/v1/endpoints/user.py +++ b/app/api/v1/endpoints/user.py @@ -1,7 +1,8 @@ from fastapi import APIRouter, Depends from sqlalchemy.orm import Session -from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, CreditInfo, DetailData, UserPayload, UserResponse +from app.api.deps import api_key_header +from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, UserResponse from app.services.user import user_service from netspresso.utils.db.session import get_db @@ -16,19 +17,7 @@ def generate_api_key(*, db: Session = Depends(get_db), request_body: ApiKeyCreat @router.get("/me", response_model=UserResponse) -def get_user() -> UserResponse: - user = UserPayload( - user_id="e8e8df79-2a62-4562-8e4d-06f51dd795b2", - email="nppd_test_001@nota.ai", - detail_data=DetailData( - first_name="Byeongman", - last_name="Lee", - company="Nota AI", - ), - credit_info=CreditInfo( - free=1000, - total=1000, - ), - ) +def get_user(*, db: Session = Depends(get_db), api_key: str = Depends(api_key_header)) -> UserResponse: + user = user_service.get_user_info(db=db, api_key=api_key) return UserResponse(data=user) diff --git a/app/services/user.py b/app/services/user.py index ce6041dc..b0ea12c6 100644 --- a/app/services/user.py +++ b/app/services/user.py @@ -1,18 +1,17 @@ from sqlalchemy.orm import Session -from app.api.v1.schemas.user import ApiKeyPayload -from app.utils import generate_id, hash_password +from app.api.v1.schemas.user import ApiKeyPayload, CreditInfo, DetailData, UserPayload +from app.utils import generate_id +from netspresso.netspresso import NetsPresso from netspresso.utils.db.models.user import User from netspresso.utils.db.repositories.user import user_repository class UserService: def create_user(self, db: Session, email: str, password: str, api_key: str): - hashed_password = hash_password(password) - user = User( email=email, - password=hashed_password, + password=password, api_key=api_key, ) user = user_repository.save(db=db, model=user) @@ -25,9 +24,8 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl user = user_repository.get_by_email(db=db, email=email) if user: - hashed_password = hash_password(password) - if user.password != hashed_password: - user.password = hashed_password + if user.password != password: + user.password = password user.api_key = generated_id elif user.api_key != generated_id: user.api_key = generated_id @@ -44,5 +42,29 @@ def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayl return api_key + def get_user_info(self, db: Session, api_key: str) -> UserPayload: + user = user_repository.get_by_api_key(db=db, api_key=api_key) + + netspresso = NetsPresso(email=user.email, password=user.password) + + user = UserPayload( + user_id=netspresso.user_info.user_id, + email=netspresso.user_info.email, + detail_data=DetailData( + first_name=netspresso.user_info.detail_data.first_name, + last_name=netspresso.user_info.detail_data.last_name, + company=netspresso.user_info.detail_data.company, + ), + credit_info=CreditInfo( + free=netspresso.user_info.credit_info.free, + reward=netspresso.user_info.credit_info.reward, + contract=netspresso.user_info.credit_info.contract, + paid=netspresso.user_info.credit_info.paid, + total=netspresso.user_info.credit_info.total, + ), + ) + + return user + user_service = UserService() diff --git a/app/utils.py b/app/utils.py index eaf3585c..eb04b7de 100644 --- a/app/utils.py +++ b/app/utils.py @@ -1,12 +1,6 @@ -import hashlib - from nanoid import generate def generate_id(entity: str, size: int = 10) -> str: nano_id = generate(size=size) return f"{entity}_{nano_id}" - - -def hash_password(password: str) -> str: - return hashlib.sha256(password.encode()).hexdigest()