netrans/netrans_py/gen_inputmeta.py

38 lines
1.0 KiB
Python

import os
import sys
from utils import check_path, AttributeCopier, creat_cla
class InputmetaGen(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def inputmeta_gen(self):
netrans_path = self.netrans
network_name = self.model_name
# 进入网络名称指定的目录
# os.chdir(network_name)
# check_env(network_name)
# 执行 pegasus 命令
os.system(f"{netrans_path} generate inputmeta --model {network_name}.json --separated-database")
# os.chdir("..")
def main():
# 检查命令行参数数量是否正确
if len(sys.argv) != 2:
print("Enter a network name!")
sys.exit(2)
# 检查提供的目录是否存在
network_name = sys.argv[1]
# 构建 netrans 可执行文件的路径
netrans_path =os.getenv('NETRANS_PATH')
cla = creat_cla(netrans_path, network_name)
func = InputmetaGen(cla)
func.inputmeta_gen()
if __name__ == '__main__':
main()