netrans/bin/tools/netrans_ruler_generate_assi...

111 lines
5.1 KiB
Python

from argparse import ArgumentParser
import json
from collections import OrderedDict
from netranslib.netrans_smart_graph import smart_graph
def gen_input_map(src_in_anchor, acu_layers):
port_map = set()
input_map = list()
for flow in src_in_anchor:
src_tensor = flow[0]
if src_tensor in port_map:
continue
else:
port = len(port_map)
port_map.add(src_tensor)
input_map.append([src_tensor, acu_layers[0] + ':in' + str(port)])
return input_map
def gen_output_map(dst_out_tensor, acu_layers):
output_map = list()
for port, out_tensor in enumerate(dst_out_tensor):
output_map.append([out_tensor, acu_layers[-1] + ':out' + str(port)])
return output_map
def gen_acu_internal_flow(acu_layers):
if len(acu_layers) == 1:
return []
else:
internal_edge_in_default = list()
for id in range(len(acu_layers) - 1):
internal_edge_in_default.append([acu_layers[id] + ':out0', acu_layers[id + 1] + ':in0'])
return internal_edge_in_default
def main():
options = ArgumentParser(description='Netranslayer Ruler Generate Assistant tool.')
options.add_argument('--model',
required=True,
help='To generate the netransnet, should be .json')
options.add_argument('--model-data',
required=True,
help='To generate the netransnet, should be .data')
options.add_argument('--ruler-name',
default=None,
help='Provide a name for ruler')
options.add_argument('--inputs',
required=True,
help='User give input nodes, include these select nodes, it is name should choose from .json'
'file. eg. if in the json file the inputs tensor name is "@xxx:out0, @xxx:out1", and'
'the params value should be "xxx:out0 xxx:out1", split by space.')
options.add_argument('--outputs',
required=True,
help='User give output nodes, include these select nodes, it is name should choose from .json'
'file. eg. if in the json file the outputs tensor name is "@yyy:out0, @zzz:out0", and'
'the params value should be "yyy:out0 zzz:out0", split by space.')
options.add_argument('--netrans-alias',
default='noop',
help='Set netrans layer alias in select graph')
options.add_argument('--ruler-file',
default=None,
help='For ruler generate provide a predef file for this pb')
options.add_argument('--ruler-db-style',
default=True,
help='Specify to "False" if generate ruler for *_ruler_generate.py')
args = options.parse_args()
if args.ruler_file == None:
args.ruler_file = './gen_ruler_{}.json'.format(args.ruler_name)
sg = smart_graph(args.inputs.split(), args.outputs.split(), args.model, args.model_data)
ruler_dict = OrderedDict()
ruler_dict['ruler_name'] = args.ruler_name
ruler_dict['src_ops_alias'] = sg.gen_op_scan_order(use_alias=True)
# Build Internal Edge
ruler_dict['src_inter_flow'] = sg.gen_internal_flow(use_alias=True)
# TODO: modify it to input tensors
ruler_dict['src_in_anchor'] = sg.gen_in_flow(use_alias=True)
ruler_dict['src_out_tensor'] = sg.gen_out_tensor(use_alias=True)
ruler_dict['acu_lys_alias'] = args.netrans_alias.split()
ruler_dict['src_acu_in_tensor_map'] = gen_input_map(ruler_dict['src_in_anchor'], args.netrans_alias.split())
ruler_dict['src_acu_out_tensor_map'] = gen_output_map(ruler_dict['src_out_tensor'], args.netrans_alias.split())
ruler_dict['acu_inter_flow'] = gen_acu_internal_flow(args.netrans_alias.split())
ruler_dict['param_map'] = {acu_ly: dict() for acu_ly in args.netrans_alias.split()}
ruler_dict['blob_map'] = {acu_ly: dict() for acu_ly in args.netrans_alias.split()}
ruler_dict['priority_tip'] = 0
ruler_dict['pre_condition'] = None
ruler_dict['src_ops_main_version'] = None
ruler_dict['src_ops_minior_version'] = [1, -1]
string_result = json.dumps(ruler_dict)
alias_real_map_str = list()
for alia, real in zip(sg.gen_op_scan_order(use_alias=True), sg.gen_op_scan_order(use_alias=False)):
alias_real_map_str.append('{}:{}'.format(alia, real))
string_result = string_result + '\n#' + ';'.join(alias_real_map_str)
for key in ruler_dict.keys():
string_result = string_result.replace("\"{}\"".format(key), "\n\"{}\"".format(key))
if args.ruler_db_style is True:
# only for transform_graph API
string_json = '[' + string_result.split('#')[0] + ']'
else:
string_json = string_result.replace("\"pre_condition\": null", "\"pre_condition\": None")
string_json = string_json.replace("\"src_ops_main_version\": null", "\"src_ops_main_version\": None")
with open(args.ruler_file, 'w') as f:
f.write(string_json)
print(string_result)
if __name__ == '__main__':
main()