nertans python api

This commit is contained in:
xujiao 2025-04-07 11:29:05 +08:00
parent 50714682d5
commit 99c9a6fcde
12 changed files with 1208 additions and 0 deletions

171
netrans_py/README.md Normal file
View File

@ -0,0 +1,171 @@
# Python api netrans_py 使用介绍
netrans_py 支持通过 python api 灵活地将模型转换成pnna 支持的格式。
使用 ntrans_py 完成模型转换的步骤如下:
1. 导入模型
2. 生成并修改前处理配置文件 *_inputmeta.yml
3. 量化模型
4. 导出模型
## 安装
在使用netrans_py之前需要安装netrans_py。
设置环境变量 NETRANS_PATH 并指向该 bin 目录。
<font color="#dd0000">注意:</font> 在该项目中,项目下载目录为 `/home/nudt_dps/netrans`,在您应用的过程中,可以使用 `pwd` 来确认您的项目目录。
```bash
export NETRANS_PATH=/home/nudt_dps/netrans/bin
```
同时设置LD_LIBRARY_PATH(Ubuntu其他系统根据具体情况设置)
```bash
export LD_LIBRARY_PATH=/home/nudt_dps/netrans/bin:$LD_LIBRARY_PATH
```
注意这一步每次使用前都需要执行,或者您可以写入 .bashrc (路径为 `~/.bashrc` )。
然后进入目录 netrans_py 进行安装。
```bash
cd /home/nudt_dps/netrans/netrans_py
pip3 install -e .
```
## netrans_py api
### Netrans 导入api及创建实例
创建 Netrans
描述: 实例化 Netrans 类。
代码示例:
```py3
from netrans import Netrans
yolo_netrans = Netrans("../examples/darknet/yolov4_tiny")
```
参数
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|model_path| str| 第一位置参数,模型文件的路径|
|netans| str | 如果 NETRANS_PATH 没有设置可通过该参数指定netrans的路径|
输出返回:
无。
<font color="#dd0000">注意:</font> 模型目录准备需要和netrans_cli一致具体数据准备要求见[introduction](./introduction.md)。
### Netrans.load_model 模型导入
描述: 将模型转换成 pnna 支持的格式。
代码示例:
```py3
yolo_netrans.load_model()
```
参数:
无。
输出返回:
无。
在工程目录下生成 pnna 支持的模型格式,以.json结尾的模型文件和 .data结尾的权重文件。
### Netrans.gen_inputmeta 预处理配置文件生成
描述: 将模型转换成 pnna 支持的格式。
代码示例:
```py3
yolo_netrans.gen_inputmeta()
```
参数:
无。
输出返回:
无。
### Netrans.quantize 量化模型
描述: 对模型生成量化配置文件。
代码示例:
```py3
yolo_netrans.quantize("uint8")
```
参数:
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|quantize_type| str| 第一位置参数,模型量化类型,仅支持 "uint8", "int8", "int16"|
输出返回:
无。
### Netrans.export 模型导出
描述: 对模型生成量化配置文件。
代码示例:
```py3
yolo_netrans.export()
```
参数:
无。
输出返回:
无。请在目录 “wksp/*/” 下检查是否生成nbg文件。
### Netrans.model2nbg 一键生成nbg文件
描述: 模型导入、量化、及nbg文件生产
代码示例:
```py3
# 无预处理
yolo_netrans.model2nbg(quantize_type='uint8')
# 需要对数据进行normlize, menas为128, scale 为 0.0039
yolo_netrans.model2nbg(quantize_type='uint8',mean=128, scale = 0.0039)
# 需要对数据分通道进行normlize, menas为128,127,125,scale 为 0.0039, 且reverse_channel 为 True
yolo_netrans.model2nbg(quantize_type='uint8'mean=[128, 127, 125], scale = 0.0039, reverse_channel= True)
# 已经进行初始化设置
yolo_netrans.model2nbg(quantize_type='uint8', inputmeta=True)
```
参数
| 参数名 | 类型 | 说明 |
|:---| -- | -- |
|quantize_type| str, ["uint8", "int8", "int16" ] | 量化类型,将模型量化成该参数指定的类型 |
|inputmeta| bool,str, [Fasle, True, "inputmeta_filepath"] | 指定 inputmeta, 默认为False。 <br/> 如果为False则会生成inputmeta模板可使用mean、scale、reverse_channel 配合修改常用参数。<br/>如果已有现成的 inputmeta 文件则可通过该参数进行指定也可使用True, 则会自动索引 model_name_inputmeta.yml |
|mean| float, int, list | 设置预处理中 normalize 的 mean 参数 |
|scale| float, int, list | 设置预处理中 normalize 的 scale 参数 |
|reverse_channel | bool | 设置预处理中的 reverse_channel 参数 |
<!-- |||| -->
输出返回:
请在目录 “wksp/*/” 下检查是否生成nbg文件。
## 使用实例
```
from nertans import Netrans
model_path = 'example/darknet/yolov4_tiny'
netrans_path = "netrans/bin" # 如果进行了export定义申明这一步可以不用
# 初始化netrans
net = Netrans(model_path,netrans=netrans_path)
# 模型载入
net.load_model()
# 生成 inputmeta 文件
net.gen_inputmeta()
# 配置预处理 normlize 的参数
net.config(scale=1,mean=0)
# 模型量化
net.quantize("uint8")
# 模型导出
net.export()
# 模型直接量化成 int16 并导出, 直接复用刚配置好的 inputmeta
net.model2nbg(quantize_type = "int16", inputmeta=True)
```

