Skip to content

Commit

Permalink
#37 Add verification options (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jan 31, 2024
1 parent 1bc26a7 commit d77ab52
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 38 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ session = SessionClient(email='YOUR_EMAIL', password='YOUR_PASSWORD')
compressor = ModelCompressor(user_session=session)
```

If you face some ssl verification error, please use the following codes.

```python
from netspresso.client import SessionClient
from netspresso.compressor import ModelCompressor

session = SessionClient(email='YOUR_EMAIL', password='YOUR_PASSWORD', verify_ssl=False)
compressor = ModelCompressor(user_session=session)
```

### Upload Model

To upload your trained model, simply enter the required information.
Expand Down Expand Up @@ -214,8 +224,10 @@ Convert an ONNX model into a TensorRT model, and benchmark the TensorRT model on

```python
from loguru import logger
from netspresso.client import SessionClient
from netspresso.launcher import ModelConverter, ModelBenchmarker, ModelFramework, TaskStatus, DeviceName, SoftwareVersion

session = SessionClient(email='YOUR_EMAIL', password='YOUR_PASSWORD')
converter = ModelConverter(user_session=session)

model = converter.upload_model("./examples/sample_models/test.onnx")
Expand Down
23 changes: 12 additions & 11 deletions netspresso/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ def validate_token(func) -> None:
@wraps(func)
def wrapper(self, *args, **kwargs):
if not check_jwt_exp(self.user_session.access_token):
self.user_session.__reissue_token()
self.user_session._reissue_token()
return func(self, *args, **kwargs)

return wrapper

class SessionClient:
def __init__(self, email: str, password: str, config: Config = None):
def __init__(self, email: str, password: str, config: Config = EndPoint.GENERAL, verify_ssl: bool = True):
"""Initialize the UserSession.
Args:
Expand All @@ -27,20 +27,21 @@ def __init__(self, email: str, password: str, config: Config = None):

self.email = email
self.password = password
self.config = config if config is not None else Config(EndPoint.GENRAL)
self.config = Config(config)
self.host = self.config.HOST
self.port = self.config.PORT
self.uri_prefix = self.config.URI_PREFIX
self.base_url = f"{self.host}:{self.port}{self.uri_prefix}"
self.user_id = None
self.verify_ssl = verify_ssl
self.__login()
self.__get_user_info()

def __login(self) -> None:
try:
url = f"{self.base_url}/auth/local/login"
data = LoginRequest(username=self.email, password=self.password)
response = requests.post(url, json=data.dict(), headers=get_headers())
response = requests.post(url, json=data.dict(), headers=get_headers(), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200 or response.status_code == 201:
Expand All @@ -58,7 +59,7 @@ def __login(self) -> None:
def __get_user_info(self):
try:
url = f"{self.base_url}/user"
response = requests.get(url, headers=get_headers(access_token=self.access_token))
response = requests.get(url, headers=get_headers(access_token=self.access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200 or response.status_code == 201:
Expand All @@ -72,15 +73,15 @@ def __get_user_info(self):
logger.error(f"Failed to get user information. Error: {e}")
raise e

def __reissue_token(self) -> None:
def _reissue_token(self) -> None:
try:
url = f"{self.base_url}/token"
url = f"{self.base_url}/auth/token"
data = RefreshTokenRequest(access_token=self.access_token, refresh_token=self.refresh_token)
response = requests.post(url, data=data.json(), headers=get_headers(json_type=True))
response = requests.post(url, data=data.json(), headers=get_headers(json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200 or response.status_code == 201:
session = RefreshTokenResponse(**response_body)
session = RefreshTokenResponse(**response_body["tokens"])
self.access_token = session.access_token
self.refresh_token = session.refresh_token
else:
Expand All @@ -91,7 +92,7 @@ def __reissue_token(self) -> None:

class BaseClient:
user_session: SessionClient = None
def __init__(self, email=None, password=None, user_session=None):
def __init__(self, email=None, password=None, user_session=None, verify_ssl=True):
"""Initialize the Model Compressor.
Args:
Expand All @@ -109,7 +110,7 @@ def __init__(self, email=None, password=None, user_session=None):
self.user_session = user_session
elif email and password:
# Case 2: Creating from email and password
self.user_session = SessionClient(email=email, password=password)
self.user_session = SessionClient(email=email, password=password, verify_ssl=verify_ssl)
else:
raise NotImplementedError("There is no avaliable constructors for given paremeters.")

Expand Down
6 changes: 3 additions & 3 deletions netspresso/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def create_literal(cls):
]

class EndPoint(str, Enum):
GENRAL = "general"
GENERAL = "general"
COMPRESSOR = "compressor"
LAUNCHER = "launcher"
@classmethod
Expand All @@ -49,12 +49,12 @@ def create_literal(cls):

