netrans/netrans_py/export.py

164 lines
6.1 KiB
Python

import os
import sys
import subprocess
import shutil
from utils import check_path, AttributeCopier, create_cls
# 检查 NETRANS_PATH 环境变量是否设置
# 定义数据集文件路径
dataset = 'dataset.txt'
class Export(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def export_network(self):
netrans = self.netrans
quantized = self.quantize_type
name = self.model_name
netrans_path = self.netrans_path
ovxgenerator = netrans + " export ovxlib"
# 进入模型目录
# os.chdir(name)
# 根据量化类型设置参数
if quantized == 'float':
type_ = 'float'
quantization_type = 'none_quantized'
generate_path = './wksp/none_quantized'
elif quantized == 'uint8':
type_ = 'quantized'
quantization_type = 'asymmetric_affine'
generate_path = './wksp/asymmetric_affine'
elif quantized == 'int8':
type_ = 'quantized'
quantization_type = 'dynamic_fixed_point-8'
generate_path = './wksp/dynamic_fixed_point-8'
elif quantized == 'int16':
type_ = 'quantized'
quantization_type = 'dynamic_fixed_point-16'
generate_path = './wksp/dynamic_fixed_point-16'
else:
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
sys.exit(1)
# 创建输出目录
os.makedirs(generate_path, exist_ok=True)
# 构建命令
if quantized == 'float':
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--target-ide-project 'linux64' \
--viv-sdk {netrans_path}/pnna_sdk \
--output-path {generate_path}/{name}_{quantization_type}"
else:
if not os.path.exists(f"{name}_{quantization_type}.quantize"):
print(f"\033[31m Can not find {name}_{quantization_type}.quantize \033[0m")
sys.exit(1)
else :
if not os.path.exists(f"{name}_postprocess_file.yml"):
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--viv-sdk {netrans_path}/pnna_sdk \
--model-quantize {name}_{quantization_type}.quantize \
--with-input-meta {name}_inputmeta.yml \
--target-ide-project 'linux64' \
--output-path {generate_path}/{quantization_type}"
else:
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--viv-sdk {netrans_path}/pnna_sdk \
--model-quantize {name}_{quantization_type}.quantize \
--with-input-meta {name}_inputmeta.yml \
--target-ide-project 'linux64' \
--postprocess-file {name}_postprocess_file.yml \
--output-path {generate_path}/{quantization_type}"
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
if result.returncode == 0:
print("\033[31m SUCCESS \033[0m")
else:
print(f"\033[31m ERROR ! {result.stderr} \033[0m")
# temp='wksp/temp'
# os.makedirs(temp, exist_ok=True)
source_dir = f"{generate_path}_nbg_viplite"
target_dir = generate_path
src_ngb = f"{source_dir}/network_binary.nb"
if self.profile:
try:
# 如果目标路径已存在,先删除(确保移动操作能成功)
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
# 移动整个目录到目标位置
shutil.move(source_dir, target_dir)
# print(f"Successfully moved directory {source_dir} to {target_dir}")
except Exception as e:
sys.exit(1) # 非零退出码表示错误
# print(f"Error moving directory: {e}")
else:
try:
# 仅复制network_binary.nb文件
shutil.rmtree(generate_path)
os.mkdir(generate_path)
shutil.copy(src_ngb, generate_path)
# print(f"Successfully copied {src_ngb} to {generate_path}")
except FileNotFoundError:
print(f"Error: {src_ngb} is not found")
except Exception as e:
print(f"Error occurred: {e}")
try:
# 清理源目录
shutil.rmtree(source_dir)
# print(f"Removed source directory {source_dir}")
except Exception as e:
# print(f"Error removing directory: {e}")
sys.exit(1) # 非零退出码表示错误
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
sys.exit(1)
# 检查网络目录是否存在
network_name = sys.argv[1]
# check_env(network_name)
if not os.path.exists(os.path.exists(network_name)):
print(f"Directory {network_name} does not exist !")
sys.exit(2)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# 调用导出函数ss
cla = create_cls(netrans_path, network_name, sys.argv[2])
func = Export(cla)
func.export_network()
# export_network(netrans, network_name, sys.argv[2])
if __name__ == '__main__':
main()