58
netrans_py/example.py Normal file
View File

@ -0,0 +1,58 @@
import argparse
from netrans import Netrans
def main():
# 创建参数解析器
parser = argparse.ArgumentParser(
description='神经网络模型转换工具',
formatter_class=argparse.ArgumentDefaultsHelpFormatter # 自动显示默认值
)
# 必填位置参数
parser.add_argument(
'model_path',
type=str,
help='输入模型路径(必须参数)'
)
# 可选参数组
quant_group = parser.add_argument_group('量化参数')
quant_group.add_argument(
'-q', '--quantize_type',
type=str,
choices=['uint8', 'int8', 'int16', 'float'],
default='uint8',
metavar='TYPE',
help='量化类型(可选值:%(choices)s'
)
quant_group.add_argument(
'-m', '--mean',
type=int,
default=0,
help='归一化均值(默认:%(default)s'
)
quant_group.add_argument(
'-s', '--scale',
type=float,
default=1.0,
help='量化缩放系数(默认:%(default)s'
)
# 解析参数
args = parser.parse_args()
# 执行模型转换
try:
model = Netrans(model_path=args.model_path)
model.model2nbg(
quantize_type=args.quantize_type,
mean=args.mean,
scale=args.scale
)
print(f"模型 {args.model_path} 转换成功")
except FileNotFoundError:
print(f"错误:模型文件 {args.model_path} 不存在")
exit(1)
if __name__ == "__main__":
main()

150
netrans_py/export.py Normal file
View File

