diff --git a/netrans_py/dump.py b/netrans_py/dump.py new file mode 100644 index 0000000..5c42070 --- /dev/null +++ b/netrans_py/dump.py @@ -0,0 +1,95 @@ +import os +import sys +import subprocess +from utils import check_path, AttributeCopier, creat_cla + +class Infer(AttributeCopier): + def __init__(self, source_obj) -> None: + super().__init__(source_obj) + + @check_path + def inference_network(self): + netrans = self.netrans + quantized = self.quantize_type + name = self.model_name + # print(self.__dict__) + + netrans += " dump" + # 进入模型目录 + + # 定义类型和量化类型 + if quantized == 'float': + type_ = 'float32' + quantization_type = 'float32' + elif quantized == 'uint8': + quantization_type = 'asymmetric_affine' + type_ = 'quantized' + elif quantized == 'int8': + quantization_type = 'dynamic_fixed_point-8' + type_ = 'quantized' + elif quantized == 'int16': + quantization_type = 'dynamic_fixed_point-16' + type_ = 'quantized' + else: + print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========") + sys.exit(-1) + + # 构建推理命令 + inf_path = './inf' + cmd = f"{netrans} \ + --dtype {type_} \ + --batch-size 1 \ + --model-quantize {name}_{quantization_type}.quantize \ + --model {name}.json \ + --model-data {name}.data \ + --output-dir {inf_path} \ + --with-input-meta {name}_inputmeta.yml \ + --device CPU" + + # 执行推理命令 + if self.verbose is True: + print(cmd) + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + # 检查执行结果 + if result.returncode == 0: + print("\033[32m SUCCESS \033[0m") + else: + print(f"\033[31m ERROR: {result.stderr} \033[0m") + + # 返回原始目录 + +def main(): + # 检查命令行参数数量 + if len(sys.argv) < 3: + print("Input a network name and quantized type ( float / uint8 / int8 / int16 )") + sys.exit(-1) + + # 检查网络目录是否存在 + network_name = sys.argv[1] + if not os.path.exists(network_name): + print(f"Directory {network_name} does not exist !") + sys.exit(-2) + # print("here") + # 定义 netrans 路径 + # netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc') + network_name = sys.argv[1] + # check_env(network_name) + + netrans_path = os.environ['NETRANS_PATH'] + # netrans = os.path.join(netrans_path, 'pnnacc') + quantize_type = sys.argv[2] + cla = creat_cla(netrans_path, network_name,quantize_type,False) + + # 调用量化函数 + func = Infer(cla) + func.inference_network() + + # 定义数据集文件路径 + # dataset_path = './dataset.txt' + # 调用推理函数 + # inference_network(network_name, sys.argv[2]) + +if __name__ == '__main__': + # print("main") + main()