81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
import sys
|
|
import os
|
|
# from functools import wraps
|
|
|
|
# def check_path(netrans, model_path):
|
|
# def decorator(func):
|
|
# @wraps(func)
|
|
# def wrapper(netrans, model_path, *args, **kargs):
|
|
# check_dir(model_path)
|
|
# check_netrans(netrans)
|
|
# if os.getcwd() != model_path :
|
|
# os.chdir(model_path)
|
|
# return func(netrans, model_path, *args, **kargs)
|
|
# return wrapper
|
|
# return decorator
|
|
|
|
def check_path(func):
|
|
def wrapper(cla, *args, **kargs):
|
|
check_netrans(cla.netrans)
|
|
if os.getcwd() != cla.model_path :
|
|
os.chdir(cla.model_path)
|
|
return func(cla, *args, **kargs)
|
|
return wrapper
|
|
|
|
|
|
def check_dir(network_name):
|
|
if not os.path.exists(network_name):
|
|
print(f"Directory {network_name} does not exist !")
|
|
sys.exit(-1)
|
|
os.chdir(network_name)
|
|
|
|
def check_netrans(netrans):
|
|
if 'NETRANS_PATH' not in os.environ :
|
|
return
|
|
if netrans != None and os.path.exists(netrans) is True:
|
|
return
|
|
print("Need to set enviroment variable NETRANS_PATH")
|
|
sys.exit(1)
|
|
|
|
def remove_history_file(name):
|
|
os.chdir(name)
|
|
if os.path.isfile(f"{name}.json"):
|
|
os.remove(f"{name}.json")
|
|
if os.path.isfile(f"{name}.data"):
|
|
os.remove(f"{name}.data")
|
|
os.chdir('..')
|
|
|
|
def check_env(name):
|
|
check_dir(name)
|
|
# check_netrans()
|
|
# remove_history_file(name)
|
|
|
|
|
|
class AttributeCopier:
|
|
def __init__(self, source_obj) -> None:
|
|
self.copy_attribute_name(source_obj)
|
|
|
|
def copy_attribute_name(self, source_obj):
|
|
for attribute_name in self._get_attribute_names(source_obj):
|
|
setattr(self, attribute_name, getattr(source_obj, attribute_name))
|
|
|
|
@staticmethod
|
|
def _get_attribute_names(source_obj):
|
|
return source_obj.__dict__.keys()
|
|
|
|
class creat_cla(): #dataclass @netrans_params
|
|
def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
|
|
self.netrans_path = netrans_path
|
|
self.netrans = os.path.join(self.netrans_path, 'pnnacc')
|
|
self.model_name=self.model_path = name
|
|
self.model_path = os.path.abspath(self.model_path)
|
|
self.verbose=verbose
|
|
self.quantize_type = quantized_type
|
|
|
|
if __name__ == "__main__":
|
|
dir_name = "yolo"
|
|
os.mkdir(dir_name)
|
|
check_dir(dir_name)
|
|
|
|
|