netrans/netrans_py/netrans.py

143 lines
4.7 KiB
Python

import sys, os
import subprocess
# import yaml
from ruamel.yaml import YAML
from ruamel import yaml
import file_model
from import_model import ImportModel
from quantize import Quantize
from export import Export
from config import Config
# from utils import check_path
import warnings
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
class Netrans():
def __init__(self, model_path, netrans=None, verbose=False):
self.verbose = verbose
self.model_path = os.path.abspath(model_path)
self.set_netrans(netrans)
_, self.model_name = os.path.split(self.model_path)
# self.model_name,_ = os.path.splitext(self.model_name)
"""
pipe line
"""
def model2nbg(self, quantize_type, inputmeta=False, **kargs):
self.load_model()
self.config(inputmeta, **kargs)
self.quantize(quantize_type)
self.export()
"""
set netrans
"""
def get_os_netrans_path(self):
# print(os.environ.get('NETRANS_PATH'))
return os.environ.get('NETRANS_PATH')
def check_netarans(self):
res = subprocess.run([self.netrans], text=True)
if res.returncode != 0:
print("pleace check the netrans")
# return False
sys.exit()
else :
return
def set_netrans(self, netrans_path=None):
if netrans_path is not None :
netrans_path = os.path.abspath(netrans_path)
else :
netrans_path = self.get_os_netrans_path()
# print(netrans_path)
if os.path.exists(netrans_path):
self.netrans = os.path.join(netrans_path, 'pnnacc')
self.netrans_path = netrans_path
else :
print('NETRANS_PATH NOT BEEN SETTED')
"""
edit config
"""
# @check_path
def config(self, inputmeta=False, **kargs):
if isinstance(inputmeta, str):
self.input_meta = inputmeta
elif isinstance(inputmeta, bool):
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
if inputmeta is False : self.inputmeta_gen()
else :
sys.exit("check inputmeta file")
if len(kargs) == 0 : return
if kargs['mean']==0 and kargs['scale'] ==1 : return
if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
with open(self.input_meta,'r') as f :
yaml = YAML()
data = yaml.load(f)
data = self.upload_cfg(data,**kargs)
with open(self.input_meta,'w') as f :
yaml = YAML()
yaml.dump(data, f)
def upload_cfg(self, data, channel=3, **kargs):
grey = config['input_meta']['databases'][0]['ports'][0]['preprocess']['preproc_node_params'] == 'IMAGE_GRAY'
if kargs.get('mean') is not None:
mean = handel_param(kargs['mean'],grey)
self.upload_cfg_mean(data, mean)
if kargs.get('scale') is not None:
scale = handel_param(kargs['scale'],grey)
self.upload_cfg_scale(data, scale)
if kargs.get('reverse_channel') is not None:
if isinstance(kargs['reverse_channel'],bool):
self.upload_cfg_reverse_channel(data, kargs['reverse_channel'])
return data
def upload_cfg_mean(self, data, mean):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['mean'] = mean
def upload_cfg_scale(self, data, scale):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['scale'] = scale
def upload_cfg_reverse_channel(self, data, reverse_channel):
for db in data['input_meta']['databases']:
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
def load_model(self):
func = ImportModel(self)
func.import_network()
def inputmeta_gen(self):
func = Config(self)
func.inputmeta_gen()
def quantize(self, quantize_type):
self.quantize_type = quantize_type
func = Quantize(self)
func.quantize_network()
def export(self, **kargs):
if kargs.get('quantize_type') :
self.quantize_type = kargs['quantize_type']
func = Export(self)
func.export_network()
def handel_param(param, grey=False):
if grey : return param
else :
return param if isinstance(param, list) else [param]*3
if __name__ == '__main__':
network = '../../model_zoo/yolov4_tiny'
yolo = Netrans(network)
yolo.inputmeta_gen()
# yolo.model2nb("uint8")
# yolo.load_model()
# yolo.config(mean=[0,0,0],scale=1)
# yolo.quantize('uint8')
# yolo.export()