Skip to content

Commit

Permalink
#282 Add a progress bar to show progress during upload (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jul 16, 2024
1 parent fc5cee2 commit c7d57c3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 15 deletions.
30 changes: 22 additions & 8 deletions netspresso/clients/compressor/v2/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from dataclasses import asdict

import requests
from requests_toolbelt import MultipartEncoderMonitor
from requests_toolbelt.multipart.encoder import MultipartEncoder
from tqdm import tqdm

from netspresso.clients.compressor.v2.schemas.common import RequestPagination, UploadFile
from netspresso.clients.compressor.v2.schemas.compression import (
RequestAutomaticCompressionParams,
Expand All @@ -21,7 +26,7 @@
ResponseModelUrl,
)
from netspresso.clients.config import Config, Module
from netspresso.clients.utils.common import get_headers
from netspresso.clients.utils.common import create_multipart_data, create_progress_func, get_headers, progress_callback
from netspresso.clients.utils.requester import Requester


Expand All @@ -45,14 +50,23 @@ def create_model(

return ResponseModelUrl(**response.json())

def upload_model(self, request_data: RequestUploadModel, file: UploadFile, access_token: str, verify_ssl: bool = True) -> bool:
def upload_model(
self, request_data: RequestUploadModel, file: UploadFile, access_token: str, verify_ssl: bool = True
) -> bool:
url = f"{self.url}/models/upload"
response = Requester.post_as_form(
url=url,
binary=file.files,
request_body=asdict(request_data),
headers=get_headers(access_token),
)

file_info = file.files[0][1]

multipart_data = create_multipart_data(request_data.url, file_info)
progress = create_progress_func(multipart_data)

# Wrap the encoder with MultipartEncoderMonitor
monitor = MultipartEncoderMonitor(multipart_data, lambda monitor: progress_callback(monitor, progress))

headers = get_headers(access_token)
headers["Content-Type"] = monitor.content_type

response = requests.post(url=url, data=monitor, headers=headers, verify=verify_ssl)

return response.text

Expand Down
26 changes: 19 additions & 7 deletions netspresso/clients/launcher/v2/implements/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from dataclasses import asdict

import requests
from requests_toolbelt import MultipartEncoderMonitor

from netspresso.clients.launcher.v2.interfaces import ModelInterface
from netspresso.clients.launcher.v2.schemas import (
AuthorizationHeader,
Expand All @@ -14,6 +17,7 @@
ResponseModelUploadUrl,
UploadFile,
)
from netspresso.clients.utils.common import create_multipart_data, create_progress_func, progress_callback
from netspresso.clients.utils.requester import Requester
from netspresso.enums import LauncherTask

Expand Down Expand Up @@ -44,13 +48,21 @@ def upload(
file: UploadFile,
headers: AuthorizationHeader,
) -> str:
endpoint = f"{self.model_base_url}/upload"
response = Requester().post_as_form(
url=endpoint,
headers=headers.to_dict(),
binary=file.files,
request_body=asdict(request_body),
)
url = f"{self.model_base_url}/upload"

file_info = file.files[0][1]

multipart_data = create_multipart_data(request_body.url, file_info)
progress = create_progress_func(multipart_data)

# Wrap the encoder with MultipartEncoderMonitor
monitor = MultipartEncoderMonitor(multipart_data, lambda monitor: progress_callback(monitor, progress))

headers = headers.to_dict()
headers["Content-Type"] = monitor.content_type

response = requests.post(url=url, data=monitor, headers=headers)

return response.text

def validate(
Expand Down
35 changes: 35 additions & 0 deletions netspresso/clients/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from pathlib import Path

from requests_toolbelt.multipart.encoder import MultipartEncoder
from tqdm import tqdm

from netspresso.clients.utils.system import ENV_STR

version = (Path(__file__).parent.parent.parent / "VERSION").read_text().strip()
Expand Down Expand Up @@ -31,3 +34,35 @@ def get_files(file_path):
(Path(file_path).name, open(file_path, "rb"), "application/octet-stream"),
)
]


def create_multipart_data(url, file_info):
# Prepare the multipart form data
file_name = file_info[0]
file_content = file_info[1]
multipart_data = MultipartEncoder(
fields={
"url": (None, url, "application/json"),
"file": (file_name, file_content, "application/octet-stream"),
}
)

return multipart_data


def create_progress_func(multipart_data):
# Progress callback function
progress = tqdm(
total=multipart_data.len,
unit="B",
unit_scale=True,
unit_divisor=1024,
colour="#1BBFD6",
desc="Uploading model",
)

return progress


def progress_callback(monitor, progress):
progress.update(monitor.bytes_read - progress.n)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ netspresso_trainer==0.2.2
PyGithub>=2.1.1
matplotlib>=3.7.4
aenum==3.1.15
requests-toolbelt>=1.0.0

0 comments on commit c7d57c3

Please sign in to comment.