132 lines
5.2 KiB
Python
132 lines
5.2 KiB
Python
import torch
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
def model_optimize(model):
|
|
pass_dict = {
|
|
'jit_pass_remove_inplace_ops': lambda g: torch._C._jit_pass_remove_inplace_ops(g),
|
|
'jit_pass_inline': lambda g: torch._C._jit_pass_inline(g),
|
|
}
|
|
for k, p in pass_dict.items():
|
|
pass_dict[k](model.graph)
|
|
|
|
|
|
def load_image(image, shape):
|
|
extension = os.path.splitext(image)[1].lower()
|
|
if extension == '.npy':
|
|
data = np.load(image).astype(np.float32)
|
|
elif extension == '.tensor':
|
|
data = np.loadtxt(image).reshape(shape).astype(np.float32)
|
|
elif extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
|
|
height = shape[2] # shape is under NCHW layout
|
|
width = shape[3]
|
|
|
|
# decode the image file
|
|
img_tensor = tf.io.decode_image(tf.io.read_file(image))
|
|
img = img_tensor.numpy()
|
|
img_height = img.shape[0] # img_tensor is under HWC layout
|
|
img_width = img.shape[1]
|
|
if img_height != height or img_width != width:
|
|
img_tensor = tf.image.resize(img_tensor, (height, width))
|
|
img = img_tensor.numpy()
|
|
_fdata = np.expand_dims(img, axis=0).astype(np.float32) # HWC -> NHWC
|
|
_fdata = np.transpose(_fdata, [0, 3, 1, 2]) # NHWC -> NCHW
|
|
|
|
data = _fdata.astype(np.float32)
|
|
else:
|
|
raise NotImplementedError("DO NOT support input *{}".format(os.path.splitext(image)[-1]))
|
|
return data
|
|
|
|
|
|
def dump(model_path, input_image, input_shapes, detect_tensor, output_dir):
|
|
# insert temp output tensor into graph
|
|
model = torch.jit.load(model_path)
|
|
model_optimize(model)
|
|
if not detect_tensor:
|
|
# detect all tensor of all node
|
|
for node in model.graph.nodes():
|
|
if node.kind() in ['prim::Constant', 'prim::GetAttr', 'prim::ListConstruct']:
|
|
continue
|
|
for o in node.outputs():
|
|
model.graph.registerOutput(o)
|
|
else:
|
|
for node in model.graph.nodes():
|
|
for o in node.outputs():
|
|
if o.debugName() in detect_tensor:
|
|
model.graph.registerOutput(o)
|
|
detect_tensor.remove(o.debugName())
|
|
if len(detect_tensor) == 0:
|
|
break
|
|
assert len(detect_tensor) == 0, '{} not in the graph, please have a check!'.format(','.join(detect_tensor))
|
|
output_names = [out.debugName() for out in model.graph.outputs()]
|
|
model.graph.makeMultiOutputIntoTuple()
|
|
model.save(os.path.join(output_dir, 'temp.pt'))
|
|
|
|
# load input data
|
|
inputs = []
|
|
input_names = [arg.name for arg in model.forward.schema.arguments[1:]] # the first argument is self
|
|
assert len(input_names) == len(input_image), "input images are not provided enough"
|
|
|
|
for i in range(len(input_image)):
|
|
inputs.append(torch.Tensor(load_image(input_image[i], input_shapes[i])))
|
|
|
|
fwd_args = dict()
|
|
for i, name in enumerate(input_names):
|
|
fwd_args.update({name: inputs[i]})
|
|
# run model
|
|
output_tensors = model(**fwd_args)
|
|
outputs = [out_tensor.detach().numpy() for out_tensor in output_tensors]
|
|
|
|
for i, out_name in enumerate(output_names):
|
|
print("Dump", out_name, outputs[i].shape)
|
|
_shape = [str(s) for s in outputs[i].shape]
|
|
_shape = '_'.join(_shape)
|
|
out_name = out_name.replace('/', '_')
|
|
outputs[i].tofile(os.path.join(output_dir, "{}_shape_{}.tensor".format(out_name, _shape)), '\n')
|
|
|
|
print("Dump success, please visit '{}' to get dump result.".format(output_dir))
|
|
|
|
def main():
|
|
options = ArgumentParser(description='Dump pytorch model golden.')
|
|
options.add_argument('--model',
|
|
required=True,
|
|
help="Pytorch model file.")
|
|
options.add_argument('--input-image',
|
|
required=True,
|
|
help="Pytorch model input image file. image order follow pytorch model input tensors,\n"
|
|
"separated by space, such as \"a.npy b.npy\"")
|
|
options.add_argument('--input-size-list',
|
|
required=True,
|
|
help="Pytorch model input shapes, seperated by '#'")
|
|
options.add_argument('--detect-tensor',
|
|
default=None,
|
|
help="Specify the detect tensor in graph except inputs and outputs, separated by space,\n"
|
|
"such as \"tensor_x tensor_y\". Default dump all tensors")
|
|
options.add_argument('--output-dir',
|
|
default='./pytorch_dump',
|
|
help="Specify the output dir of dump result, default by './pytorch_dump'")
|
|
args = options.parse_args()
|
|
|
|
model_name = args.model
|
|
input_image = args.input_image.split(' ')
|
|
input_size_list = args.input_size_list.split('#')
|
|
input_shapes = list()
|
|
for shape in input_size_list:
|
|
input_shapes.append([int(s) for s in shape.split(',')])
|
|
if args.detect_tensor is not None:
|
|
detect_tensor = list(args.detect_tensor.split(' '))
|
|
else:
|
|
detect_tensor = []
|
|
output_dir = args.output_dir
|
|
os.path.exists(output_dir) or os.mkdir(output_dir)
|
|
|
|
dump(model_name, input_image, input_shapes, detect_tensor, output_dir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|