Skip to content

Commit

Permalink
#426 Add generating api key for be (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Dec 5, 2024
1 parent 9d904b0 commit a10e8c4
Show file tree
Hide file tree
Showing 17 changed files with 300 additions and 6 deletions.
21 changes: 15 additions & 6 deletions app/api/v1/endpoints/user.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session

from app.api.v1.schemas.user import CreditInfo, DetailData, UserPayload, UserResponse
from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, CreditInfo, DetailData, UserPayload, UserResponse
from app.services.user import user_service
from netspresso.utils.db.session import get_db

router = APIRouter()


@router.post("/me", response_model=UserResponse)
def get_user() -> UserResponse:
@router.post("/api-key", response_model=ApiKeyResponse)
def generate_api_key(*, db: Session = Depends(get_db), request_body: ApiKeyCreate) -> ApiKeyResponse:
api_key = user_service.generate_api_key(db=db, email=request_body.email, password=request_body.password)

return ApiKeyResponse(data=api_key)

project = UserPayload(

@router.get("/me", response_model=UserResponse)
def get_user() -> UserResponse:
user = UserPayload(
user_id="e8e8df79-2a62-4562-8e4d-06f51dd795b2",
email="[email protected]",
detail_data=DetailData(
Expand All @@ -22,4 +31,4 @@ def get_user() -> UserResponse:
),
)

return UserResponse(data=project)
return UserResponse(data=user)
13 changes: 13 additions & 0 deletions app/api/v1/schemas/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
from app.api.v1.schemas.base import ResponseItem


class ApiKeyCreate(BaseModel):
email: str = Field(..., description="Email of the user.")
password: str = Field(..., description="Password of the user.")


class ApiKeyPayload(BaseModel):
api_key: str = Field(..., description="API key of the user.")


class ApiKeyResponse(ResponseItem):
data: ApiKeyPayload


class CreditInfo(BaseModel):
free: int = Field(default=0, description="Free credits available.")
reward: int = Field(default=0, description="Reward credits available.")
Expand Down
Empty file added app/exceptions/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions app/exceptions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from fastapi import HTTPException

from app.exceptions.schema import AdditionalData, ExceptionDetail


class ExceptionBase(HTTPException):
def __init__(
self,
data: Optional[AdditionalData],
error_code: str,
status_code: int,
name: str,
message: str,
):
detail = ExceptionDetail(
data=data,
error_code=error_code,
name=name,
message=message,
)
super().__init__(status_code=status_code, detail=detail.model_dump())
39 changes: 39 additions & 0 deletions app/exceptions/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from enum import Enum
from typing import Optional

from pydantic import BaseModel, Field


class Origin(str, Enum):
ROUTER = "router"
SERVICE = "service"
REPOSITORY = "repository"
CLIENT = "client"
LIBRARY = "library"


class AdditionalData(BaseModel):
origin: Optional[Origin] = Field(default="", description="Error origin")
error_log: Optional[str] = Field(default="", description="Error log")


class ExceptionDetail(BaseModel):
data: Optional[AdditionalData] = Field(default={}, description="Additional data")
error_code: str = Field(..., description="Error code")
name: str = Field(..., description="Error name")
message: str = Field(..., description="Error message")

class Config:
json_schema_extra = {
"examples": [
{
"data": {
"origin": Origin.ROUTER,
"error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")"
},
"error_code": "CS40020",
"name": "NotFoundMethodClassException",
"message": "Not found VBMF method class.",
}
]
}
15 changes: 15 additions & 0 deletions app/exceptions/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from fastapi import status

from app.exceptions.base import ExceptionBase
from app.exceptions.schema import AdditionalData, Origin


class IncorrectEmailOrPasswordException(ExceptionBase):
def __init__(self, origin: Origin = Origin.SERVICE):
super().__init__(
data=AdditionalData(origin=origin),
error_code="US40101",
status_code=status.HTTP_401_UNAUTHORIZED,
name=self.__class__.__name__,
message="The email or password provided is incorrect. Please check your email and password and try again.",
)
48 changes: 48 additions & 0 deletions app/services/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sqlalchemy.orm import Session

from app.api.v1.schemas.user import ApiKeyPayload
from app.utils import generate_id, hash_password
from netspresso.utils.db.models.user import User
from netspresso.utils.db.repositories.user import user_repository


class UserService:
def create_user(self, db: Session, email: str, password: str, api_key: str):
hashed_password = hash_password(password)

