netrans/bin/tools/onnx_golden_dump2.py

151 lines
5.5 KiB
Python

#!/usr/bin/env python
# coding: utf-8
import os
import onnx
from onnx import helper
from onnx import TensorProto as OnnxType
import onnxruntime as ort
import numpy as np
from argparse import ArgumentParser
import tensorflow
main_version = tensorflow.__version__.split('.')[0]
if int(main_version) == 2:
import tensorflow.compat.v1 as tf
else:
import tensorflow as tf
_onnx_to_np_dtype_mapping = {
OnnxType.BFLOAT16: np.float16,
OnnxType.FLOAT16: np.float16,
OnnxType.FLOAT: np.float32,
OnnxType.BOOL: np.bool_,
OnnxType.INT16: np.int16,
OnnxType.INT32: np.int32,
OnnxType.INT64: np.int64,
OnnxType.UINT8: np.uint8,
OnnxType.UINT16: np.uint16,
OnnxType.UINT32: np.uint32,
OnnxType.UINT64: np.uint64,
}
def get_dtype(t_info):
return t_info.type.tensor_type.elem_type
def get_shape(t_info):
dim = t_info.type.tensor_type.shape.dim
shape = []
for i in range(len(dim)):
shape.append(int(dim[i].dim_value))
return shape
def load_image(image, onnx_dtype, shape):
data = None
extension = os.path.splitext(image)[1].lower()
if extension == '.npy':
data = np.load(image).astype(_onnx_to_np_dtype_mapping[onnx_dtype])
elif extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
height = shape[2] # shape is under NCHW layout
width = shape[3]
# decode the image file
img_tensor = tf.io.decode_image(tf.io.read_file(image))
img = img_tensor.numpy()
img_height = img.shape[0] # img_tensor is under HWC layout
img_width = img.shape[1]
if img_height != height or img_width != width:
img_tensor = tf.image.resize(img_tensor, (height, width))
img = img_tensor.numpy()
_fdata = np.expand_dims(img, axis=0).astype(np.float32) # HWC -> NHWC
_fdata = np.transpose(_fdata, [0, 3, 1, 2]) # NHWC -> NCHW
# Convert float value to dtype value
dtype = _onnx_to_np_dtype_mapping[onnx_dtype]
assert(dtype == np.float32)
data = _fdata.astype(dtype)
else:
raise NotImplementedError("DO NOT support input *{}".format(os.path.splitext(image)[-1]))
return data
def dump(model_path, input_image, detect_tensor, output_dir):
# insert temp output tensor into graph
temp_model = onnx.load(model_path)
if not detect_tensor:
# detect all tensor of all node
for onnx_node in temp_model.graph.node:
output_tensors = onnx_node.output
for output_tensor in output_tensors:
detect_tensor.append(output_tensor)
for name in detect_tensor:
intern_layer_value_info = helper.ValueInfoProto()
intern_layer_value_info.name = name
temp_model.graph.output.append(intern_layer_value_info)
onnx.save(temp_model, os.path.join(output_dir, 'temp.onnx'))
# load input data and build feed_dict
input_names = []
inputs = []
input_shapes = []
input_feed = {}
i = 0
initializer_names = [initializer.name for initializer in temp_model.graph.initializer]
for in_info in temp_model.graph.input:
if in_info.name in initializer_names:
continue
input_names.append(in_info.name)
inputs.append(load_image(input_image[i], get_dtype(in_info), get_shape(in_info)))
input_shapes.append(get_shape(in_info))
print("Dump", in_info.name, get_shape(in_info))
inputs[i].tofile(os.path.join(output_dir, "{}.tensor".format(in_info.name)), '\n')
input_feed[in_info.name] = inputs[i]
i = i + 1
# run temp model on onnxruntime
ort_model = ort.InferenceSession(os.path.join(output_dir, 'temp.onnx'))
outputs = ort_model.run(None, input_feed)
i = 0
for out_info in temp_model.graph.output:
print("Dump", out_info.name, outputs[i].shape)
_shape = [str(s) for s in outputs[i].shape]
_shape = '_'.join(_shape)
out_info.name = out_info.name.replace('/', '_').replace(':', '_').replace('@', '_')
outputs[i].tofile(os.path.join(output_dir, "{}_shape_{}.tensor".format(out_info.name, _shape)), '\n')
i = i + 1
print("Dump success, please visit '{}' to get dump result.".format(output_dir))
def main():
options = ArgumentParser(description='Dump onnx model golden.')
options.add_argument('--model',
required=True,
help="ONNX model file.")
options.add_argument('--input-image',
required=True,
help="ONNX model input image file. image order follow onnx model input tensors,\n"
"separated by space, such as \"a.npy b.npy\"")
options.add_argument('--detect-tensor',
default=None,
help="Specify the detect tensor in graph except inputs and outputs, separated by space,\n"
"such as \"tensor_x tensor_y\". Default dump all tensors")
options.add_argument('--output-dir',
default='./onnx_dump',
help="Specify the output dir of dump result, default by './onnx_dump'")
args = options.parse_args()
model_name = args.model
input_image = args.input_image.split(' ')
if args.detect_tensor is not None:
detect_tensor = list(args.detect_tensor.split(' '))
else:
detect_tensor = []
output_dir = args.output_dir
os.path.exists(output_dir) or os.mkdir(output_dir)
dump(model_name, input_image, detect_tensor, output_dir)
if __name__ == '__main__':
main()