-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
9d904b0
commit a10e8c4
Showing
17 changed files
with
300 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,23 @@ | ||
from fastapi import APIRouter | ||
from fastapi import APIRouter, Depends | ||
from sqlalchemy.orm import Session | ||
|
||
from app.api.v1.schemas.user import CreditInfo, DetailData, UserPayload, UserResponse | ||
from app.api.v1.schemas.user import ApiKeyCreate, ApiKeyResponse, CreditInfo, DetailData, UserPayload, UserResponse | ||
from app.services.user import user_service | ||
from netspresso.utils.db.session import get_db | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/me", response_model=UserResponse) | ||
def get_user() -> UserResponse: | ||
@router.post("/api-key", response_model=ApiKeyResponse) | ||
def generate_api_key(*, db: Session = Depends(get_db), request_body: ApiKeyCreate) -> ApiKeyResponse: | ||
api_key = user_service.generate_api_key(db=db, email=request_body.email, password=request_body.password) | ||
|
||
return ApiKeyResponse(data=api_key) | ||
|
||
project = UserPayload( | ||
|
||
@router.get("/me", response_model=UserResponse) | ||
def get_user() -> UserResponse: | ||
user = UserPayload( | ||
user_id="e8e8df79-2a62-4562-8e4d-06f51dd795b2", | ||
email="[email protected]", | ||
detail_data=DetailData( | ||
|
@@ -22,4 +31,4 @@ def get_user() -> UserResponse: | |
), | ||
) | ||
|
||
return UserResponse(data=project) | ||
return UserResponse(data=user) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Optional | ||
|
||
from fastapi import HTTPException | ||
|
||
from app.exceptions.schema import AdditionalData, ExceptionDetail | ||
|
||
|
||
class ExceptionBase(HTTPException): | ||
def __init__( | ||
self, | ||
data: Optional[AdditionalData], | ||
error_code: str, | ||
status_code: int, | ||
name: str, | ||
message: str, | ||
): | ||
detail = ExceptionDetail( | ||
data=data, | ||
error_code=error_code, | ||
name=name, | ||
message=message, | ||
) | ||
super().__init__(status_code=status_code, detail=detail.model_dump()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from enum import Enum | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class Origin(str, Enum): | ||
ROUTER = "router" | ||
SERVICE = "service" | ||
REPOSITORY = "repository" | ||
CLIENT = "client" | ||
LIBRARY = "library" | ||
|
||
|
||
class AdditionalData(BaseModel): | ||
origin: Optional[Origin] = Field(default="", description="Error origin") | ||
error_log: Optional[str] = Field(default="", description="Error log") | ||
|
||
|
||
class ExceptionDetail(BaseModel): | ||
data: Optional[AdditionalData] = Field(default={}, description="Additional data") | ||
error_code: str = Field(..., description="Error code") | ||
name: str = Field(..., description="Error name") | ||
message: str = Field(..., description="Error message") | ||
|
||
class Config: | ||
json_schema_extra = { | ||
"examples": [ | ||
{ | ||
"data": { | ||
"origin": Origin.ROUTER, | ||
"error_log": "AttributeError(\"module 'np_compressor_core.torch.pruning' has no attribute 'VBMF'\")" | ||
}, | ||
"error_code": "CS40020", | ||
"name": "NotFoundMethodClassException", | ||
"message": "Not found VBMF method class.", | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from fastapi import status | ||
|
||
from app.exceptions.base import ExceptionBase | ||
from app.exceptions.schema import AdditionalData, Origin | ||
|
||
|
||
class IncorrectEmailOrPasswordException(ExceptionBase): | ||
def __init__(self, origin: Origin = Origin.SERVICE): | ||
super().__init__( | ||
data=AdditionalData(origin=origin), | ||
error_code="US40101", | ||
status_code=status.HTTP_401_UNAUTHORIZED, | ||
name=self.__class__.__name__, | ||
message="The email or password provided is incorrect. Please check your email and password and try again.", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from sqlalchemy.orm import Session | ||
|
||
from app.api.v1.schemas.user import ApiKeyPayload | ||
from app.utils import generate_id, hash_password | ||
from netspresso.utils.db.models.user import User | ||
from netspresso.utils.db.repositories.user import user_repository | ||
|
||
|
||
class UserService: | ||
def create_user(self, db: Session, email: str, password: str, api_key: str): | ||
hashed_password = hash_password(password) | ||
|
||
user = User( | ||
email=email, | ||
password=hashed_password, | ||
api_key=api_key, | ||
) | ||
user = user_repository.save(db=db, model=user) | ||
|
||
return user | ||
|
||
def generate_api_key(self, db: Session, email: str, password: str) -> ApiKeyPayload: | ||
generated_id = generate_id(entity="user") | ||
|
||
user = user_repository.get_by_email(db=db, email=email) | ||
|
||
if user: | ||
hashed_password = hash_password(password) | ||
if user.password != hashed_password: | ||
user.password = hashed_password | ||
user.api_key = generated_id | ||
elif user.api_key != generated_id: | ||
user.api_key = generated_id | ||
user = user_repository.save(db=db, model=user) | ||
else: | ||
user = self.create_user( | ||
db=db, | ||
email=email, | ||
password=password, | ||
api_key=generated_id, | ||
) | ||
|
||
api_key = ApiKeyPayload(api_key=user.api_key) | ||
|
||
return api_key | ||
|
||
|
||
user_service = UserService() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import hashlib | ||
|
||
from nanoid import generate | ||
|
||
|
||
def generate_id(entity: str, size: int = 10) -> str: | ||
nano_id = generate(size=size) | ||
return f"{entity}_{nano_id}" | ||
|
||
|
||
def hash_password(password: str) -> str: | ||
return hashlib.sha256(password.encode()).hexdigest() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .timestamp_mixin import TimestampMixin | ||
|
||
__all__ = ["TimestampMixin"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from sqlalchemy import Column, DateTime, func | ||
from sqlalchemy.ext.declarative import declared_attr | ||
|
||
|
||
class TimestampMixin: | ||
@declared_attr | ||
def created_at(cls): | ||
return Column(DateTime(timezone=True), server_default=func.now(), nullable=False) | ||
|
||
@declared_attr | ||
def updated_at(cls): | ||
return Column( | ||
DateTime(timezone=True), | ||
server_default=func.now(), | ||
onupdate=func.now(), | ||
nullable=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from netspresso.utils.db.models.user import User | ||
from netspresso.utils.db.session import Base, engine | ||
|
||
Base.metadata.create_all(engine) | ||
|
||
|
||
__all__ = [ | ||
"User", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from sqlalchemy import Boolean, Column, Integer, String | ||
|
||
from netspresso.utils.db.mixins import TimestampMixin | ||
from netspresso.utils.db.session import Base | ||
|
||
|
||
class User(Base, TimestampMixin): | ||
__tablename__ = "user" | ||
|
||
id = Column(Integer, primary_key=True, index=True, unique=True, autoincrement=True, nullable=False) | ||
email = Column(String(36), nullable=False) | ||
password = Column(String(36), nullable=False) | ||
api_key = Column(String(36), nullable=False) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from enum import Enum | ||
from typing import Generic, Type, TypeVar | ||
|
||
from sqlalchemy import asc, desc | ||
from sqlalchemy.orm import Session | ||
|
||
from netspresso.utils.db.session import Base | ||
|
||
ModelType = TypeVar("ModelType", bound=Base) # type: ignore | ||
|
||
|
||
class Order(str, Enum): | ||
DESC = "desc" | ||
ASC = "asc" | ||
|
||
|
||
class BaseRepository(Generic[ModelType]): | ||
def __init__(self, model: Type[ModelType]): | ||
self.model = model | ||
|
||
def choose_order_func(self, order): | ||
if order == Order.DESC: | ||
return desc | ||
return asc | ||
|
||
def save(self, db: Session, model: ModelType) -> ModelType: | ||
db.add(model) | ||
db.commit() | ||
db.refresh(model) | ||
|
||
return model | ||
|
||
def update(self, db: Session, model: ModelType) -> ModelType: | ||
return self.save(db, model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Optional | ||
|
||
from sqlalchemy.orm import Session | ||
|
||
from netspresso.utils.db.models.user import User | ||
from netspresso.utils.db.repositories.base import BaseRepository | ||
|
||
|
||
class UserRepository(BaseRepository[User]): | ||
def get_by_email(self, db: Session, email: str) -> Optional[User]: | ||
user = db.query(User).filter(User.email == email).first() | ||
|
||
return user | ||
|
||
def get_by_user_id(self, db: Session, user_id: str) -> Optional[User]: | ||
user = db.query(self.model).filter(self.model.user_id == user_id).first() | ||
|
||
return user | ||
|
||
def get_by_api_key(self, db: Session, api_key: str) -> Optional[User]: | ||
user = db.query(self.model).filter(self.model.api_key == api_key).first() | ||
|
||
return user | ||
|
||
|
||
user_repository = UserRepository(User) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Generator | ||
|
||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import declarative_base, sessionmaker | ||
from sqlalchemy_utils import create_database, database_exists | ||
|
||
DB_URL = "sqlite:///netspresso.db" | ||
engine = create_engine( | ||
f"{DB_URL}", | ||
pool_pre_ping=True, | ||
pool_use_lifo=True, | ||
pool_recycle=3600, | ||
) | ||
SessionLocal = sessionmaker( | ||
autocommit=False, | ||
autoflush=False, | ||
bind=engine, | ||
expire_on_commit=False, | ||
) | ||
|
||
Base = declarative_base() | ||
|
||
|
||
def get_db() -> Generator: | ||
try: | ||
db = SessionLocal() | ||
yield db | ||
finally: | ||
db.close() | ||
|
||
|
||
if not database_exists(engine.url): | ||
create_database(engine.url) |