netrans/netrans_py/utils.py

81 lines
2.3 KiB
Python

import sys
import os
# from functools import wraps
# def check_path(netrans, model_path):
# def decorator(func):
# @wraps(func)
# def wrapper(netrans, model_path, *args, **kargs):
# check_dir(model_path)
# check_netrans(netrans)
# if os.getcwd() != model_path :
# os.chdir(model_path)
# return func(netrans, model_path, *args, **kargs)
# return wrapper
# return decorator
def check_path(func):
def wrapper(cla, *args, **kargs):
check_netrans(cla.netrans)
if os.getcwd() != cla.model_path :
os.chdir(cla.model_path)
return func(cla, *args, **kargs)
return wrapper
def check_dir(network_name):
if not os.path.exists(network_name):
print(f"Directory {network_name} does not exist !")
sys.exit(-1)
os.chdir(network_name)
def check_netrans(netrans):
if 'NETRANS_PATH' not in os.environ :
return
if netrans != None and os.path.exists(netrans) is True:
return
print("Need to set enviroment variable NETRANS_PATH")
sys.exit(1)
def remove_history_file(name):
os.chdir(name)
if os.path.isfile(f"{name}.json"):
os.remove(f"{name}.json")
if os.path.isfile(f"{name}.data"):
os.remove(f"{name}.data")
os.chdir('..')
def check_env(name):
check_dir(name)
# check_netrans()
# remove_history_file(name)
class AttributeCopier:
def __init__(self, source_obj) -> None:
self.copy_attribute_name(source_obj)
def copy_attribute_name(self, source_obj):
for attribute_name in self._get_attribute_names(source_obj):
setattr(self, attribute_name, getattr(source_obj, attribute_name))
@staticmethod
def _get_attribute_names(source_obj):
return source_obj.__dict__.keys()
class creat_cla(): #dataclass @netrans_params
def __init__(self, netrans_path, name, quantized_type = 'uint8',verbose=False) -> None:
self.netrans_path = netrans_path
self.netrans = os.path.join(self.netrans_path, 'pnnacc')
self.model_name=self.model_path = name
self.model_path = os.path.abspath(self.model_path)
self.verbose=verbose
self.quantize_type = quantized_type
if __name__ == "__main__":
dir_name = "yolo"
os.mkdir(dir_name)
check_dir(dir_name)