96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
import os
|
|
import sys
|
|
import subprocess
|
|
from utils import check_path, AttributeCopier, creat_cla
|
|
|
|
class Infer(AttributeCopier):
|
|
def __init__(self, source_obj) -> None:
|
|
super().__init__(source_obj)
|
|
|
|
@check_path
|
|
def inference_network(self):
|
|
netrans = self.netrans
|
|
quantized = self.quantize_type
|
|
name = self.model_name
|
|
# print(self.__dict__)
|
|
|
|
netrans += " dump"
|
|
# 进入模型目录
|
|
|
|
# 定义类型和量化类型
|
|
if quantized == 'float':
|
|
type_ = 'float32'
|
|
quantization_type = 'float32'
|
|
elif quantized == 'uint8':
|
|
quantization_type = 'asymmetric_affine'
|
|
type_ = 'quantized'
|
|
elif quantized == 'int8':
|
|
quantization_type = 'dynamic_fixed_point-8'
|
|
type_ = 'quantized'
|
|
elif quantized == 'int16':
|
|
quantization_type = 'dynamic_fixed_point-16'
|
|
type_ = 'quantized'
|
|
else:
|
|
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
|
|
sys.exit(-1)
|
|
|
|
# 构建推理命令
|
|
inf_path = './inf'
|
|
cmd = f"{netrans} \
|
|
--dtype {type_} \
|
|
--batch-size 1 \
|
|
--model-quantize {name}_{quantization_type}.quantize \
|
|
--model {name}.json \
|
|
--model-data {name}.data \
|
|
--output-dir {inf_path} \
|
|
--with-input-meta {name}_inputmeta.yml \
|
|
--device CPU"
|
|
|
|
# 执行推理命令
|
|
if self.verbose is True:
|
|
print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
if result.returncode == 0:
|
|
print("\033[32m SUCCESS \033[0m")
|
|
else:
|
|
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
|
|
|
# 返回原始目录
|
|
|
|
def main():
|
|
# 检查命令行参数数量
|
|
if len(sys.argv) < 3:
|
|
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
|
|
sys.exit(-1)
|
|
|
|
# 检查网络目录是否存在
|
|
network_name = sys.argv[1]
|
|
if not os.path.exists(network_name):
|
|
print(f"Directory {network_name} does not exist !")
|
|
sys.exit(-2)
|
|
# print("here")
|
|
# 定义 netrans 路径
|
|
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
|
network_name = sys.argv[1]
|
|
# check_env(network_name)
|
|
|
|
netrans_path = os.environ['NETRANS_PATH']
|
|
# netrans = os.path.join(netrans_path, 'pnnacc')
|
|
quantize_type = sys.argv[2]
|
|
cla = creat_cla(netrans_path, network_name,quantize_type,False)
|
|
|
|
# 调用量化函数
|
|
func = Infer(cla)
|
|
func.inference_network()
|
|
|
|
# 定义数据集文件路径
|
|
# dataset_path = './dataset.txt'
|
|
# 调用推理函数
|
|
# inference_network(network_name, sys.argv[2])
|
|
|
|
if __name__ == '__main__':
|
|
# print("main")
|
|
main()
|