164 lines
6.1 KiB
Python
164 lines
6.1 KiB
Python
import os
|
|
import sys
|
|
import subprocess
|
|
import shutil
|
|
from utils import check_path, AttributeCopier, create_cls
|
|
# 检查 NETRANS_PATH 环境变量是否设置
|
|
|
|
# 定义数据集文件路径
|
|
dataset = 'dataset.txt'
|
|
|
|
class Export(AttributeCopier):
|
|
def __init__(self, source_obj) -> None:
|
|
super().__init__(source_obj)
|
|
|
|
@check_path
|
|
def export_network(self):
|
|
|
|
netrans = self.netrans
|
|
quantized = self.quantize_type
|
|
name = self.model_name
|
|
netrans_path = self.netrans_path
|
|
|
|
ovxgenerator = netrans + " export ovxlib"
|
|
# 进入模型目录
|
|
# os.chdir(name)
|
|
|
|
# 根据量化类型设置参数
|
|
if quantized == 'float':
|
|
type_ = 'float'
|
|
quantization_type = 'none_quantized'
|
|
generate_path = './wksp/none_quantized'
|
|
elif quantized == 'uint8':
|
|
type_ = 'quantized'
|
|
quantization_type = 'asymmetric_affine'
|
|
generate_path = './wksp/asymmetric_affine'
|
|
elif quantized == 'int8':
|
|
type_ = 'quantized'
|
|
quantization_type = 'dynamic_fixed_point-8'
|
|
generate_path = './wksp/dynamic_fixed_point-8'
|
|
elif quantized == 'int16':
|
|
type_ = 'quantized'
|
|
quantization_type = 'dynamic_fixed_point-16'
|
|
generate_path = './wksp/dynamic_fixed_point-16'
|
|
else:
|
|
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
|
|
sys.exit(1)
|
|
|
|
# 创建输出目录
|
|
os.makedirs(generate_path, exist_ok=True)
|
|
|
|
# 构建命令
|
|
if quantized == 'float':
|
|
cmd = f"{ovxgenerator} \
|
|
--model {name}.json \
|
|
--model-data {name}.data \
|
|
--dtype {type_} \
|
|
--pack-nbg-viplite \
|
|
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
|
|
--target-ide-project 'linux64' \
|
|
--viv-sdk {netrans_path}/pnna_sdk \
|
|
--output-path {generate_path}/{name}_{quantization_type}"
|
|
else:
|
|
if not os.path.exists(f"{name}_{quantization_type}.quantize"):
|
|
print(f"\033[31m Can not find {name}_{quantization_type}.quantize \033[0m")
|
|
sys.exit(1)
|
|
else :
|
|
if not os.path.exists(f"{name}_postprocess_file.yml"):
|
|
cmd = f"{ovxgenerator} \
|
|
--model {name}.json \
|
|
--model-data {name}.data \
|
|
--dtype {type_} \
|
|
--pack-nbg-viplite \
|
|
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
|
|
--viv-sdk {netrans_path}/pnna_sdk \
|
|
--model-quantize {name}_{quantization_type}.quantize \
|
|
--with-input-meta {name}_inputmeta.yml \
|
|
--target-ide-project 'linux64' \
|
|
--output-path {generate_path}/{quantization_type}"
|
|
else:
|
|
cmd = f"{ovxgenerator} \
|
|
--model {name}.json \
|
|
--model-data {name}.data \
|
|
--dtype {type_} \
|
|
--pack-nbg-viplite \
|
|
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
|
|
--viv-sdk {netrans_path}/pnna_sdk \
|
|
--model-quantize {name}_{quantization_type}.quantize \
|
|
--with-input-meta {name}_inputmeta.yml \
|
|
--target-ide-project 'linux64' \
|
|
--postprocess-file {name}_postprocess_file.yml \
|
|
--output-path {generate_path}/{quantization_type}"
|
|
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
if result.returncode == 0:
|
|
print("\033[31m SUCCESS \033[0m")
|
|
else:
|
|
print(f"\033[31m ERROR ! {result.stderr} \033[0m")
|
|
|
|
|
|
# temp='wksp/temp'
|
|
# os.makedirs(temp, exist_ok=True)
|
|
|
|
source_dir = f"{generate_path}_nbg_viplite"
|
|
target_dir = generate_path
|
|
src_ngb = f"{source_dir}/network_binary.nb"
|
|
if self.profile:
|
|
try:
|
|
# 如果目标路径已存在,先删除(确保移动操作能成功)
|
|
if os.path.exists(target_dir):
|
|
shutil.rmtree(target_dir)
|
|
# 移动整个目录到目标位置
|
|
shutil.move(source_dir, target_dir)
|
|
# print(f"Successfully moved directory {source_dir} to {target_dir}")
|
|
except Exception as e:
|
|
sys.exit(1) # 非零退出码表示错误
|
|
# print(f"Error moving directory: {e}")
|
|
else:
|
|
try:
|
|
# 仅复制network_binary.nb文件
|
|
shutil.rmtree(generate_path)
|
|
os.mkdir(generate_path)
|
|
shutil.copy(src_ngb, generate_path)
|
|
|
|
# print(f"Successfully copied {src_ngb} to {generate_path}")
|
|
except FileNotFoundError:
|
|
print(f"Error: {src_ngb} is not found")
|
|
except Exception as e:
|
|
print(f"Error occurred: {e}")
|
|
|
|
try:
|
|
# 清理源目录
|
|
shutil.rmtree(source_dir)
|
|
# print(f"Removed source directory {source_dir}")
|
|
except Exception as e:
|
|
# print(f"Error removing directory: {e}")
|
|
sys.exit(1) # 非零退出码表示错误
|
|
|
|
def main():
|
|
# 检查命令行参数数量
|
|
if len(sys.argv) < 3:
|
|
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
|
|
sys.exit(1)
|
|
# 检查网络目录是否存在
|
|
network_name = sys.argv[1]
|
|
# check_env(network_name)
|
|
if not os.path.exists(os.path.exists(network_name)):
|
|
print(f"Directory {network_name} does not exist !")
|
|
sys.exit(2)
|
|
|
|
netrans_path = os.environ['NETRANS_PATH']
|
|
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
|
# 调用导出函数ss
|
|
cla = create_cls(netrans_path, network_name, sys.argv[2])
|
|
func = Export(cla)
|
|
func.export_network()
|
|
|
|
# export_network(netrans, network_name, sys.argv[2])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|