Merge remote-tracking branch 'refs/remotes/origin/master'
This commit is contained in:
commit
bacb26b28a
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from utils import check_path, AttributeCopier, creat_cla
|
||||
|
||||
class Infer(AttributeCopier):
|
||||
def __init__(self, source_obj) -> None:
|
||||
super().__init__(source_obj)
|
||||
|
||||
@check_path
|
||||
def inference_network(self):
|
||||
netrans = self.netrans
|
||||
quantized = self.quantize_type
|
||||
name = self.model_name
|
||||
# print(self.__dict__)
|
||||
|
||||
netrans += " dump"
|
||||
# 进入模型目录
|
||||
|
||||
# 定义类型和量化类型
|
||||
if quantized == 'float':
|
||||
type_ = 'float32'
|
||||
quantization_type = 'float32'
|
||||
elif quantized == 'uint8':
|
||||
quantization_type = 'asymmetric_affine'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int8':
|
||||
quantization_type = 'dynamic_fixed_point-8'
|
||||
type_ = 'quantized'
|
||||
elif quantized == 'int16':
|
||||
quantization_type = 'dynamic_fixed_point-16'
|
||||
type_ = 'quantized'
|
||||
else:
|
||||
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
|
||||
sys.exit(-1)
|
||||
|
||||
# 构建推理命令
|
||||
inf_path = './inf'
|
||||
cmd = f"{netrans} \
|
||||
--dtype {type_} \
|
||||
--batch-size 1 \
|
||||
--model-quantize {name}_{quantization_type}.quantize \
|
||||
--model {name}.json \
|
||||
--model-data {name}.data \
|
||||
--output-dir {inf_path} \
|
||||
--with-input-meta {name}_inputmeta.yml \
|
||||
--device CPU"
|
||||
|
||||
# 执行推理命令
|
||||
if self.verbose is True:
|
||||
print(cmd)
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# 检查执行结果
|
||||
if result.returncode == 0:
|
||||
print("\033[32m SUCCESS \033[0m")
|
||||
else:
|
||||
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
||||
|
||||
# 返回原始目录
|
||||
|
||||
def main():
|
||||
# 检查命令行参数数量
|
||||
if len(sys.argv) < 3:
|
||||
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
|
||||
sys.exit(-1)
|
||||
|
||||
# 检查网络目录是否存在
|
||||
network_name = sys.argv[1]
|
||||
if not os.path.exists(network_name):
|
||||
print(f"Directory {network_name} does not exist !")
|
||||
sys.exit(-2)
|
||||
# print("here")
|
||||
# 定义 netrans 路径
|
||||
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
||||
network_name = sys.argv[1]
|
||||
# check_env(network_name)
|
||||
|
||||
netrans_path = os.environ['NETRANS_PATH']
|
||||
# netrans = os.path.join(netrans_path, 'pnnacc')
|
||||
quantize_type = sys.argv[2]
|
||||
cla = creat_cla(netrans_path, network_name,quantize_type,False)
|
||||
|
||||
# 调用量化函数
|
||||
func = Infer(cla)
|
||||
func.inference_network()
|
||||
|
||||
# 定义数据集文件路径
|
||||
# dataset_path = './dataset.txt'
|
||||
# 调用推理函数
|
||||
# inference_network(network_name, sys.argv[2])
|
||||
|
||||
if __name__ == '__main__':
|
||||
# print("main")
|
||||
main()
|
Loading…
Reference in New Issue