netrans/bin/tools/pytorch_golden_dump.py

132 lines
5.2 KiB
Python

import torch
import tensorflow as tf
import numpy as np
import os
import sys
from argparse import ArgumentParser
def model_optimize(model):
pass_dict = {
'jit_pass_remove_inplace_ops': lambda g: torch._C._jit_pass_remove_inplace_ops(g),
'jit_pass_inline': lambda g: torch._C._jit_pass_inline(g),
}
for k, p in pass_dict.items():
pass_dict[k](model.graph)
def load_image(image, shape):
extension = os.path.splitext(image)[1].lower()
if extension == '.npy':
data = np.load(image).astype(np.float32)
elif extension == '.tensor':
data = np.loadtxt(image).reshape(shape).astype(np.float32)
elif extension in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']:
height = shape[2] # shape is under NCHW layout
width = shape[3]
# decode the image file
img_tensor = tf.io.decode_image(tf.io.read_file(image))
img = img_tensor.numpy()
img_height = img.shape[0] # img_tensor is under HWC layout
img_width = img.shape[1]
if img_height != height or img_width != width:
img_tensor = tf.image.resize(img_tensor, (height, width))
img = img_tensor.numpy()
_fdata = np.expand_dims(img, axis=0).astype(np.float32) # HWC -> NHWC
_fdata = np.transpose(_fdata, [0, 3, 1, 2]) # NHWC -> NCHW
data = _fdata.astype(np.float32)
else:
raise NotImplementedError("DO NOT support input *{}".format(os.path.splitext(image)[-1]))
return data
def dump(model_path, input_image, input_shapes, detect_tensor, output_dir):
# insert temp output tensor into graph
model = torch.jit.load(model_path)
model_optimize(model)
if not detect_tensor:
# detect all tensor of all node
for node in model.graph.nodes():
if node.kind() in ['prim::Constant', 'prim::GetAttr', 'prim::ListConstruct']:
continue
for o in node.outputs():
model.graph.registerOutput(o)
else:
for node in model.graph.nodes():
for o in node.outputs():
if o.debugName() in detect_tensor:
model.graph.registerOutput(o)
detect_tensor.remove(o.debugName())
if len(detect_tensor) == 0:
break
assert len(detect_tensor) == 0, '{} not in the graph, please have a check!'.format(','.join(detect_tensor))
output_names = [out.debugName() for out in model.graph.outputs()]
model.graph.makeMultiOutputIntoTuple()
model.save(os.path.join(output_dir, 'temp.pt'))
# load input data
inputs = []
input_names = [arg.name for arg in model.forward.schema.arguments[1:]] # the first argument is self
assert len(input_names) == len(input_image), "input images are not provided enough"
for i in range(len(input_image)):
inputs.append(torch.Tensor(load_image(input_image[i], input_shapes[i])))
fwd_args = dict()
for i, name in enumerate(input_names):
fwd_args.update({name: inputs[i]})
# run model
output_tensors = model(**fwd_args)
outputs = [out_tensor.detach().numpy() for out_tensor in output_tensors]
for i, out_name in enumerate(output_names):
print("Dump", out_name, outputs[i].shape)
_shape = [str(s) for s in outputs[i].shape]
_shape = '_'.join(_shape)
out_name = out_name.replace('/', '_')
outputs[i].tofile(os.path.join(output_dir, "{}_shape_{}.tensor".format(out_name, _shape)), '\n')
print("Dump success, please visit '{}' to get dump result.".format(output_dir))
def main():
options = ArgumentParser(description='Dump pytorch model golden.')
options.add_argument('--model',
required=True,
help="Pytorch model file.")
options.add_argument('--input-image',
required=True,
help="Pytorch model input image file. image order follow pytorch model input tensors,\n"
"separated by space, such as \"a.npy b.npy\"")
options.add_argument('--input-size-list',
required=True,
help="Pytorch model input shapes, seperated by '#'")
options.add_argument('--detect-tensor',
default=None,
help="Specify the detect tensor in graph except inputs and outputs, separated by space,\n"
"such as \"tensor_x tensor_y\". Default dump all tensors")
options.add_argument('--output-dir',
default='./pytorch_dump',
help="Specify the output dir of dump result, default by './pytorch_dump'")
args = options.parse_args()
model_name = args.model
input_image = args.input_image.split(' ')
input_size_list = args.input_size_list.split('#')
input_shapes = list()
for shape in input_size_list:
input_shapes.append([int(s) for s in shape.split(',')])
if args.detect_tensor is not None:
detect_tensor = list(args.detect_tensor.split(' '))
else:
detect_tensor = []
output_dir = args.output_dir
os.path.exists(output_dir) or os.mkdir(output_dir)
dump(model_name, input_image, input_shapes, detect_tensor, output_dir)
if __name__ == '__main__':
main()