Update netrans.py
This commit is contained in:
parent
cab9cd5a5c
commit
76adccdf3c
|
@ -1,7 +1,7 @@
|
|||
import sys, os
|
||||
sys.path.append("/home/xj/work_program/1.deving/pnna_sdk/netrans/netrans_py")
|
||||
import subprocess
|
||||
from ruamel import yaml
|
||||
from ruamel.yaml import YAML
|
||||
import file_model
|
||||
from import_model import ImportModel
|
||||
from quantize import Quantize
|
||||
|
@ -71,9 +71,13 @@ class Netrans():
|
|||
|
||||
if len(kargs) == 0 : return
|
||||
if isinstance(kargs['mean'], list) or isinstance(kargs['scale'], (int, float)) or isinstance(kargs['reverse_channel'], bool):
|
||||
with open(self.input_meta,'r') as f :data = yaml.load(f)
|
||||
with open(self.input_meta,'r') as f :
|
||||
yaml = YAML()
|
||||
data = yaml.load(f)
|
||||
data = self.upload_cfg(data,**kargs)
|
||||
with open(self.input_meta,'w') as f :yaml.dump(data, f,default_style=False)
|
||||
with open(self.input_meta,'w') as f :
|
||||
yaml = YAML()
|
||||
yaml.dump(data, f,default_style=False)
|
||||
|
||||
def upload_cfg(self, data, channel=3, **kargs):
|
||||
if kargs.get('mean') is not None:
|
||||
|
|
Loading…
Reference in New Issue