138 lines
4.6 KiB
Python
138 lines
4.6 KiB
Python
import sys, os
|
|
sys.path.append("/path/to/pnna_sdk/netrans/netrans_py")
|
|
import subprocess
|
|
from ruamel.yaml import YAML
|
|
import file_model
|
|
from import_model import ImportModel
|
|
from quantize import Quantize
|
|
from export import Export
|
|
from inputmeta_gen import InputmetaGen
|
|
# from utils import check_path
|
|
import warnings
|
|
warnings.simplefilter('ignore', yaml.error.UnsafeLoaderWarning)
|
|
class Netrans():
|
|
|
|
def __init__(self, model_path, netrans=None, verbose=False):
|
|
self.verbose = verbose
|
|
self.model_path = os.path.abspath(model_path)
|
|
self.set_netrans(netrans)
|
|
_, self.model_name = os.path.split(self.model_path)
|
|
# self.model_name,_ = os.path.splitext(self.model_name)
|
|
|
|
|
|
"""
|
|
pipe line
|
|
"""
|
|
def model2nb(self, quantize_type, inputmeta=False, **kargs):
|
|
self.load_model()
|
|
self.config(inputmeta, **kargs)
|
|
self.quantize(quantize_type)
|
|
self.export()
|
|
|
|
"""
|
|
set netrans
|
|
"""
|
|
def get_os_netrans_path(self):
|
|
# print(os.environ.get('NETRANS_PATH'))
|
|
return os.environ.get('NETRANS_PATH')
|
|
|
|
def check_netarans(self):
|
|
res = subprocess.run([self.netrans], text=True)
|
|
if res.returncode != 0:
|
|
print("pleace check the netrans")
|
|
# return False
|
|
sys.exit()
|
|
else :
|
|
return
|
|
|
|
def set_netrans(self, netrans_path=None):
|
|
if netrans_path is not None :
|
|
netrans_path = os.path.abspath(netrans_path)
|
|
else :
|
|
netrans_path = self.get_os_netrans_path()
|
|
# print(netrans_path)
|
|
if os.path.exists(netrans_path):
|
|
self.netrans = os.path.join(netrans_path, 'pnnacc')
|
|
self.netrans_path = netrans_path
|
|
else :
|
|
print('NETRANS_PATH NOT BEEN SETTED')
|
|
"""
|
|
edit config
|
|
"""
|
|
# @check_path
|
|
def config(self, inputmeta=False, **kargs):
|
|
if isinstance(inputmeta, str):
|
|
self.input_meta = inputmeta
|
|
elif isinstance(inputmeta, bool):
|
|
self.input_meta = os.path.join(self.model_path,'%s%s'%(self.model_name, file_model.extensions.input_meta))
|
|
if inputmeta is False : self.inputmeta_gen()
|
|
else :
|
|
sys.exit("check inputmeta file")
|
|
|
|
if len(kargs) == 0 : return
|
|
if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
|
|
with open(self.input_meta,'r') as f :
|
|
yaml = YAML()
|
|
data = yaml.load(f)
|
|
data = self.upload_cfg(data,**kargs)
|
|
with open(self.input_meta,'w') as f :
|
|
yaml = YAML()
|
|
yaml.dump(data, f,default_style=False)
|
|
|
|
def upload_cfg(self, data, channel=3, **kargs):
|
|
if kargs.get('mean') is not None:
|
|
mean = handel_param(kargs['mean'])
|
|
self.upload_cfg_mean(data, mean)
|
|
if kargs.get('scale') is not None:
|
|
scale = handel_param(kargs['scale'])
|
|
self.upload_cfg_scale(data, scale)
|
|
if kargs.get('reverse_channel') is not None:
|
|
if isinstance(kargs['reverse_channel'],bool):
|
|
self.upload_cfg_reverse_channel(data, kargs['reverse_channel'])
|
|
return data
|
|
|
|
def upload_cfg_mean(self, data, mean):
|
|
for db in data['input_meta']['databases']:
|
|
db['ports'][0]['preprocess']['mean'] = mean
|
|
def upload_cfg_scale(self, data, scale):
|
|
for db in data['input_meta']['databases']:
|
|
db['ports'][0]['preprocess']['scale'] = scale
|
|
def upload_cfg_reverse_channel(self, data, reverse_channel):
|
|
for db in data['input_meta']['databases']:
|
|
db['ports'][0]['preprocess']['reverse_channel'] = reverse_channel
|
|
|
|
def load_model(self):
|
|
func = ImportModel(self)
|
|
func.import_network()
|
|
|
|
def inputmeta_gen(self):
|
|
func = InputmetaGen(self)
|
|
func.inputmeta_gen()
|
|
|
|
def quantize(self, quantize_type):
|
|
self.quantize_type = quantize_type
|
|
func = Quantize(self)
|
|
func.quantize_network()
|
|
|
|
def export(self, **kargs):
|
|
if kargs.get('quantize_type') :
|
|
self.quantize_type = kargs['quantize_type']
|
|
func = Export(self)
|
|
func.export_network()
|
|
|
|
|
|
|
|
def handel_param(param):
|
|
return param if isinstance(param, list) else [param]*3
|
|
|
|
|
|
if __name__ == '__main__':
|
|
network = '../../model_zoo/yolov4_tiny'
|
|
yolo = Netrans(network)
|
|
yolo.inputmeta_gen()
|
|
# yolo.model2nb("uint8")
|
|
# yolo.load_model()
|
|
# yolo.config(mean=[0,0,0],scale=1)
|
|
# yolo.quantize('uint8')
|
|
# yolo.export()
|