user = User(
email=email,
password=hashed_password,
api_key=api_key,
)
user = user_repository.save(db=db, model=user)

return user

def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayload:
generated_id = generate_id(entity="user")

user = user_repository.get_by_email(db=db, email=email)

if user:
hashed_password = hash_password(password)
if user.password != hashed_password:
user.password = hashed_password
user.api_key = generated_id
elif user.api_key != generated_id:
user.api_key = generated_id
user = user_repository.save(db=db, model=user)
else:
user = self.create_user(
db=db,
email=email,
password=password,
api_key=generated_id,
)

api_key = ApiKeyPayload(api_key=user.api_key)

return api_key


user_service = UserService()
12 changes: 12 additions & 0 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import hashlib

from nanoid import generate


def generate_id(entity: str, size: int = 10) -> str:
nano_id = generate(size=size)
return f"{entity}_{nano_id}"


def hash_password(password: str) -> str:
return hashlib.sha256(password.encode()).hexdigest()
Empty file added netspresso/utils/db/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions netspresso/utils/db/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .timestamp_mixin import TimestampMixin

__all__ = ["TimestampMixin"]
17 changes: 17 additions & 0 deletions netspresso/utils/db/mixins/timestamp_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from sqlalchemy import Column, DateTime, func
from sqlalchemy.ext.declarative import declared_attr


class TimestampMixin:
@declared_attr
def created_at(cls):
return Column(DateTime(timezone=True), server_default=func.now(), nullable=False)

@declared_attr
def updated_at(cls):
return Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
9 changes: 9 additions & 0 deletions netspresso/utils/db/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from netspresso.utils.db.models.user import User
from netspresso.utils.db.session import Base, engine

Base.metadata.create_all(engine)


__all__ = [
"User",
]
13 changes: 13 additions & 0 deletions netspresso/utils/db/models/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sqlalchemy import Boolean, Column, Integer, String

from netspresso.utils.db.mixins import TimestampMixin
from netspresso.utils.db.session import Base


class User(Base, TimestampMixin):
__tablename__ = "user"

id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False)
email = Column(String(36), nullable=False)
password = Column(String(36), nullable=False)
api_key = Column(String(36), nullable=False)
Empty file.
34 changes: 34 additions & 0 deletions netspresso/utils/db/repositories/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import Generic, Type, TypeVar

from sqlalchemy import asc, desc
from sqlalchemy.orm import Session

from netspresso.utils.db.session import Base

ModelType = TypeVar("ModelType", bound=Base) # type: ignore


class Order(str, Enum):
DESC = "desc"
ASC = "asc"


class BaseRepository(Generic[ModelType]):
def __init__(self, model: Type[ModelType]):
self.model = model

def choose_order_func(self, order):
if order == Order.DESC:
return desc
return asc

def save(self, db: Session, model: ModelType) -> ModelType:
db.add(model)
db.commit()
db.refresh(model)

return model

def update(self, db: Session, model: ModelType) -> ModelType:
return self.save(db, model)
26 changes: 26 additions & 0 deletions netspresso/utils/db/repositories/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from sqlalchemy.orm import Session

from netspresso.utils.db.models.user import User
from netspresso.utils.db.repositories.base import BaseRepository


class UserRepository(BaseRepository[User]):
def get_by_email(self, db: Session, email: str) -> Optional[User]:
user = db.query(User).filter(User.email == email).first()

return user

def get_by_user_id(self, db: Session, user_id: str) -> Optional[User]:
user = db.query(self.model).filter(self.model.user_id == user_id).first()

return user

def get_by_api_key(self, db: Session, api_key: str) -> Optional[User]:
user = db.query(self.model).filter(self.model.api_key == api_key).first()

return user


user_repository = UserRepository(User)
33 changes: 33 additions & 0 deletions netspresso/utils/db/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Generator

from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy_utils import create_database, database_exists

DB_URL = "sqlite:///netspresso.db"
engine = create_engine(
f"{DB_URL}",
pool_pre_ping=True,
pool_use_lifo=True,
pool_recycle=3600,
)
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
expire_on_commit=False,
)

Base = declarative_base()


def get_db() -> Generator:
try:
db = SessionLocal()
yield db
finally:
db.close()


if not database_exists(engine.url):
create_database(engine.url)

0 comments on commit a10e8c4

Please sign in to comment.