class Config:
ENVIRONMENT_TYPE: EnvironmentType = EnvironmentType.PROD
CONFIG_SESSION: str = f"{EndPoint.GENRAL}.{ENVIRONMENT_TYPE}"
CONFIG_SESSION: str = f"{EndPoint.GENERAL}.{ENVIRONMENT_TYPE}"
HOST: str = config[CONFIG_SESSION][EndPointProperty.HOST]
PORT: int = int(config[CONFIG_SESSION][EndPointProperty.PORT])
URI_PREFIX: str = config[CONFIG_SESSION][EndPointProperty.URI_PREFIX]

def __init__(self, endpoint: EndPoint = EndPoint.GENRAL):
def __init__(self, endpoint: EndPoint = EndPoint.GENERAL):
self.ENVIRONMENT_TYPE = EnvironmentType(DEPLOYMENT_MODE.lower())
self.CONFIG_SESSION = f"{endpoint}.{self.ENVIRONMENT_TYPE}"
self.HOST = config[self.CONFIG_SESSION][EndPointProperty.HOST]
Expand Down
6 changes: 3 additions & 3 deletions netspresso/compressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


class ModelCompressor(BaseClient):
def __init__(self, email=None, password=None, user_session=None):
def __init__(self, email=None, password=None, user_session=None, verify_ssl=True):
"""Initialize the Model Compressor.
Args:
Expand All @@ -44,8 +44,8 @@ def __init__(self, email=None, password=None, user_session=None):
ModelCompressor(email='USER_EMAIL',password='PASSWORD')
ModelCompressor(user_session=SessionClient(email='USER_EMAIL',password='PASSWORD')
"""
super().__init__(email=email, password=password, user_session=user_session)
self.client = ModelCompressorAPIClient()
super().__init__(email=email, password=password, user_session=user_session, verify_ssl=verify_ssl)
self.client = ModelCompressorAPIClient(verify_ssl=self.user_session.verify_ssl)
self.model_factory = ModelFactory()

@validate_token
Expand Down
29 changes: 15 additions & 14 deletions netspresso/compressor/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@
from netspresso.client.config import Config, EndPoint

class ModelCompressorAPIClient:
def __init__(self):
def __init__(self, verify_ssl: bool = True):
self.config = Config(EndPoint.COMPRESSOR)
self.host = self.config.HOST
self.port = self.config.PORT
self.prefix = self.config.URI_PREFIX
self.verify_ssl = verify_ssl
self.url = f"{self.host}:{self.port}{self.prefix}"

def upload_model(self, data: UploadModelRequest, access_token) -> ModelResponse:
url = f"{self.url}/models"
files = get_files(data.file_path)
response = requests.post(url, data=data.dict(), files=files, headers=get_headers(access_token))
response = requests.post(url, data=data.dict(), files=files, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -48,7 +49,7 @@ def upload_model(self, data: UploadModelRequest, access_token) -> ModelResponse:

def get_parent_models(self, is_simple, access_token):
url = f"{self.url}/models/parents?is_simple={is_simple}"
response = requests.get(url, headers=get_headers(access_token))
response = requests.get(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -58,7 +59,7 @@ def get_parent_models(self, is_simple, access_token):

def get_children_models(self, model_id, access_token):
url = f"{self.url}/models/{model_id}/children"
response = requests.get(url, headers=get_headers(access_token))
response = requests.get(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -68,7 +69,7 @@ def get_children_models(self, model_id, access_token):

def get_model_info(self, model_id, access_token) -> ModelResponse:
url = f"{self.url}/models/{model_id}"
response = requests.get(url, headers=get_headers(access_token))
response = requests.get(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -78,7 +79,7 @@ def get_model_info(self, model_id, access_token) -> ModelResponse:

def get_download_model_link(self, model_id, access_token) -> GetDownloadLinkResponse:
url = f"{self.url}/models/{model_id}/download"
response = requests.post(url, headers=get_headers(access_token))
response = requests.post(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -88,7 +89,7 @@ def get_download_model_link(self, model_id, access_token) -> GetDownloadLinkResp

def delete_model(self, model_id, access_token):
url = f"{self.url}/models/{model_id}"
response = requests.delete(url, headers=get_headers(access_token))
response = requests.delete(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -98,7 +99,7 @@ def delete_model(self, model_id, access_token):

def get_available_layers(self, data, access_token) -> GetAvailableLayersReponse:
url = f"{self.url}/models/{data.model_id}/get_available_layers"
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True))
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -108,7 +109,7 @@ def get_available_layers(self, data, access_token) -> GetAvailableLayersReponse:

def create_compression(self, data, access_token) -> CompressionResponse:
url = f"{self.url}/compressions"
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True))
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -118,7 +119,7 @@ def create_compression(self, data, access_token) -> CompressionResponse:

def get_recommendation(self, data, access_token) -> RecommendationResponse:
url = f"{self.url}/models/{data.model_id}/recommendation"
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True))
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -129,7 +130,7 @@ def get_recommendation(self, data, access_token) -> RecommendationResponse:

