38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
|
|
import os
|
|
import sys
|
|
from utils import check_path, AttributeCopier, creat_cla
|
|
|
|
class InputmetaGen(AttributeCopier):
|
|
def __init__(self, source_obj) -> None:
|
|
super().__init__(source_obj)
|
|
|
|
@check_path
|
|
def inputmeta_gen(self):
|
|
netrans_path = self.netrans
|
|
network_name = self.model_name
|
|
# 进入网络名称指定的目录
|
|
# os.chdir(network_name)
|
|
# check_env(network_name)
|
|
|
|
# 执行 pegasus 命令
|
|
os.system(f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database")
|
|
# os.chdir("..")
|
|
|
|
def main():
|
|
# 检查命令行参数数量是否正确
|
|
if len(sys.argv) != 2:
|
|
print("Enter a network name!")
|
|
sys.exit(2)
|
|
|
|
# 检查提供的目录是否存在
|
|
network_name = sys.argv[1]
|
|
# 构建 netrans 可执行文件的路径
|
|
netrans_path =os.getenv('NETRANS_PATH')
|
|
cla = creat_cla(netrans_path, network_name)
|
|
func = InputmetaGen(cla)
|
|
func.inputmeta_gen()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |