50 lines
1.6 KiB
C++
50 lines
1.6 KiB
C++
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
|
#include "QuantDialectBytecode.h"
|
|
#include "TypeDetail.h"
|
|
|
|
#include "mlir/Dialect/Quant/QuantTypes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::quant;
|
|
using namespace mlir::quant::detail;
|
|
|
|
#include "mlir/Dialect/Quant/QuantOpsDialect.cpp.inc"
|
|
|
|
void QuantizationDialect::initialize() {
|
|
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
|
|
UniformQuantizedPerAxisType>();
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
|
|
>();
|
|
addBytecodeInterface(this);
|
|
}
|
|
|
|
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
|
|
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
|
|
// value of x if the casts invert each other.
|
|
auto srcScastOp = getArg().getDefiningOp<StorageCastOp>();
|
|
if (!srcScastOp || srcScastOp.getArg().getType() != getType())
|
|
return OpFoldResult();
|
|
return srcScastOp.getArg();
|
|
}
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
|