@ -0,0 +1,150 @@
import os
import sys
import subprocess
import shutil
from utils import check_path, AttributeCopier, creat_cla
# 检查 NETRANS_PATH 环境变量是否设置
# 定义数据集文件路径
dataset = 'dataset.txt'
class Export(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def export_network(self):
netrans = self.netrans
quantized = self.quantize_type
name = self.model_name
netrans_path = self.netrans_path
ovxgenerator = netrans + " export ovxlib"
# 进入模型目录
# os.chdir(name)
# 根据量化类型设置参数
if quantized == 'float':
type_ = 'float'
quantization_type = 'none_quantized'
generate_path = './wksp/none_quantized'
elif quantized == 'uint8':
type_ = 'quantized'
quantization_type = 'asymmetric_affine'
generate_path = './wksp/asymmetric_affine'
elif quantized == 'int8':
type_ = 'quantized'
quantization_type = 'dynamic_fixed_point-8'
generate_path = './wksp/dynamic_fixed_point-8'
elif quantized == 'int16':
type_ = 'quantized'
quantization_type = 'dynamic_fixed_point-16'
generate_path = './wksp/dynamic_fixed_point-16'
else:
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
sys.exit(1)
# 创建输出目录
os.makedirs(generate_path, exist_ok=True)
# 构建命令
if quantized == 'float':
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--target-ide-project 'linux64' \
--viv-sdk {netrans_path}/pnna_sdk \
--output-path {generate_path}/{name}_{quantization_type}"
else:
if not os.path.exists(f"{name}_{quantization_type}.quantize"):
print(f"\033[31m Can not find {name}_{quantization_type}.quantize \033[0m")
sys.exit(1)
if not os.path.exists(f"{name}_postprocess_file.yml"):
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--viv-sdk {netrans_path}/pnna_sdk \
--model-quantize {name}_{quantization_type}.quantize \
--with-input-meta {name}_inputmeta.yml \
--target-ide-project 'linux64' \
--output-path {generate_path}/{quantization_type}"
else:
cmd = f"{ovxgenerator} \
--model {name}.json \
--model-data {name}.data \
--dtype {type_} \
--pack-nbg-viplite \
--optimize 'VIP8000NANOQI_PLUS_PID0XB1'\
--viv-sdk {netrans_path}/pnna_sdk \
--model-quantize {name}_{quantization_type}.quantize \
--with-input-meta {name}_inputmeta.yml \
--target-ide-project 'linux64' \
--postprocess-file {name}_postprocess_file.yml \
--output-path {generate_path}/{quantization_type}"
# 执行命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
if result.returncode == 0:
print("\033[31m SUCCESS \033[0m")
else:
print(f"\033[31m ERROR ! {result.stderr} \033[0m")
# temp='wksp/temp'
# os.makedirs(temp, exist_ok=True)
src_ngb = f'{generate_path}_nbg_viplite/network_binary.nb'
try :
shutil.copy(src_ngb, generate_path)
except FileNotFoundError:
print(f"Error: {src_ngb} is not found")
except Exception as e :
print(f"a error occurred : {e}")
try:
shutil.rmtree(f"{generate_path}_nbg_viplite")
except:
sys.exit()
# try :
# shutil.move(temp, generate_path )
# except:
# sys.exit()
# 返回原始目录
# os.chdir('..')
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
sys.exit(1)
# 检查网络目录是否存在
network_name = sys.argv[1]
# check_env(network_name)
if not os.path.exists(os.path.exists(network_name)):
print(f"Directory {network_name} does not exist !")
sys.exit(2)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# 调用导出函数ss
cla = creat_cla(netrans_path, network_name, sys.argv[2])
func = Export(cla)
func.export_network()
# export_network(netrans, network_name, sys.argv[2])
if __name__ == '__main__':
main()

49
netrans_py/file_model.py Normal file
View File

@ -0,0 +1,49 @@
__all__ = ['extensions']
class model_extensions:
def __init__(self, model, model_data, model_quantize, input_meta, output_meta):
self._model = model
self._model_data = model_data
self._model_quantize = model_quantize
self._input_meta = input_meta
self._output_meta = output_meta
@property
def model(self):
return self._model
@property
def model_data(self):
return self._model_data
@property
def model_quantize(self):
return self._model_quantize
@property
def input_meta(self):
return self._input_meta
@property
def output_meta(self):
return self._output_meta
class file_model:
def __init__(self,extensions):
self._extensions = extensions
@property
def extensions(self):
return self._extensions
x_extensions = model_extensions(
'.json',
'.data',
'.quantize',
'_inputmeta.yml',
'.yml'
)
_file_model = file_model(x_extensions)
extensions = _file_model.extensions

View File

@ -0,0 +1,38 @@
import os
import sys
from utils import check_path, AttributeCopier, creat_cla
class InputmetaGen(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def inputmeta_gen(self):
netrans_path = self.netrans
network_name = self.model_name
# 进入网络名称指定的目录
# os.chdir(network_name)
# check_env(network_name)
# 执行 pegasus 命令
os.system(f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database")
# os.chdir("..")
def main():
# 检查命令行参数数量是否正确
if len(sys.argv) != 2:
print("Enter a network name!")
sys.exit(2)
# 检查提供的目录是否存在
network_name = sys.argv[1]
# 构建 netrans 可执行文件的路径
netrans_path =os.getenv('NETRANS_PATH')
cla = creat_cla(netrans_path, network_name)
func = InputmetaGen(cla)
func.inputmeta_gen()
if __name__ == '__main__':
main()

225
netrans_py/import_model.py Normal file
View File

@ -0,0 +1,225 @@
import os
import sys
import subprocess
from utils import check_path, AttributeCopier, creat_cla
def check_status(result):
if result.returncode == 0:
print("\033[31m LOAD MODEL SUCCESS \033[0m")
else:
print(f"\033[31m ERROR: {result.stderr} \033[0m")
def import_caffe_network(name, netrans_path):
# 定义转换工具的路径
convert_caffe =netrans_path + " import caffe"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_prototxt_path = f"{name}.prototxt"
model_caffemodel_path = f"{name}.caffemodel"
# 打印转换信息
print(f"=========== Converting {name} Caffe model ===========")
# 构建转换命令
if os.path.isfile(model_caffemodel_path):
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--weights {model_caffemodel_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
else:
print("=========== fake Caffe model data file =============")
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
os.system(cmd)
def import_tensorflow_network(name, netrans_path):
# 定义转换工具的命令
convertf_cmd = f"{netrans_path} import tensorflow"
# 打印转换信息
print(f"=========== Converting {name} Tensorflow model ===========")
# 读取 inputs_outputs.txt 文件中的参数
with open('inputs_outputs.txt', 'r') as f:
inputs_outputs_params = f.read().strip()
# 构建转换命令
cmd = f"{convertf_cmd} \
--model {name}.pb \
--output-data {name}.data \
--output-model {name}.json \
{inputs_outputs_params}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_onnx_network(name, netrans_path):
# 定义转换工具的命令
convert_onnx_cmd = f"{netrans_path} import onnx"
# 打印转换信息
print(f"=========== Converting {name} ONNX model ===========")
if os.path.exists(f"{name}_outputs.txt"):
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
with open(output_path, 'r', encoding='utf-8') as file:
outputs = str(file.readline().strip())
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data \
--outputs '{outputs}'"
else:
# 构建转换命令
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
####### TFLITE
def import_tflite_network(name, netrans_path):
# 定义转换工具的路径或命令
convert_tflite = f"{netrans_path} import tflite"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_tflite_path = f"{name}.tflite"
# 打印转换信息
print(f"=========== Converting {name} TFLite model ===========")
# 构建转换命令
cmd = f"{convert_tflite} \
--model {model_tflite_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_darknet_network(name, netrans_path):
# 定义转换工具的命令
convert_darknet_cmd = f"{netrans_path} import darknet"
# 打印转换信息
print(f"=========== Converting {name} darknet model ===========")
# 构建转换命令
cmd = f"{convert_darknet_cmd} \
--model {name}.cfg \
--weight {name}.weights \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_pytorch_network(name, netrans_path):
# 定义转换工具的命令
convert_pytorch_cmd = f"{netrans_path} import pytorch"
# 打印转换信息
print(f"=========== Converting {name} pytorch model ===========")
# 读取 input_size.txt 文件中的参数
try:
with open('input_size.txt', 'r') as file:
input_size_params = ' '.join(file.readlines())
except FileNotFoundError:
print("Error: input_size.txt not found.")
sys.exit(1)
# 构建转换命令
cmd = f"{convert_pytorch_cmd} \
--model {name}.pt \
--output-model {name}.json \
--output-data {name}.data \
{input_size_params}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
# 使用示例
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
class ImportModel(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
# print(source_obj.__dict__)
@check_path
def import_network(self):
if self.verbose is True :
print("begin load model")
# print(self.model_path)
print(os.getcwd())
print(f"{self.model_name}.weights")
name = self.model_name
netrans_path = self.netrans
if os.path.isfile(f"{name}.prototxt"):
import_caffe_network(name, netrans_path)
elif os.path.isfile(f"{name}.pb"):
import_tensorflow_network(name, netrans_path)
elif os.path.isfile(f"{name}.onnx"):
import_onnx_network(name, netrans_path)
elif os.path.isfile(f"{name}.tflite"):
import_tflite_network(name, netrans_path)
elif os.path.isfile(f"{name}.weights"):
import_darknet_network(name, netrans_path)
elif os.path.isfile(f"{name}.pt"):
import_pytorch_network(name, netrans_path)
else :
# print(os.getcwd())
print("=========== can not find suitable model files ===========")
sys.exit(-3)
# os.chdir("..")
def main():
if len(sys.argv) != 2 :
print("Input a network")
sys.exit(-1)
network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
clas = creat_cla(netrans_path, network_name,verbose=False)
func = ImportModel(clas)
func.import_network()
if __name__ == "__main__":
main()

95
netrans_py/infer.py Normal file
View File

@ -0,0 +1,95 @@
import os
import sys
import subprocess
from utils import check_path, AttributeCopier, creat_cla
class Infer(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def inference_network(self):
netrans = self.netrans
quantized = self.quantize_type
name = self.model_name
# print(self.__dict__)
netrans += " inference"
# 进入模型目录
# 定义类型和量化类型
if quantized == 'float':
type_ = 'float32'
quantization_type = 'float32'
elif quantized == 'uint8':
quantization_type = 'asymmetric_affine'
type_ = 'quantized'
elif quantized == 'int8':
quantization_type = 'dynamic_fixed_point-8'
type_ = 'quantized'
elif quantized == 'int16':
quantization_type = 'dynamic_fixed_point-16'
type_ = 'quantized'
else:
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
sys.exit(-1)
# 构建推理命令
inf_path = './inf'
cmd = f"{netrans} \
--dtype {type_} \
--batch-size 1 \
--model-quantize {name}_{quantization_type}.quantize \
--model {name}.json \
--model-data {name}.data \
--output-dir {inf_path} \
--with-input-meta {name}_inputmeta.yml \
--device CPU"
# 执行推理命令
if self.verbose is True:
print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
if result.returncode == 0:
print("\033[32m SUCCESS \033[0m")
else:
print(f"\033[31m ERROR: {result.stderr} \033[0m")
# 返回原始目录
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
sys.exit(-1)
# 检查网络目录是否存在
network_name = sys.argv[1]
if not os.path.exists(network_name):
print(f"Directory {network_name} does not exist !")
sys.exit(-2)
# print("here")
# 定义 netrans 路径
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
quantize_type = sys.argv[2]
cla = creat_cla(netrans_path, network_name,quantize_type,False)
# 调用量化函数
func = Infer(cla)
func.inference_network()
# 定义数据集文件路径
# dataset_path = './dataset.txt'
# 调用推理函数
# inference_network(network_name, sys.argv[2])
if __name__ == '__main__':
# print("main")
main()

142
netrans_py/netrans.py Normal file
View File

@ -0,0 +1,142 @@
import sys, os
import subprocess
# import yaml
from ruamel.yaml import YAML
from ruamel import yaml
import file_model
from import_model import ImportModel
from quantize import Quantize
from export import Export
from gen_inputmeta import InputmetaGen
# from utils import check_path
import warnings
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
class Netrans():
def __init__(self, model_path, netrans=None, verbose=False):
self.verbose = verbose
self.model_path = os.path.abspath(model_path)
self.set_netrans(netrans)
_, self.model_name = os.path.split(self.model_path)
# self.model_name,_ = os.path.splitext(self.model_name)
"""
pipe line
"""
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
self.load_model()
self.config(inputmeta, **kargs)
self.quantize(quantize_type)
self.export()
"""
set netrans
"""
def get_os_netrans_path(self):
# print(os.environ.get('NETRANS_PATH'))
return os.environ.get('NETRANS_PATH')
def check_netarans(self):
res = subprocess.run([self.netrans], text=True)
if res.returncode != 0:
print("pleace check the netrans")
# return False
sys.exit()
else :
return
def set_netrans(self, netrans_path=None):
if netrans_path is not None :
netrans_path = os.path.abspath(netrans_path)
else :
netrans_path = self.get_os_netrans_path()
# print(netrans_path)
if os.path.exists(netrans_path):
self.netrans = os.path.join(netrans_path, 'pnnacc')
self.netrans_path = netrans_path
else :
print('NETRANS_PATH NOT BEEN SETTED')
"""
edit config
"""
# @check_path
def config(self, inputmeta=False, **kargs):
if isinstance(inputmeta, str):
self.input_meta = inputmeta
elif isinstance(inputmeta, bool):
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
if inputmeta is False : self.inputmeta_gen()
else :
sys.exit("check inputmeta file")
if len(kargs) == 0 : return
if kargs['mean']==0 and kargs['scale'] ==1 : return
if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
with open(self.input_meta,'r') as f :
yaml = YAML()
data = yaml.load(f)
data = self.upload_cfg(data,**kargs)
with open(self.input_meta,'w') as f :
yaml = YAML()
yaml.dump(data, f)
def upload_cfg(self, data, channel=3, **kargs):
grey = config['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
if kargs.get('mean') is not None:
mean = handel_param(kargs['mean'],grey)
self.upload_cfg_mean(data, mean)
if kargs.get('scale') is not None:
scale = handel_param(kargs['scale'],grey)
self.upload_cfg_scale(data, scale)
if kargs.get('reverse_channel') is not None:
if isinstance(kargs['reverse_channel'],bool):
self.upload_cfg_reverse_channel(data, kargs['reverse_channel'])
return data
def upload_cfg_mean(self, data, mean):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['mean'] = mean
def upload_cfg_scale(self, data, scale):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['scale'] = scale
def upload_cfg_reverse_channel(self, data, reverse_channel):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
def load_model(self):
func = ImportModel(self)
func.import_network()
def inputmeta_gen(self):
func = InputmetaGen(self)
func.inputmeta_gen()
def quantize(self, quantize_type):
self.quantize_type = quantize_type
func = Quantize(self)
func.quantize_network()
def export(self, **kargs):
if kargs.get('quantize_type') :
self.quantize_type = kargs['quantize_type']
func = Export(self)
func.export_network()
def handel_param(param, grey=False):
if grey : return param
else :
return param if isinstance(param, list) else [param]*3
if __name__ == '__main__':
network = '../../model_zoo/yolov4_tiny'
yolo = Netrans(network)
yolo.inputmeta_gen()
# yolo.model2nb("uint8")
# yolo.load_model()
# yolo.config(mean=[0,0,0],scale=1)
# yolo.quantize('uint8')
# yolo.export()

93
netrans_py/quantize.py Normal file
View File

@ -0,0 +1,93 @@
import os
import sys
from utils import check_path, AttributeCopier, creat_cla
class Quantize(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def quantize_network(self):
netrans = self.netrans
quantized_type = self.quantize_type
name = self.model_name
# check_env(name)
# print(os.getcwd())
netrans += " quantize"
# 根据量化类型设置量化参数
if quantized_type == 'float':
print("=========== do not need quantized===========")
return
elif quantized_type == 'uint8':
quantization_type = "asymmetric_affine"
elif quantized_type == 'int8':
quantization_type = "dynamic_fixed_point-8"
elif quantized_type == 'int16':
quantization_type = "dynamic_fixed_point-16"
else:
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
return
# 输出量化信息
print(" =======================================================================")
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
print(" =======================================================================")
current_directory = os.getcwd()
txt_path = current_directory+"/dataset.txt"
with open(txt_path, 'r', encoding='utf-8') as file:
num_lines = len(file.readlines())
# 移除已存在的量化文件
quantize_file = f"{name}_{quantization_type}.quantize"
if os.path.exists(quantize_file):
print(f"\033[31m rm {quantize_file} \033[0m")
os.remove(quantize_file)
# 构建并执行量化命令
cmd = f"{netrans} \
--batch-size 1 \
--qtype {quantized_type} \
--rebuild \
--quantizer {quantization_type.split('-')[0]} \
--model-quantize {quantize_file} \
--model {name}.json \
--model-data {name}.data \
--with-input-meta {name}_inputmeta.yml \
--device CPU \
--algorithm kl_divergence \
--iterations {num_lines}"
os.system(cmd)
# 检查量化结果
if os.path.exists(quantize_file):
print("\033[31m QUANTIZED SUCCESS \033[0m")
else:
print("\033[31m ERROR ! \033[0m")
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
sys.exit(-1)
# 检查网络目录是否存在
network_name = sys.argv[1]
# 定义 netrans 路径
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
quantize_type = sys.argv[2]
cla = creat_cla(netrans_path, network_name,quantize_type)
# 调用量化函数
run = Quantize(cla)
run.quantize_network()
if __name__ == "__main__":
main()

91
netrans_py/quantize_hb.py Normal file
View File

@ -0,0 +1,91 @@
import os
import sys
from utils import check_path, AttributeCopier, creat_cla
class Quantize(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def quantize_network(self):
netrans = self.netrans
quantized_type = self.quantize_type
name = self.model_name
# check_env(name)
# print(os.getcwd())
netrans += " quantize"
# 根据量化类型设置量化参数
if quantized_type == 'float':
print("=========== do not need quantized===========")
return
elif quantized_type == 'uint8':
quantization_type = "asymmetric_affine"
elif quantized_type == 'int8':
quantization_type = "dynamic_fixed_point-8"
elif quantized_type == 'int16':
quantization_type = "dynamic_fixed_point-16"
else:
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
return
# 输出量化信息
print(" =======================================================================")
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
print(" =======================================================================")
# 移除已存在的量化文件
quantize_file = f"{name}_{quantization_type}.quantize"
current_directory = os.getcwd()
txt_path = current_directory+"/dataset.txt"
with open(txt_path, 'r', encoding='utf-8') as file:
num_lines = len(file.readlines())
# 构建并执行量化命令
cmd = f"{netrans} \
--qtype {quantized_type} \
--hybrid \
--quantizer {quantization_type.split('-')[0]} \
--model-quantize {quantize_file} \
--model {name}.json \
--model-data {name}.data \
--with-input-meta {name}_inputmeta.yml \
--device CPU \
--algorithm kl_divergence \
--divergence-nbins 2048 \
--iterations {num_lines}"
os.system(cmd)
# 检查量化结果
if os.path.exists(quantize_file):
print("\033[31m QUANTIZED SUCCESS \033[0m")
else:
print("\033[31m ERROR ! \033[0m")
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
sys.exit(-1)
# 检查网络目录是否存在
network_name = sys.argv[1]
# 定义 netrans 路径
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
quantize_type = sys.argv[2]
cla = creat_cla(netrans_path, network_name,quantize_type)
# 调用量化函数
run = Quantize(cla)
run.quantize_network()
if __name__ == "__main__":
main()

16
netrans_py/setup.py Normal file
View File

@ -0,0 +1,16 @@
from setuptools import setup, find_packages
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setup(
name="netrans",
version="0.1.0",
author="nudt_dsp",
url="https://gitlink.org.cn/gwg_xujiao/netrans",
packages=find_packages(include=["netrans_py"]),
package_dir={"": "."}, # 指定根目录映射关系[8](@ref)
install_requires=[
"ruamel.yaml==0.18.6"
]
)

80
netrans_py/utils.py Normal file
View File

@ -0,0 +1,80 @@
import sys
import os
# from functools import wraps
# def check_path(netrans, model_path):
# def decorator(func):
# @wraps(func)
# def wrapper(netrans, model_path, *args, **kargs):
# check_dir(model_path)
# check_netrans(netrans)
# if os.getcwd() != model_path :
# os.chdir(model_path)
# return func(netrans, model_path, *args, **kargs)
# return wrapper
# return decorator
def check_path(func):
def wrapper(cla, *args, **kargs):
check_netrans(cla.netrans)
if os.getcwd() != cla.model_path :
os.chdir(cla.model_path)
return func(cla, *args, **kargs)
return wrapper
def check_dir(network_name):
if not os.path.exists(network_name):
print(f"Directory {network_name} does not exist !")
sys.exit(-1)
os.chdir(network_name)
def check_netrans(netrans):
if 'NETRANS_PATH' not in os.environ :
return
if netrans != None and os.path.exists(netrans) is True:
return
print("Need to set enviroment variable NETRANS_PATH")
sys.exit(1)
def remove_history_file(name):
os.chdir(name)
if os.path.isfile(f"{name}.json"):
os.remove(f"{name}.json")
if os.path.isfile(f"{name}.data"):
os.remove(f"{name}.data")
os.chdir('..')
def check_env(name):
check_dir(name)
# check_netrans()
# remove_history_file(name)
class AttributeCopier:
def __init__(self, source_obj) -> None:
self.copy_attribute_name(source_obj)
def copy_attribute_name(self, source_obj):
for attribute_name in self._get_attribute_names(source_obj):
setattr(self, attribute_name, getattr(source_obj, attribute_name))
@staticmethod
def _get_attribute_names(source_obj):
return source_obj.__dict__.keys()
class creat_cla(): #dataclass @netrans_params
def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
self.netrans_path = netrans_path
self.netrans = os.path.join(self.netrans_path, 'pnnacc')
self.model_name=self.model_path = name
self.model_path = os.path.abspath(self.model_path)
self.verbose=verbose
self.quantize_type = quantized_type
if __name__ == "__main__":
dir_name = "yolo"
os.mkdir(dir_name)
check_dir(dir_name)