def compress_model(self, data, access_token):
url = f"{self.url}/compressions/{data.compression_id}"
response = requests.put(url, data=data.json(), headers=get_headers(access_token, json_type=True))
response = requests.put(url, data=data.json(), headers=get_headers(access_token, json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -139,7 +140,7 @@ def compress_model(self, data, access_token):

def auto_compression(self, data, access_token):
url = f"{self.url}/models/{data.model_id}/auto_compress"
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True))
response = requests.post(url, data=data.json(), headers=get_headers(access_token, json_type=True), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -151,7 +152,7 @@ def auto_compression(self, data, access_token):
def get_compression_info(self, compression_id, access_token):
url = f"{self.url}/compressions/{compression_id}"

response = requests.get(url, headers=get_headers(access_token))
response = requests.get(url, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand All @@ -162,7 +163,7 @@ def get_compression_info(self, compression_id, access_token):
def upload_dataset(self, data, access_token):
url = f"{self.url}/models/{data.model_id}/datasets"
files = get_files(data.file_path)
response = requests.post(url, files=files, headers=get_headers(access_token))
response = requests.post(url, files=files, headers=get_headers(access_token), verify=self.verify_ssl)
response_body = json.loads(response.text)

if response.status_code == 200:
Expand Down
12 changes: 6 additions & 6 deletions netspresso/launcher/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def upload_model(self, model_file_path: str, target_function: LauncherFunction)
url = f"{self.url}/{target_function.value.lower()}/upload_model"
files = get_files(model_file_path)
# files = {"file": open(model_file_path, "rb")}
response = requests.post(url, files=files, headers=get_headers(self.user_session.access_token))
response = requests.post(url, files=files, headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return Model(**response_body)
Expand Down Expand Up @@ -85,7 +85,7 @@ def convert_model(
if software_version is not None:
request_data.software_version = software_version

response = requests.post(url, data=request_data.dict(), headers=get_headers(self.user_session.access_token))
response = requests.post(url, data=request_data.dict(), headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return ConversionTask(**response_body)
Expand All @@ -106,7 +106,7 @@ def get_conversion_task(self, conversion_task_uuid: str) -> ConversionTask:
"""

url = f"{self.url}/convert/{conversion_task_uuid}"
response = requests.get(url, headers=get_headers(self.user_session.access_token))
response = requests.get(url, headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return ConversionTask(**response_body)
Expand All @@ -126,7 +126,7 @@ def get_converted_model(self, conversion_task_uuid: str):
ConversionTask: model conversion task object.
"""
url = f"{self.url}/convert/{conversion_task_uuid}/download"
response = requests.get(url, headers=get_headers(self.user_session.access_token))
response = requests.get(url, headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return response_body
Expand Down Expand Up @@ -164,7 +164,7 @@ def benchmark_model(
software_version=software_version,
hardware_type=hardware_type,
)
response = requests.post(url, json=request_data.dict(), headers=get_headers(self.user_session.access_token))
response = requests.post(url, json=request_data.dict(), headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return BenchmarkTask(**response_body)
Expand All @@ -184,7 +184,7 @@ def get_benchmark(self, benchmark_task_uuid: str) -> BenchmarkTask:
BenchmarkTask: model benchmark task object.
"""
url = f"{self.url}/benchmark/{benchmark_task_uuid}"
response = requests.get(url, headers=get_headers(self.user_session.access_token))
response = requests.get(url, headers=get_headers(self.user_session.access_token), verify=self.user_session.verify_ssl)
response_body = json.loads(response.text)
if response.status_code < 300:
return BenchmarkTask(**response_body)
Expand Down
4 changes: 4 additions & 0 deletions netspresso/launcher/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class DeviceName(StrEnumBase):
Intel_XEON_W_2233 = "Intel-Xeon"
ALIF_ENSEMBLE_E7_DEVKIT_GEN2 = "Ensemble-E7-DevKit-Gen2"
RENESAS_RA8D1 = "Renesas-RA8D1"
Renesas_RA8D1 = "Renesas-RA8D1"
Ensemble_E7_DevKit_Gen2 = "Ensemble-E7-DevKit-Gen2"

@classmethod
def create_literal(cls):
Expand All @@ -94,6 +96,8 @@ def create_literal(cls):
"Intel-Xeon",
"Ensemble-E7-DevKit-Gen2",
"Renesas-RA8D1",
"Renesas-RA8D1",
"Ensemble-E7-DevKit-Gen2"
]


Expand Down
2 changes: 1 addition & 1 deletion netspresso/launcher/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TargetDevice(BaseModel):

display_name: str = Field(default=None)
display_brand_name: str = Field(default=None)
device_name: DeviceName = Field(default=None)
device_name: Union[DeviceName, str] = Field(default=None, union_mode='left_to_right')
software_version: Optional[str] = Field(default=None)
software_version_display_name: Optional[str] = Field(default=None)
hardware_type: Optional[HardwareType] = Field(default=None)
Expand Down

0 comments on commit d77ab52

Please sign in to comment.