Skip to content

Commit

Permalink
#112 Remove validate_token decorator & Add token_handler (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jan 26, 2024
1 parent 4b7f781 commit 4b5b477
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 97 deletions.
24 changes: 14 additions & 10 deletions netspresso/benchmarker/benchmarker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from loguru import logger

from netspresso.clients.auth import validate_token, auth_client
from netspresso.clients.auth import auth_client, TokenHandler
from netspresso.clients.launcher import launcher_client
from netspresso.clients.launcher.schemas import TargetDeviceFilter
from netspresso.clients.launcher.schemas.model import BenchmarkTask
Expand All @@ -25,13 +25,12 @@


class Benchmarker:
def __init__(self, tokens, user_info):
def __init__(self, token_handler: TokenHandler, user_info):
"""Initialize the Model Benchmarker."""

self.tokens = tokens
self.token_handler = token_handler
self.user_info = user_info

@validate_token
def benchmark_model(
self,
input_model_path: Union[Path, str],
Expand All @@ -57,6 +56,9 @@ def benchmark_model(
Returns:
Dict: model benchmark task dict.
"""

self.token_handler.validate_token()

try:
folder_path = Path(input_model_path).parent

Expand All @@ -68,10 +70,10 @@ def benchmark_model(
metadatas = [metadata.asdict()]
MetadataHandler.save_json(metadatas, folder_path, file_name="benchmark")

current_credit = auth_client.get_credit(self.tokens.access_token)
current_credit = auth_client.get_credit(self.token_handler.tokens.access_token)
check_credit_balance(user_credit=current_credit, service_credit=ServiceCredit.MODEL_BENCHMARK)
model = launcher_client.upload_model(
model_file_path=input_model_path, target_function=Module.BENCHMARK, access_token=self.tokens.access_token
model_file_path=input_model_path, target_function=Module.BENCHMARK, access_token=self.token_handler.tokens.access_token
)
model_uuid = model.model_uuid

Expand Down Expand Up @@ -127,7 +129,7 @@ def benchmark_model(
data_type=target_data_type,
software_version=target_software_version,
hardware_type=target_hardware_type,
access_token=self.tokens.access_token,
access_token=self.token_handler.tokens.access_token,
)
model_benchmark = self.get_benchmark_task(benchmark_task=model_benchmark)

Expand Down Expand Up @@ -163,7 +165,7 @@ def benchmark_model(
metadatas[-1] = metadata.asdict()
MetadataHandler.save_json(data=metadatas, folder_path=folder_path, file_name="benchmark")

remaining_credit = auth_client.get_credit(self.tokens.access_token)
remaining_credit = auth_client.get_credit(self.token_handler.tokens.access_token)
logger.info(
f"{ServiceCredit.MODEL_BENCHMARK} credits have been consumed. Remaining Credit: {remaining_credit}"
)
Expand All @@ -182,7 +184,6 @@ def benchmark_model(

return metadata.asdict()

@validate_token
def get_benchmark_task(self, benchmark_task: Union[str, BenchmarkTask]) -> BenchmarkTask:
"""Get the benchmark task information with given benchmark task or benchmark task uuid.
Expand All @@ -195,6 +196,9 @@ def get_benchmark_task(self, benchmark_task: Union[str, BenchmarkTask]) -> Bench
Returns:
BenchmarkTask: model benchmark task object.
"""

self.token_handler.validate_token()

try:
task_uuid = None
if type(benchmark_task) is str:
Expand All @@ -206,7 +210,7 @@ def get_benchmark_task(self, benchmark_task: Union[str, BenchmarkTask]) -> Bench
"There is no available function for the given parameter. The 'benchmark_task' should be a UUID string or a ModelBenchmark object."
)

return launcher_client.get_benchmark(benchmark_task_uuid=task_uuid, access_token=self.tokens.access_token)
return launcher_client.get_benchmark(benchmark_task_uuid=task_uuid, access_token=self.token_handler.tokens.access_token)

except Exception as e:
logger.error(f"Get benchmark failed. Error: {e}")
4 changes: 2 additions & 2 deletions netspresso/clients/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .main import auth_client, validate_token
from .main import auth_client, TokenHandler

__all__ = ["auth_client", "validate_token"]
__all__ = ["auth_client", "TokenHandler"]
30 changes: 18 additions & 12 deletions netspresso/clients/auth/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
from functools import wraps
from datetime import datetime

import requests
from loguru import logger
import jwt
import pytz

from netspresso.clients.auth.schemas import (
LoginRequest,
Expand All @@ -11,17 +13,7 @@
UserResponse,
)
from netspresso.clients.config import Config, Module
from netspresso.clients.utils import check_jwt_exp, get_headers


def validate_token(func) -> None:
@wraps(func)
def wrapper(self, *args, **kwargs):
if not check_jwt_exp(self.tokens.access_token):
self.reissue_token()
return func(self, *args, **kwargs)

return wrapper
from netspresso.clients.utils import get_headers


class AuthClient:
Expand Down Expand Up @@ -94,6 +86,7 @@ def reissue_token(self, access_token, refresh_token) -> Tokens:

if response.status_code == 200 or response.status_code == 201:
tokens = Tokens(**response_body["tokens"])
logger.info("Successfully reissued token")
return tokens
else:
raise Exception(response_body["detail"])
Expand All @@ -102,4 +95,17 @@ def reissue_token(self, access_token, refresh_token) -> Tokens:
raise e


class TokenHandler:
def __init__(self, tokens) -> None:
self.tokens = tokens

def check_jwt_exp(self):
payload = jwt.decode(self.tokens.access_token, options={"verify_signature": False})
return datetime.now(pytz.utc).timestamp() <= payload["exp"]

def validate_token(self):
if not self.check_jwt_exp():
self.tokens = auth_client.reissue_token(self.tokens.access_token, self.tokens.refresh_token)


auth_client = AuthClient()
3 changes: 1 addition & 2 deletions netspresso/clients/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from .common import get_files, get_headers
from .system import ENV_STR
from .token import check_jwt_exp


__all__ = [
"get_files",
"get_headers",
"check_jwt_exp",
"ENV_STR",
]
9 changes: 0 additions & 9 deletions netspresso/clients/utils/token.py

This file was deleted.

Loading

0 comments on commit 4b5b477

Please sign in to comment.