netrans/netrans_py/quantize_hb.py

92 lines
3.0 KiB
Python

import os
import sys
from utils import check_path, AttributeCopier, creat_cla
class Quantize(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def quantize_network(self):
netrans = self.netrans
quantized_type = self.quantize_type
name = self.model_name
# check_env(name)
# print(os.getcwd())
netrans += " quantize"
# 根据量化类型设置量化参数
if quantized_type == 'float':
print("=========== do not need quantized===========")
return
elif quantized_type == 'uint8':
quantization_type = "asymmetric_affine"
elif quantized_type == 'int8':
quantization_type = "dynamic_fixed_point-8"
elif quantized_type == 'int16':
quantization_type = "dynamic_fixed_point-16"
else:
print("=========== wrong quantization_type ! ( uint8 / int8 / int16 )===========")
return
# 输出量化信息
print(" =======================================================================")
print(f" ==== Start Quantizing {name} model with type of {quantization_type} ===")
print(" =======================================================================")
# 移除已存在的量化文件
quantize_file = f"{name}_{quantization_type}.quantize"
current_directory = os.getcwd()
txt_path = current_directory+"/dataset.txt"
with open(txt_path, 'r', encoding='utf-8') as file:
num_lines = len(file.readlines())
# 构建并执行量化命令
cmd = f"{netrans} \
--qtype {quantized_type} \
--hybrid \
--quantizer {quantization_type.split('-')[0]} \
--model-quantize {quantize_file} \
--model {name}.json \
--model-data {name}.data \
--with-input-meta {name}_inputmeta.yml \
--device CPU \
--algorithm kl_divergence \
--divergence-nbins 2048 \
--iterations {num_lines}"
os.system(cmd)
# 检查量化结果
if os.path.exists(quantize_file):
print("\033[31m QUANTIZED SUCCESS \033[0m")
else:
print("\033[31m ERROR ! \033[0m")
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( uint8 / int8 / int16 )")
sys.exit(-1)
# 检查网络目录是否存在
network_name = sys.argv[1]
# 定义 netrans 路径
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
# network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
quantize_type = sys.argv[2]
cla = creat_cla(netrans_path, network_name,quantize_type)
# 调用量化函数
run = Quantize(cla)
run.quantize_network()
if __name__ == "__main__":
main()