68 lines
1.8 KiB
Python
Executable File
68 lines
1.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
|
||
import argparse
|
||
from netrans import Netrans
|
||
|
||
def main():
|
||
# 创建参数解析器
|
||
parser = argparse.ArgumentParser(
|
||
description='神经网络模型转换工具',
|
||
formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
|
||
)
|
||
|
||
# 必填位置参数
|
||
parser.add_argument(
|
||
'model_path',
|
||
type=str,
|
||
help='输入模型路径(必须参数)'
|
||
)
|
||
|
||
# 可选参数组
|
||
quant_group = parser.add_argument_group('量化参数')
|
||
quant_group.add_argument(
|
||
'-q', '--quantize_type',
|
||
type=str,
|
||
choices=['uint8', 'int8', 'int16', 'float'],
|
||
default='uint8',
|
||
metavar='TYPE',
|
||
help='量化类型(可选值:%(choices)s)'
|
||
)
|
||
quant_group.add_argument(
|
||
'-m', '--mean',
|
||
type=int,
|
||
default=0,
|
||
help='归一化均值(默认:%(default)s)'
|
||
)
|
||
quant_group.add_argument(
|
||
'-s', '--scale',
|
||
type=float,
|
||
default=1.0,
|
||
help='量化缩放系数(默认:%(default)s)'
|
||
)
|
||
parser.add_argument(
|
||
'-p', '--profile',
|
||
action='store_true', # 设置为True当参数存在时
|
||
help='启用性能分析模式(默认:%(default)s)'
|
||
)
|
||
|
||
|
||
# 解析参数
|
||
args = parser.parse_args()
|
||
|
||
# 执行模型转换
|
||
try:
|
||
model = Netrans(model_path=args.model_path)
|
||
model.model2nbg(
|
||
quantize_type=args.quantize_type,
|
||
mean=args.mean,
|
||
scale=args.scale,
|
||
profile=args.profile
|
||
)
|
||
print(f"模型 {args.model_path} 转换成功")
|
||
except FileNotFoundError:
|
||
print(f"错误:模型文件 {args.model_path} 不存在")
|
||
exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|