92 lines
3.0 KiB
Python
92 lines
3.0 KiB
Python
import os
|
|
import sys
|
|
from utils import check_path, AttributeCopier, creat_cla
|
|
|
|
class Quantize(AttributeCopier):
|
|
def __init__(self, source_obj) -> None:
|
|
super().__init__(source_obj)
|
|
|
|
@check_path
|
|
def quantize_network(self):
|
|
netrans = self.netrans
|
|
quantized_type = self.quantize_type
|
|
name = self.model_name
|
|
# check_env(name)
|
|
# print(os.getcwd())
|
|
netrans += " quantize"
|
|
# 根据量化类型设置量化参数
|
|
if quantized_type == 'float':
|
|
print("=========== do not need quantized===========")
|
|
return
|
|
elif quantized_type == 'uint8':
|
|
quantization_type = "asymmetric_affine"
|
|
elif quantized_type == 'int8':
|
|
quantization_type = "dynamic_fixed_point-8"
|
|
elif quantized_type == 'int16':
|
|
quantization_type = "dynamic_fixed_point-16"
|
|
else:
|
|
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
|
|
return
|
|
|
|
# 输出量化信息
|
|
print(" =======================================================================")
|
|
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
|
|
print(" =======================================================================")
|
|
|
|
# 移除已存在的量化文件
|
|
quantize_file = f"{name}_{quantization_type}.quantize"
|
|
current_directory = os.getcwd()
|
|
txt_path = current_directory+"/dataset.txt"
|
|
with open(txt_path, 'r', encoding='utf-8') as file:
|
|
num_lines = len(file.readlines())
|
|
|
|
|
|
# 构建并执行量化命令
|
|
cmd = f"{netrans} \
|
|
--qtype {quantized_type} \
|
|
--hybrid \
|
|
--quantizer {quantization_type.split('-')[0]} \
|
|
--model-quantize {quantize_file} \
|
|
--model {name}.json \
|
|
--model-data {name}.data \
|
|
--with-input-meta {name}_inputmeta.yml \
|
|
--device CPU \
|
|
--algorithm kl_divergence \
|
|
--divergence-nbins 2048 \
|
|
--iterations {num_lines}"
|
|
|
|
os.system(cmd)
|
|
|
|
# 检查量化结果
|
|
if os.path.exists(quantize_file):
|
|
print("\033[31m QUANTIZED SUCCESS \033[0m")
|
|
else:
|
|
print("\033[31m ERROR ! \033[0m")
|
|
|
|
|
|
def main():
|
|
# 检查命令行参数数量
|
|
if len(sys.argv) < 3:
|
|
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
|
|
sys.exit(-1)
|
|
|
|
# 检查网络目录是否存在
|
|
network_name = sys.argv[1]
|
|
|
|
# 定义 netrans 路径
|
|
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
|
|
# network_name = sys.argv[1]
|
|
# check_env(network_name)
|
|
|
|
netrans_path = os.environ['NETRANS_PATH']
|
|
# netrans = os.path.join(netrans_path, 'pnnacc')
|
|
quantize_type = sys.argv[2]
|
|
cla = creat_cla(netrans_path, network_name,quantize_type)
|
|
|
|
# 调用量化函数
|
|
run = Quantize(cla)
|
|
run.quantize_network()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|