netrans/netrans_cli/example.py

68 lines
1.8 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import argparse
from netrans import Netrans
def main():
# 创建参数解析器
parser = argparse.ArgumentParser(
description='神经网络模型转换工具',
formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
)
# 必填位置参数
parser.add_argument(
'model_path',
type=str,
help='输入模型路径(必须参数)'
)
# 可选参数组
quant_group = parser.add_argument_group('量化参数')
quant_group.add_argument(
'-q', '--quantize_type',
type=str,
choices=['uint8', 'int8', 'int16', 'float'],
default='uint8',
metavar='TYPE',
help='量化类型(可选值:%(choices)s'
)
quant_group.add_argument(
'-m', '--mean',
type=int,
default=0,
help='归一化均值(默认:%(default)s'
)
quant_group.add_argument(
'-s', '--scale',
type=float,
default=1.0,
help='量化缩放系数(默认:%(default)s'
)
parser.add_argument(
'-p', '--profile',
action='store_true', # 设置为True当参数存在时
help='启用性能分析模式(默认:%(default)s'
)
# 解析参数
args = parser.parse_args()
# 执行模型转换
try:
model = Netrans(model_path=args.model_path)
model.model2nbg(
quantize_type=args.quantize_type,
mean=args.mean,
scale=args.scale,
profile=args.profile
)
print(f"模型 {args.model_path} 转换成功")
except FileNotFoundError:
print(f"错误:模型文件 {args.model_path} 不存在")
exit(1)
if __name__ == "__main__":
main()