netrans/netrans_py/import_model.py

226 lines
6.8 KiB
Python

import os
import sys
import subprocess
from utils import check_path, AttributeCopier, creat_cla
def check_status(result):
if result.returncode == 0:
print("\033[31m LOAD MODEL SUCCESS \033[0m")
else:
print(f"\033[31m ERROR: {result.stderr} \033[0m")
def import_caffe_network(name, netrans_path):
# 定义转换工具的路径
convert_caffe =netrans_path + " import caffe"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_prototxt_path = f"{name}.prototxt"
model_caffemodel_path = f"{name}.caffemodel"
# 打印转换信息
print(f"=========== Converting {name} Caffe model ===========")
# 构建转换命令
if os.path.isfile(model_caffemodel_path):
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--weights {model_caffemodel_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
else:
print("=========== fake Caffe model data file =============")
cmd = f"{convert_caffe} \
--model {model_prototxt_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
os.system(cmd)
def import_tensorflow_network(name, netrans_path):
# 定义转换工具的命令
convertf_cmd = f"{netrans_path} import tensorflow"
# 打印转换信息
print(f"=========== Converting {name} Tensorflow model ===========")
# 读取 inputs_outputs.txt 文件中的参数
with open('inputs_outputs.txt', 'r') as f:
inputs_outputs_params = f.read().strip()
# 构建转换命令
cmd = f"{convertf_cmd} \
--model {name}.pb \
--output-data {name}.data \
--output-model {name}.json \
{inputs_outputs_params}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_onnx_network(name, netrans_path):
# 定义转换工具的命令
convert_onnx_cmd = f"{netrans_path} import onnx"
# 打印转换信息
print(f"=========== Converting {name} ONNX model ===========")
if os.path.exists(f"{name}_outputs.txt"):
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
with open(output_path, 'r', encoding='utf-8') as file:
outputs = str(file.readline().strip())
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data \
--outputs '{outputs}'"
else:
# 构建转换命令
cmd = f"{convert_onnx_cmd} \
--model {name}.onnx \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
####### TFLITE
def import_tflite_network(name, netrans_path):
# 定义转换工具的路径或命令
convert_tflite = f"{netrans_path} import tflite"
# 定义模型文件路径
model_json_path = f"{name}.json"
model_data_path = f"{name}.data"
model_tflite_path = f"{name}.tflite"
# 打印转换信息
print(f"=========== Converting {name} TFLite model ===========")
# 构建转换命令
cmd = f"{convert_tflite} \
--model {model_tflite_path} \
--output-model {model_json_path} \
--output-data {model_data_path}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_darknet_network(name, netrans_path):
# 定义转换工具的命令
convert_darknet_cmd = f"{netrans_path} import darknet"
# 打印转换信息
print(f"=========== Converting {name} darknet model ===========")
# 构建转换命令
cmd = f"{convert_darknet_cmd} \
--model {name}.cfg \
--weight {name}.weights \
--output-model {name}.json \
--output-data {name}.data"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
def import_pytorch_network(name, netrans_path):
# 定义转换工具的命令
convert_pytorch_cmd = f"{netrans_path} import pytorch"
# 打印转换信息
print(f"=========== Converting {name} pytorch model ===========")
# 读取 input_size.txt 文件中的参数
try:
with open('input_size.txt', 'r') as file:
input_size_params = ' '.join(file.readlines())
except FileNotFoundError:
print("Error: input_size.txt not found.")
sys.exit(1)
# 构建转换命令
cmd = f"{convert_pytorch_cmd} \
--model {name}.pt \
--output-model {name}.json \
--output-data {name}.data \
{input_size_params}"
# 执行转换命令
# print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
check_status(result)
# 使用示例
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
class ImportModel(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
# print(source_obj.__dict__)
@check_path
def import_network(self):
if self.verbose is True :
print("begin load model")
# print(self.model_path)
print(os.getcwd())
print(f"{self.model_name}.weights")
name = self.model_name
netrans_path = self.netrans
if os.path.isfile(f"{name}.prototxt"):
import_caffe_network(name, netrans_path)
elif os.path.isfile(f"{name}.pb"):
import_tensorflow_network(name, netrans_path)
elif os.path.isfile(f"{name}.onnx"):
import_onnx_network(name, netrans_path)
elif os.path.isfile(f"{name}.tflite"):
import_tflite_network(name, netrans_path)
elif os.path.isfile(f"{name}.weights"):
import_darknet_network(name, netrans_path)
elif os.path.isfile(f"{name}.pt"):
import_pytorch_network(name, netrans_path)
else :
# print(os.getcwd())
print("=========== can not find suitable model files ===========")
sys.exit(-3)
# os.chdir("..")
def main():
if len(sys.argv) != 2 :
print("Input a network")
sys.exit(-1)
network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
clas = creat_cla(netrans_path, network_name,verbose=False)
func = ImportModel(clas)
func.import_network()
if __name__ == "__main__":
main()