226 lines
6.8 KiB
Python
226 lines
6.8 KiB
Python
import os
|
|
import sys
|
|
import subprocess
|
|
from utils import check_path, AttributeCopier, creat_cla
|
|
|
|
def check_status(result):
|
|
if result.returncode == 0:
|
|
print("\033[31m LOAD MODEL SUCCESS \033[0m")
|
|
else:
|
|
print(f"\033[31m ERROR: {result.stderr} \033[0m")
|
|
|
|
|
|
def import_caffe_network(name, netrans_path):
|
|
# 定义转换工具的路径
|
|
convert_caffe =netrans_path + " import caffe"
|
|
|
|
# 定义模型文件路径
|
|
model_json_path = f"{name}.json"
|
|
model_data_path = f"{name}.data"
|
|
model_prototxt_path = f"{name}.prototxt"
|
|
model_caffemodel_path = f"{name}.caffemodel"
|
|
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} Caffe model ===========")
|
|
|
|
# 构建转换命令
|
|
if os.path.isfile(model_caffemodel_path):
|
|
cmd = f"{convert_caffe} \
|
|
--model {model_prototxt_path} \
|
|
--weights {model_caffemodel_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
else:
|
|
print("=========== fake Caffe model data file =============")
|
|
cmd = f"{convert_caffe} \
|
|
--model {model_prototxt_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
os.system(cmd)
|
|
|
|
def import_tensorflow_network(name, netrans_path):
|
|
# 定义转换工具的命令
|
|
convertf_cmd = f"{netrans_path} import tensorflow"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} Tensorflow model ===========")
|
|
|
|
# 读取 inputs_outputs.txt 文件中的参数
|
|
with open('inputs_outputs.txt', 'r') as f:
|
|
inputs_outputs_params = f.read().strip()
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convertf_cmd} \
|
|
--model {name}.pb \
|
|
--output-data {name}.data \
|
|
--output-model {name}.json \
|
|
{inputs_outputs_params}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
def import_onnx_network(name, netrans_path):
|
|
# 定义转换工具的命令
|
|
convert_onnx_cmd = f"{netrans_path} import onnx"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} ONNX model ===========")
|
|
if os.path.exists(f"{name}_outputs.txt"):
|
|
output_path = os.path.join(os.getcwd(), name+"_outputs.txt")
|
|
with open(output_path, 'r', encoding='utf-8') as file:
|
|
outputs = str(file.readline().strip())
|
|
cmd = f"{convert_onnx_cmd} \
|
|
--model {name}.onnx \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data \
|
|
--outputs '{outputs}'"
|
|
else:
|
|
# 构建转换命令
|
|
cmd = f"{convert_onnx_cmd} \
|
|
--model {name}.onnx \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
####### TFLITE
|
|
def import_tflite_network(name, netrans_path):
|
|
# 定义转换工具的路径或命令
|
|
convert_tflite = f"{netrans_path} import tflite"
|
|
|
|
# 定义模型文件路径
|
|
model_json_path = f"{name}.json"
|
|
model_data_path = f"{name}.data"
|
|
model_tflite_path = f"{name}.tflite"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} TFLite model ===========")
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_tflite} \
|
|
--model {model_tflite_path} \
|
|
--output-model {model_json_path} \
|
|
--output-data {model_data_path}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
|
|
def import_darknet_network(name, netrans_path):
|
|
# 定义转换工具的命令
|
|
convert_darknet_cmd = f"{netrans_path} import darknet"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} darknet model ===========")
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_darknet_cmd} \
|
|
--model {name}.cfg \
|
|
--weight {name}.weights \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
def import_pytorch_network(name, netrans_path):
|
|
# 定义转换工具的命令
|
|
convert_pytorch_cmd = f"{netrans_path} import pytorch"
|
|
|
|
# 打印转换信息
|
|
print(f"=========== Converting {name} pytorch model ===========")
|
|
|
|
# 读取 input_size.txt 文件中的参数
|
|
try:
|
|
with open('input_size.txt', 'r') as file:
|
|
input_size_params = ' '.join(file.readlines())
|
|
except FileNotFoundError:
|
|
print("Error: input_size.txt not found.")
|
|
sys.exit(1)
|
|
|
|
# 构建转换命令
|
|
cmd = f"{convert_pytorch_cmd} \
|
|
--model {name}.pt \
|
|
--output-model {name}.json \
|
|
--output-data {name}.data \
|
|
{input_size_params}"
|
|
|
|
# 执行转换命令
|
|
# print(cmd)
|
|
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
|
|
# 检查执行结果
|
|
check_status(result)
|
|
|
|
# 使用示例
|
|
# import_tensorflow_network('model_name', '/path/to/NETRANS_PATH')
|
|
class ImportModel(AttributeCopier):
|
|
def __init__(self, source_obj) -> None:
|
|
super().__init__(source_obj)
|
|
# print(source_obj.__dict__)
|
|
|
|
@check_path
|
|
def import_network(self):
|
|
if self.verbose is True :
|
|
print("begin load model")
|
|
# print(self.model_path)
|
|
print(os.getcwd())
|
|
print(f"{self.model_name}.weights")
|
|
name = self.model_name
|
|
netrans_path = self.netrans
|
|
if os.path.isfile(f"{name}.prototxt"):
|
|
import_caffe_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.pb"):
|
|
import_tensorflow_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.onnx"):
|
|
import_onnx_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.tflite"):
|
|
import_tflite_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.weights"):
|
|
import_darknet_network(name, netrans_path)
|
|
elif os.path.isfile(f"{name}.pt"):
|
|
import_pytorch_network(name, netrans_path)
|
|
else :
|
|
# print(os.getcwd())
|
|
print("=========== can not find suitable model files ===========")
|
|
sys.exit(-3)
|
|
# os.chdir("..")
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) != 2 :
|
|
print("Input a network")
|
|
sys.exit(-1)
|
|
|
|
network_name = sys.argv[1]
|
|
# check_env(network_name)
|
|
|
|
netrans_path = os.environ['NETRANS_PATH']
|
|
# netrans = os.path.join(netrans_path, 'pnnacc')
|
|
clas = creat_cla(netrans_path, network_name,verbose=False)
|
|
func = ImportModel(clas)
|
|
func.import_network()
|
|
if __name__ == "__main__":
|
|
main()
|