-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Support saving and loading models in different formats #3758
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
README_CN.md | ||
README_CN.md | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import os | ||
|
||
import numpy as np | ||
import paddle | ||
from paddle.inference import create_predictor | ||
from paddle.inference import Config as PredictConfig | ||
|
||
|
@@ -29,22 +30,20 @@ | |
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Test') | ||
parser.add_argument( | ||
"--config", | ||
help="The deploy config generated by exporting model.", | ||
type=str, | ||
required=True) | ||
parser.add_argument("--config", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. formatter自动更新,与主题无关。 |
||
help="The deploy config generated by exporting model.", | ||
type=str, | ||
required=True) | ||
parser.add_argument( | ||
'--image_path', | ||
help='The directory or path or file list of the images to be predicted.', | ||
type=str, | ||
required=True) | ||
|
||
parser.add_argument( | ||
'--dynamic_shape_path', | ||
type=str, | ||
default="./dynamic_shape.pbtxt", | ||
help='The path to save dynamic shape.') | ||
parser.add_argument('--dynamic_shape_path', | ||
type=str, | ||
default="./dynamic_shape.pbtxt", | ||
help='The path to save dynamic shape.') | ||
|
||
return parser.parse_args() | ||
|
||
|
@@ -62,7 +61,10 @@ def collect_dynamic_shape(args): | |
|
||
# prepare config | ||
cfg = DeployConfig(args.config) | ||
pred_cfg = PredictConfig(cfg.model, cfg.params) | ||
if paddle.__version__.split('.')[0] == '2': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 兼容paddle 2.x |
||
pred_cfg = PredictConfig(cfg.model, cfg.params) | ||
else: | ||
pred_cfg = PredictConfig(cfg.model_dir, cfg.model_prefix) | ||
pred_cfg.enable_use_gpu(1000, 0) | ||
pred_cfg.collect_shape_range_info(args.dynamic_shape_path) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,25 +21,45 @@ | |
|
||
|
||
class DeployConfig: | ||
|
||
def __init__(self, path): | ||
with codecs.open(path, 'r', 'utf-8') as file: | ||
self.dic = yaml.load(file, Loader=yaml.FullLoader) | ||
|
||
self._transforms = self.load_transforms(self.dic['Deploy'][ | ||
'transforms']) | ||
self._transforms = self.load_transforms( | ||
self.dic['Deploy']['transforms']) | ||
self._dir = os.path.dirname(path) | ||
self._is_old_format = 'model_prefix' not in self.dic['Deploy'] | ||
|
||
@property | ||
def transforms(self): | ||
return self._transforms | ||
|
||
@property | ||
def model(self): | ||
return os.path.join(self._dir, self.dic['Deploy']['model']) | ||
if self._is_old_format: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 兼容旧版本导出模型格式。 |
||
return os.path.join(self._dir, self.dic['Deploy']['model']) | ||
else: | ||
return os.path.join(self._dir, | ||
self.dic['Deploy']['model_prefix'] + '.pdmodel') | ||
|
||
@property | ||
def params(self): | ||
return os.path.join(self._dir, self.dic['Deploy']['params']) | ||
if self._is_old_format: | ||
return os.path.join(self._dir, self.dic['Deploy']['params']) | ||
else: | ||
return os.path.join( | ||
self._dir, self.dic['Deploy']['model_prefix'] + '.pdiparams') | ||
|
||
@property | ||
def model_dir(self): | ||
return self._dir | ||
|
||
@property | ||
def model_prefix(self): | ||
if self._is_old_format: | ||
return self.dic['Deploy']['model'][:-8] | ||
return self.dic['Deploy']['model_prefix'] | ||
|
||
@staticmethod | ||
def load_transforms(t_list): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -203,7 +203,7 @@ def forward_tokens(self, x): | |
def forward(self, x): | ||
x = self.patch_embed(x) | ||
x = self.forward_tokens(x) | ||
if self.mode is not 'multi_scale': | ||
if self.mode != 'multi_scale': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 顺带修复bug,与PR主题无关。 |
||
x = [ | ||
paddle.concat( | ||
[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
顺带修复bug,与PR主题无关。