[mlir][tosa] Flip accessors used to prefixed form (NFC)
Follow up from dialect flip, just flipping accessors. Both forms still generated.
This commit is contained in:
parent
475a39fbc3
commit
13448db06a
|
@ -28,7 +28,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ConstOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.value());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -66,8 +66,8 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value value = op.value();
|
||||
Value multiplier32 = op.multiplier();
|
||||
Value value = op.getValue();
|
||||
Value multiplier32 = op.getMultiplier();
|
||||
|
||||
Type resultTy = op.getType();
|
||||
Type valueTy = value.getType();
|
||||
|
@ -78,7 +78,7 @@ public:
|
|||
Value one64 = getConstantValue(loc, i64Ty, 1, rewriter);
|
||||
Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter);
|
||||
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
||||
|
||||
// Compute the multiplication in 64-bits then select the high / low parts.
|
||||
Value value64 = rewriter.create<arith::ExtSIOp>(loc, i64Ty, value);
|
||||
|
@ -94,7 +94,7 @@ public:
|
|||
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
|
||||
|
||||
// Apply double rounding if necessary.
|
||||
if (op.double_round()) {
|
||||
if (op.getDoubleRound()) {
|
||||
int64_t roundInt = 1 << 30;
|
||||
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
|
||||
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
|
||||
|
@ -129,14 +129,14 @@ public:
|
|||
Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy);
|
||||
Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy);
|
||||
|
||||
Value value = op.value();
|
||||
Value value = op.getValue();
|
||||
if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value value32 = op.value();
|
||||
Value multiplier32 = op.multiplier();
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.shift());
|
||||
Value value32 = op.getValue();
|
||||
Value multiplier32 = op.getMultiplier();
|
||||
Value shift32 = rewriter.create<arith::ExtUIOp>(loc, i32Ty, op.getShift());
|
||||
|
||||
// Constants used during the scaling operation.
|
||||
Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter);
|
||||
|
@ -176,7 +176,7 @@ public:
|
|||
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
|
||||
|
||||
// Conditionally perform our double round.
|
||||
if (op.double_round()) {
|
||||
if (op.getDoubleRound()) {
|
||||
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
|
||||
Value valuePositive = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, value32, zero32);
|
||||
|
|
|
@ -77,7 +77,7 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|||
|
||||
// tosa::MulOp
|
||||
if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
|
||||
if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
|
||||
if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
|
||||
(void)rewriter.notifyMatchFailure(op,
|
||||
"Cannot have shift value for float");
|
||||
return nullptr;
|
||||
|
@ -137,15 +137,15 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
|
|||
return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
|
||||
|
||||
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
|
||||
!cast<tosa::NegateOp>(op).quantization_info()) {
|
||||
!cast<tosa::NegateOp>(op).getQuantizationInfo()) {
|
||||
auto constant =
|
||||
rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
|
||||
return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
|
||||
}
|
||||
|
||||
if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
|
||||
cast<tosa::NegateOp>(op).quantization_info()) {
|
||||
auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
|
||||
cast<tosa::NegateOp>(op).getQuantizationInfo()) {
|
||||
auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
|
||||
int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
|
||||
int64_t inZp = quantizationInfo.value().getInputZp();
|
||||
int64_t outZp = quantizationInfo.value().getOutputZp();
|
||||
|
@ -978,7 +978,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
|
@ -1021,7 +1021,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
|
@ -1065,7 +1065,7 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
|
||||
ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
|
||||
ShapedType resultTy = reshape.getType().template cast<ShapedType>();
|
||||
bool isDynamic = !operandTy.hasStaticShape();
|
||||
|
||||
|
@ -1086,7 +1086,7 @@ public:
|
|||
reshape.getLoc(),
|
||||
RankedTensorType::get(intermediateShape,
|
||||
reshape.getType().getElementType()),
|
||||
adaptor.input1());
|
||||
adaptor.getInput1());
|
||||
Value expand =
|
||||
rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
|
||||
rewriter.replaceOp(reshape, expand);
|
||||
|
@ -1102,7 +1102,7 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
DenseIntElementsAttr perms;
|
||||
if (!matchPattern(op.perms(), m_Constant(&perms))) {
|
||||
if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -1136,7 +1136,7 @@ public:
|
|||
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
|
||||
|
||||
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
|
||||
op, resultTy, op.input1(), ValueRange{initTensor}, affineMaps,
|
||||
op, resultTy, op.getInput1(), ValueRange{initTensor}, affineMaps,
|
||||
getNParallelLoopsAttrs(resultTy.getRank()),
|
||||
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||
|
@ -1152,28 +1152,28 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::RescaleOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
auto input = op.input();
|
||||
auto inputTy = op.input().getType().cast<ShapedType>();
|
||||
auto outputTy = op.output().getType().cast<ShapedType>();
|
||||
auto input = op.getInput();
|
||||
auto inputTy = op.getInput().getType().cast<ShapedType>();
|
||||
auto outputTy = op.getOutput().getType().cast<ShapedType>();
|
||||
unsigned rank = inputTy.getRank();
|
||||
|
||||
// This is an illegal configuration. terminate and log an error
|
||||
if (op.double_round() && !op.scale32())
|
||||
if (op.getDoubleRound() && !op.getScale32())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "tosa.rescale requires scale32 for double_round to be true");
|
||||
|
||||
auto dynamicDimsOr =
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
|
||||
if (!dynamicDimsOr.has_value())
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = dynamicDimsOr.value();
|
||||
|
||||
// The shift and multiplier values.
|
||||
SmallVector<int32_t> multiplierValues;
|
||||
getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues);
|
||||
getValuesFromIntArrayAttribute(op.getMultiplier(), multiplierValues);
|
||||
|
||||
SmallVector<int8_t> shiftValues;
|
||||
getValuesFromIntArrayAttribute(op.shift(), shiftValues);
|
||||
getValuesFromIntArrayAttribute(op.getShift(), shiftValues);
|
||||
|
||||
// If we shift by more than the bitwidth, this just sets to 0.
|
||||
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
|
||||
|
@ -1186,7 +1186,7 @@ public:
|
|||
// Double round only occurs if shift is greater than 31, check that this
|
||||
// is ever true.
|
||||
bool doubleRound =
|
||||
op.double_round() &&
|
||||
op.getDoubleRound() &&
|
||||
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
|
@ -1346,7 +1346,7 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::ResizeOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
auto input = op.input();
|
||||
auto input = op.getInput();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
auto resultElementTy = resultTy.getElementType();
|
||||
|
@ -1355,12 +1355,12 @@ public:
|
|||
auto imageW = inputTy.getShape()[2];
|
||||
|
||||
auto dynamicDimsOr =
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
|
||||
if (!dynamicDimsOr.has_value())
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = dynamicDimsOr.value();
|
||||
|
||||
if (op.mode() != "NEAREST_NEIGHBOR" && op.mode() != "BILINEAR")
|
||||
if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
|
||||
return failure();
|
||||
|
||||
auto initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
|
@ -1394,19 +1394,19 @@ public:
|
|||
Value inX =
|
||||
rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
|
||||
|
||||
int32_t shift = op.shift();
|
||||
int32_t shift = op.getShift();
|
||||
bool floatingPointMode = shift == 0;
|
||||
|
||||
Value yStride, xStride, yOffset, xOffset;
|
||||
if (floatingPointMode) {
|
||||
yStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[0]);
|
||||
xStride = rewriter.create<arith::ConstantOp>(loc, op.stride_fp()[1]);
|
||||
yOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[0]);
|
||||
xOffset = rewriter.create<arith::ConstantOp>(loc, op.offset_fp()[1]);
|
||||
yStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[0]);
|
||||
xStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[1]);
|
||||
yOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[0]);
|
||||
xOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[1]);
|
||||
} else {
|
||||
SmallVector<int32_t> stride, offset;
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.offset(), offset);
|
||||
getValuesFromIntArrayAttribute(op.getStride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.getOffset(), offset);
|
||||
|
||||
yStride = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(stride[0]));
|
||||
|
@ -1463,7 +1463,7 @@ public:
|
|||
dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
|
||||
}
|
||||
|
||||
if (op.mode() == "NEAREST_NEIGHBOR") {
|
||||
if (op.getMode() == "NEAREST_NEIGHBOR") {
|
||||
Value yPred, xPred;
|
||||
// Round the index position towards the closest pixel location.
|
||||
if (floatingPointMode) {
|
||||
|
@ -1516,7 +1516,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
if (op.mode() == "BILINEAR") {
|
||||
if (op.getMode() == "BILINEAR") {
|
||||
Value y0 = iy;
|
||||
Value x0 = ix;
|
||||
|
||||
|
@ -1634,7 +1634,7 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(SrcOp reduceOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.axis(), rewriter);
|
||||
return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1648,7 +1648,7 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
|
|||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
int axis = op.axis();
|
||||
int axis = op.getAxis();
|
||||
Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, rewriter.getIndexAttr(axis));
|
||||
int rank = resultType.getRank();
|
||||
|
@ -1713,10 +1713,10 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::ReverseOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value input = op.getInput();
|
||||
auto inputTy = input.getType().template cast<ShapedType>();
|
||||
auto resultTy = op.getType().template cast<ShapedType>();
|
||||
auto axis = op.axis();
|
||||
auto axis = op.getAxis();
|
||||
|
||||
SmallVector<Value> dynDims;
|
||||
for (int i = 0; i < inputTy.getRank(); i++) {
|
||||
|
@ -1775,7 +1775,7 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
|
|||
matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto input = op.input1();
|
||||
auto input = op.getInput1();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
auto inputShape = inputTy.getShape();
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
|
@ -1783,7 +1783,7 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
|
|||
int64_t rank = inputTy.getRank();
|
||||
|
||||
SmallVector<int64_t> multiples;
|
||||
getValuesFromIntArrayAttribute(op.multiples(), multiples);
|
||||
getValuesFromIntArrayAttribute(op.getMultiples(), multiples);
|
||||
|
||||
// Broadcast the newly added dimensions to their appropriate multiple.
|
||||
SmallVector<int64_t, 2> genericShape;
|
||||
|
@ -1837,8 +1837,8 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::PadOp padOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = padOp.getLoc();
|
||||
auto input = padOp.input1();
|
||||
auto padding = padOp.padding();
|
||||
auto input = padOp.getInput1();
|
||||
auto padding = padOp.getPadding();
|
||||
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
Type elementTy = inputTy.getElementType();
|
||||
|
@ -1848,17 +1848,17 @@ public:
|
|||
|
||||
Value padConstant;
|
||||
|
||||
if (padOp.pad_const()) {
|
||||
if (padOp.getPadConst()) {
|
||||
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
|
||||
loc, padOp.pad_const(), ValueRange({}));
|
||||
loc, padOp.getPadConst(), ValueRange({}));
|
||||
} else {
|
||||
Attribute constantAttr;
|
||||
if (elementTy.isa<FloatType>()) {
|
||||
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
||||
} else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
|
||||
} else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
||||
} else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
|
||||
int64_t value = padOp.quantization_info()->getInputZp();
|
||||
} else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
|
||||
int64_t value = padOp.getQuantizationInfo()->getInputZp();
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, value);
|
||||
}
|
||||
if (constantAttr)
|
||||
|
@ -1926,12 +1926,12 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = argmaxOp.getLoc();
|
||||
Value input = argmaxOp.input();
|
||||
Value input = argmaxOp.getInput();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
|
||||
auto resultTy = argmaxOp.getOutput().getType().cast<ShapedType>();
|
||||
auto inElementTy = inputTy.getElementType();
|
||||
auto outElementTy = resultTy.getElementType();
|
||||
int axis = argmaxOp.axis();
|
||||
int axis = argmaxOp.getAxis();
|
||||
auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
|
||||
|
||||
if (!outElementTy.isa<IntegerType>())
|
||||
|
@ -2049,8 +2049,8 @@ public:
|
|||
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
|
||||
auto dynamicDimsOr =
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, indices, op.output()});
|
||||
auto dynamicDimsOr = checkHasDynamicBatchDims(
|
||||
rewriter, op, {input, indices, op.getOutput()});
|
||||
if (!dynamicDimsOr.has_value())
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = dynamicDimsOr.value();
|
||||
|
@ -2101,8 +2101,8 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::TableOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value table = op.table();
|
||||
Value input = op.getInput();
|
||||
Value table = op.getTable();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
auto tableTy = table.getType().cast<ShapedType>();
|
||||
auto resultTy = op.getType().cast<ShapedType>();
|
||||
|
|
|
@ -533,21 +533,21 @@ public:
|
|||
.create<linalg::FillOp>(loc, ValueRange{zero},
|
||||
ValueRange{initTensor})
|
||||
.result();
|
||||
if (!op.quantization_info()) {
|
||||
if (!op.getQuantizationInfo()) {
|
||||
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
|
||||
op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
|
||||
ValueRange{zeroTensor});
|
||||
op, TypeRange{op.getType()},
|
||||
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
|
||||
return success();
|
||||
}
|
||||
|
||||
auto quantizationInfo = *op.quantization_info();
|
||||
auto quantizationInfo = *op.getQuantizationInfo();
|
||||
auto aZp = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
|
||||
auto bZp = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
|
||||
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
|
||||
op, TypeRange{op.getType()},
|
||||
ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
|
||||
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -562,12 +562,12 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
auto outputTy = op.getType().cast<ShapedType>();
|
||||
auto input = op.input();
|
||||
auto input = op.getInput();
|
||||
auto inputTy = input.getType().cast<ShapedType>();
|
||||
|
||||
auto bias = op.bias();
|
||||
auto bias = op.getBias();
|
||||
|
||||
auto weight = op.weight();
|
||||
auto weight = op.getWeight();
|
||||
auto weightTy = weight.getType().cast<ShapedType>();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
|
@ -627,7 +627,7 @@ public:
|
|||
outputTy.getShape(), outputETy)
|
||||
->getResults();
|
||||
|
||||
if (!op.quantization_info()) {
|
||||
if (!op.getQuantizationInfo()) {
|
||||
Value matmul = rewriter
|
||||
.create<linalg::MatmulOp>(
|
||||
loc, TypeRange{op.getType()},
|
||||
|
@ -650,7 +650,7 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
auto quantizationInfo = *op.quantization_info();
|
||||
auto quantizationInfo = *op.getQuantizationInfo();
|
||||
auto inputZp = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
|
||||
auto outputZp = rewriter.create<arith::ConstantOp>(
|
||||
|
@ -686,14 +686,14 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value input = op.getInput();
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
|
||||
ShapedType resultTy = op.getType().template cast<ShapedType>();
|
||||
Type resultETy = inputTy.getElementType();
|
||||
|
||||
auto dynamicDimsOr =
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
|
||||
if (!dynamicDimsOr.has_value())
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = dynamicDimsOr.value();
|
||||
|
@ -718,15 +718,15 @@ public:
|
|||
// Apply padding as necessary.
|
||||
llvm::SmallVector<int64_t> pad;
|
||||
pad.resize(2, 0);
|
||||
getValuesFromIntArrayAttribute(op.pad(), pad);
|
||||
getValuesFromIntArrayAttribute(op.getPad(), pad);
|
||||
pad.resize(pad.size() + 2, 0);
|
||||
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
|
||||
|
||||
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
|
||||
|
||||
SmallVector<int64_t> kernel, stride;
|
||||
getValuesFromIntArrayAttribute(op.kernel(), kernel);
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.getKernel(), kernel);
|
||||
getValuesFromIntArrayAttribute(op.getStride(), stride);
|
||||
|
||||
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
|
||||
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
|
||||
|
@ -758,7 +758,7 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
Value input = op.getInput();
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
Type inElementTy = inputTy.getElementType();
|
||||
|
||||
|
@ -770,7 +770,7 @@ public:
|
|||
ShapedType accTy = resultTy.clone(accETy);
|
||||
|
||||
auto dynamicDimsOr =
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.output()});
|
||||
checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
|
||||
if (!dynamicDimsOr.has_value())
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = dynamicDimsOr.value();
|
||||
|
@ -778,7 +778,7 @@ public:
|
|||
// Apply padding as necessary.
|
||||
llvm::SmallVector<int64_t> pad;
|
||||
pad.resize(2, 0);
|
||||
getValuesFromIntArrayAttribute(op.pad(), pad);
|
||||
getValuesFromIntArrayAttribute(op.getPad(), pad);
|
||||
pad.resize(pad.size() + 2, 0);
|
||||
Attribute padAttr = rewriter.getZeroAttr(inElementTy);
|
||||
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
|
||||
|
@ -787,8 +787,8 @@ public:
|
|||
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
|
||||
|
||||
SmallVector<int64_t> kernel, stride;
|
||||
getValuesFromIntArrayAttribute(op.kernel(), kernel);
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.getKernel(), kernel);
|
||||
getValuesFromIntArrayAttribute(op.getStride(), stride);
|
||||
|
||||
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
|
||||
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
|
||||
|
@ -889,8 +889,8 @@ public:
|
|||
|
||||
// If we have quantization information we need to apply an offset
|
||||
// for the input zp value.
|
||||
if (op.quantization_info()) {
|
||||
auto quantizationInfo = *op.quantization_info();
|
||||
if (op.getQuantizationInfo()) {
|
||||
auto quantizationInfo = *op.getQuantizationInfo();
|
||||
auto inputZp = rewriter.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
|
||||
Value offset =
|
||||
|
@ -925,8 +925,8 @@ public:
|
|||
|
||||
// If we have quantization information we need to apply output
|
||||
// zeropoint.
|
||||
if (op.quantization_info()) {
|
||||
auto quantizationInfo = *op.quantization_info();
|
||||
if (op.getQuantizationInfo()) {
|
||||
auto quantizationInfo = *op.getQuantizationInfo();
|
||||
auto outputZp = rewriter.create<arith::ConstantOp>(
|
||||
loc, b.getIntegerAttr(scaled.getType(),
|
||||
quantizationInfo.getOutputZp()));
|
||||
|
|
|
@ -32,7 +32,7 @@ static void inlineIfCase(Region &srcRegion, Region &dstRegion,
|
|||
|
||||
auto yield = cast<YieldOp>(headBlock->getTerminator());
|
||||
rewriter.setInsertionPoint(yield);
|
||||
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
|
||||
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
|
||||
rewriter.eraseOp(yield);
|
||||
|
||||
headBlock->eraseArguments(
|
||||
|
@ -55,7 +55,7 @@ static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
|
|||
headBlock->getArguments());
|
||||
} else {
|
||||
rewriter.setInsertionPoint(yield);
|
||||
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
|
||||
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.getInputs());
|
||||
}
|
||||
rewriter.eraseOp(yield);
|
||||
}
|
||||
|
@ -68,13 +68,14 @@ public:
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::IfOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto condition = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.cond());
|
||||
auto condition =
|
||||
rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
|
||||
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
|
||||
condition, true);
|
||||
|
||||
inlineIfCase(op.then_branch(), newIf.getThenRegion(), op.inputs(),
|
||||
inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
|
||||
rewriter);
|
||||
inlineIfCase(op.else_branch(), newIf.getElseRegion(), op.inputs(),
|
||||
inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
|
||||
rewriter);
|
||||
|
||||
rewriter.replaceOp(op, newIf.getResults());
|
||||
|
@ -89,12 +90,12 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::WhileOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto newWhile = rewriter.create<scf::WhileOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.inputs());
|
||||
op.getLoc(), op.getResultTypes(), op.getInputs());
|
||||
rewriter.createBlock(&newWhile.getBefore());
|
||||
rewriter.createBlock(&newWhile.getAfter());
|
||||
|
||||
inlineWhileCase(op.cond(), newWhile.getBefore(), rewriter, true);
|
||||
inlineWhileCase(op.body(), newWhile.getAfter(), rewriter, false);
|
||||
inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
|
||||
inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
|
||||
|
||||
rewriter.replaceOp(op, newWhile.getResults());
|
||||
|
||||
|
|
|
@ -29,10 +29,10 @@ public:
|
|||
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = sliceOp.getLoc();
|
||||
Value input = sliceOp.input();
|
||||
Value input = sliceOp.getInput();
|
||||
SmallVector<int64_t> strides;
|
||||
auto starts = sliceOp.start();
|
||||
auto sizes = sliceOp.size();
|
||||
auto starts = sliceOp.getStart();
|
||||
auto sizes = sliceOp.getSize();
|
||||
strides.resize(sliceOp.getType().template cast<ShapedType>().getRank(), 1);
|
||||
|
||||
SmallVector<Value> dynSizes;
|
||||
|
|
|
@ -38,17 +38,17 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ConcatOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.input1().size() != 1)
|
||||
if (op.getInput1().size() != 1)
|
||||
return failure();
|
||||
if (op.input1().front().getType() != op.getType()) {
|
||||
if (op.getInput1().front().getType() != op.getType()) {
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
|
||||
op.input1().front())
|
||||
op.getInput1().front())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, op.input1().front());
|
||||
rewriter.replaceOp(op, op.getInput1().front());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -63,14 +63,14 @@ struct ReshapeReshapeOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input1();
|
||||
Value input = op.getInput1();
|
||||
Operation *definingOp = input.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (tosa::ReshapeOp reshapeOp = dyn_cast<tosa::ReshapeOp>(definingOp)) {
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, op.getType(), reshapeOp.input1(), op.new_shape());
|
||||
op, op.getType(), reshapeOp.getInput1(), op.getNewShape());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -83,8 +83,8 @@ struct ReshapeConstOptimization : public OpRewritePattern<tosa::ReshapeOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ReshapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input1();
|
||||
ArrayAttr newShape = op.new_shape();
|
||||
Value input = op.getInput1();
|
||||
ArrayAttr newShape = op.getNewShape();
|
||||
|
||||
// Check if input is constant
|
||||
DenseElementsAttr inputAttr;
|
||||
|
@ -118,12 +118,12 @@ void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
}
|
||||
|
||||
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
||||
auto notOp = op.pred().getDefiningOp<tosa::LogicalNotOp>();
|
||||
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
||||
if (!notOp)
|
||||
return failure();
|
||||
rewriter.updateRootInPlace(op, [&]() {
|
||||
op.getOperation()->setOperands(
|
||||
{notOp.input1(), op.on_false(), op.on_true()});
|
||||
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto perm = op.perms();
|
||||
auto perm = op.getPerms();
|
||||
|
||||
DenseIntElementsAttr permAttr;
|
||||
if (!matchPattern(perm, m_Constant(&permAttr))) {
|
||||
|
@ -150,7 +150,7 @@ struct NoOpOptimization : public OpRewritePattern<tosa::TransposeOp> {
|
|||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, op.input1());
|
||||
rewriter.replaceOp(op, op.getInput1());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -165,15 +165,15 @@ struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::AddOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input1 = op.input1();
|
||||
auto input2 = op.input2();
|
||||
auto input1 = op.getInput1();
|
||||
auto input2 = op.getInput2();
|
||||
|
||||
DenseElementsAttr input1Attr;
|
||||
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
|
||||
input2.getType() == op.getType()) {
|
||||
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
|
||||
input1Attr.getSplatValue<APInt>().isZero()) {
|
||||
rewriter.replaceOp(op, op.input2());
|
||||
rewriter.replaceOp(op, op.getInput2());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ struct AddZeroOptimization : public OpRewritePattern<tosa::AddOp> {
|
|||
input1.getType() == op.getType()) {
|
||||
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
|
||||
input2Attr.getSplatValue<APInt>().isZero()) {
|
||||
rewriter.replaceOp(op, op.input1());
|
||||
rewriter.replaceOp(op, op.getInput1());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -202,21 +202,21 @@ struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::MulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input1 = op.input1();
|
||||
auto input2 = op.input2();
|
||||
auto input1 = op.getInput1();
|
||||
auto input2 = op.getInput2();
|
||||
|
||||
DenseElementsAttr input1Attr;
|
||||
if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() &&
|
||||
input2.getType() == op.getType()) {
|
||||
if (input1Attr.getType().getElementType().isa<FloatType>() &&
|
||||
input1Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
|
||||
rewriter.replaceOp(op, op.input2());
|
||||
rewriter.replaceOp(op, op.getInput2());
|
||||
return success();
|
||||
}
|
||||
|
||||
if (input1Attr.getType().getElementType().isa<IntegerType>() &&
|
||||
matchPattern(input1, m_One())) {
|
||||
rewriter.replaceOp(op, op.input2());
|
||||
rewriter.replaceOp(op, op.getInput2());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -226,13 +226,13 @@ struct MulOneOptimization : public OpRewritePattern<tosa::MulOp> {
|
|||
input1.getType() == op.getType()) {
|
||||
if (input2Attr.getType().getElementType().isa<FloatType>() &&
|
||||
input2Attr.getSplatValue<APFloat>().isExactlyValue(1)) {
|
||||
rewriter.replaceOp(op, op.input1());
|
||||
rewriter.replaceOp(op, op.getInput1());
|
||||
return success();
|
||||
}
|
||||
|
||||
if (input2Attr.getType().getElementType().isa<IntegerType>() &&
|
||||
matchPattern(input2, m_One())) {
|
||||
rewriter.replaceOp(op, op.input1());
|
||||
rewriter.replaceOp(op, op.getInput1());
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
@ -251,11 +251,11 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::PadOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (op.pad_const())
|
||||
if (op.getPadConst())
|
||||
return failure();
|
||||
|
||||
auto input = op.input1();
|
||||
auto padding = op.padding();
|
||||
auto input = op.getInput1();
|
||||
auto padding = op.getPadding();
|
||||
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
Type elementTy = inputTy.getElementType();
|
||||
|
@ -263,10 +263,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
|
|||
Attribute constantAttr;
|
||||
if (elementTy.isa<FloatType>()) {
|
||||
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
||||
} else if (elementTy.isa<IntegerType>() && !op.quantization_info()) {
|
||||
} else if (elementTy.isa<IntegerType>() && !op.getQuantizationInfo()) {
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
||||
} else if (elementTy.isa<IntegerType>() && op.quantization_info()) {
|
||||
auto value = op.quantization_info()->getInputZp();
|
||||
} else if (elementTy.isa<IntegerType>() && op.getQuantizationInfo()) {
|
||||
auto value = op.getQuantizationInfo()->getInputZp();
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, value);
|
||||
}
|
||||
|
||||
|
@ -298,8 +298,8 @@ struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value output = op.output();
|
||||
Value input = op.getInput();
|
||||
Value output = op.getOutput();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType outputType = output.getType().cast<ShapedType>();
|
||||
|
||||
|
@ -333,8 +333,9 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||
Value input = op.getInput();
|
||||
auto inputType =
|
||||
op.getInput().getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputElementType = inputType.getElementType();
|
||||
|
||||
if (!inputType.hasStaticShape()) {
|
||||
|
@ -342,8 +343,8 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|||
}
|
||||
|
||||
if (inputElementType.isF32()) {
|
||||
auto minClamp = op.min_fp();
|
||||
auto maxClamp = op.max_fp();
|
||||
auto minClamp = op.getMinFp();
|
||||
auto maxClamp = op.getMaxFp();
|
||||
bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) &&
|
||||
minClamp.isNegative();
|
||||
bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) &&
|
||||
|
@ -357,8 +358,8 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|||
}
|
||||
|
||||
if (inputElementType.isUnsignedInteger()) {
|
||||
int64_t minClamp = op.min_int();
|
||||
int64_t maxClamp = op.max_int();
|
||||
int64_t minClamp = op.getMinInt();
|
||||
int64_t maxClamp = op.getMaxInt();
|
||||
|
||||
int64_t intMin =
|
||||
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
|
||||
|
@ -375,8 +376,8 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|||
}
|
||||
|
||||
if (inputElementType.isa<IntegerType>()) {
|
||||
int64_t minClamp = op.min_int();
|
||||
int64_t maxClamp = op.max_int();
|
||||
int64_t minClamp = op.getMinInt();
|
||||
int64_t maxClamp = op.getMaxInt();
|
||||
|
||||
int64_t intMin =
|
||||
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
|
||||
|
@ -401,21 +402,22 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value input = op.getInput();
|
||||
|
||||
Operation *definingOp = input.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
|
||||
auto minFp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat();
|
||||
auto maxFp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat();
|
||||
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
|
||||
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
|
||||
|
||||
auto minInt = std::max(op.min_int(), clampOp.min_int());
|
||||
auto maxInt = std::min(op.max_int(), clampOp.max_int());
|
||||
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
|
||||
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
|
||||
op, op.getType(), clampOp.input(), rewriter.getI64IntegerAttr(minInt),
|
||||
op, op.getType(), clampOp.getInput(),
|
||||
rewriter.getI64IntegerAttr(minInt),
|
||||
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
|
||||
rewriter.getF32FloatAttr(maxFp));
|
||||
return success();
|
||||
|
@ -436,14 +438,14 @@ void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (input().getType() == getType())
|
||||
return input();
|
||||
if (getInput().getType() == getType())
|
||||
return getInput();
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
return valueAttr();
|
||||
return getValueAttr();
|
||||
}
|
||||
|
||||
#define REDUCE_FOLDER(OP) \
|
||||
|
@ -465,12 +467,12 @@ REDUCE_FOLDER(ReduceSumOp)
|
|||
#undef REDUCE_FOLDER
|
||||
|
||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto inputTy = input1().getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
|
||||
auto outputTy = getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!inputTy || !outputTy || inputTy != outputTy)
|
||||
return {};
|
||||
return input1();
|
||||
return getInput1();
|
||||
}
|
||||
|
||||
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
@ -478,7 +480,7 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (operands[1]) {
|
||||
auto densePad = operands[1].cast<DenseElementsAttr>();
|
||||
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
||||
return input1();
|
||||
return getInput1();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -486,20 +488,20 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto inputTy = input().getType().dyn_cast<RankedTensorType>();
|
||||
auto inputTy = getInput().getType().dyn_cast<RankedTensorType>();
|
||||
auto outputTy = getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!inputTy || !outputTy || inputTy != outputTy)
|
||||
return {};
|
||||
if (inputTy.hasStaticShape())
|
||||
return input();
|
||||
return getInput();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (on_true() == on_false())
|
||||
return on_true();
|
||||
if (getOnTrue() == getOnFalse())
|
||||
return getOnTrue();
|
||||
|
||||
auto predicate = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||
if (!predicate)
|
||||
|
@ -507,18 +509,18 @@ OpFoldResult tosa::SelectOp::fold(ArrayRef<Attribute> operands) {
|
|||
|
||||
if (!predicate.isSplat())
|
||||
return {};
|
||||
return predicate.getSplatValue<APInt>().getBoolValue() ? on_true()
|
||||
: on_false();
|
||||
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
|
||||
: getOnFalse();
|
||||
}
|
||||
|
||||
OpFoldResult TileOp::fold(ArrayRef<Attribute> operands) {
|
||||
bool allOnes = true;
|
||||
for (Attribute val : multiples().getValue()) {
|
||||
for (Attribute val : getMultiples().getValue()) {
|
||||
allOnes = allOnes && val.cast<IntegerAttr>().getValue().getSExtValue() == 1;
|
||||
}
|
||||
|
||||
if (allOnes && input1().getType() == getType())
|
||||
return input1();
|
||||
if (allOnes && getInput1().getType() == getType())
|
||||
return getInput1();
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -537,7 +539,7 @@ OpFoldResult TransposeOp::fold(ArrayRef<Attribute> operands) {
|
|||
[](const APInt &val) { return val.getSExtValue(); }));
|
||||
|
||||
if (llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms) &&
|
||||
input1().getType() == getType())
|
||||
return input1();
|
||||
getInput1().getType() == getType())
|
||||
return getInput1();
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ struct TosaInlinerInterface : public DialectInlinerInterface {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns the while loop body.
|
||||
Region &tosa::WhileOp::getLoopBody() { return body(); }
|
||||
Region &tosa::WhileOp::getLoopBody() { return getBody(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tosa dialect initialization.
|
||||
|
@ -101,16 +101,18 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
|||
template <typename T>
|
||||
static LogicalResult verifyConvOp(T op) {
|
||||
// All TOSA conv ops have an input() and weight().
|
||||
auto inputType = op.input().getType().template dyn_cast<RankedTensorType>();
|
||||
auto weightType = op.weight().getType().template dyn_cast<RankedTensorType>();
|
||||
auto inputType =
|
||||
op.getInput().getType().template dyn_cast<RankedTensorType>();
|
||||
auto weightType =
|
||||
op.getWeight().getType().template dyn_cast<RankedTensorType>();
|
||||
|
||||
// Must be ranked tensor types
|
||||
if (!inputType) {
|
||||
op.emitOpError("expect a ranked tensor for input, got ") << op.input();
|
||||
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
|
||||
return failure();
|
||||
}
|
||||
if (!weightType) {
|
||||
op.emitOpError("expect a ranked tensor for weight, got ") << op.weight();
|
||||
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -130,8 +132,8 @@ static LogicalResult verifyConvOp(T op) {
|
|||
|
||||
// Quantized type must have constructed the quantizationattr, and unquantized
|
||||
// types should not have a quantizationattr.
|
||||
if ((inputIsQuant && !op.quantization_info()) ||
|
||||
(!inputIsQuant && op.quantization_info())) {
|
||||
if ((inputIsQuant && !op.getQuantizationInfo()) ||
|
||||
(!inputIsQuant && op.getQuantizationInfo())) {
|
||||
op.emitOpError("quantizationattr is required for quantized type, and not "
|
||||
"allowed for float type");
|
||||
return failure();
|
||||
|
@ -141,7 +143,7 @@ static LogicalResult verifyConvOp(T op) {
|
|||
}
|
||||
|
||||
LogicalResult tosa::AvgPool2dOp::verify() {
|
||||
auto inputETy = input().getType().cast<ShapedType>().getElementType();
|
||||
auto inputETy = getInput().getType().cast<ShapedType>().getElementType();
|
||||
auto resultETy = getType().cast<ShapedType>().getElementType();
|
||||
|
||||
if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
|
||||
|
@ -538,7 +540,7 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
|
|||
MLIRContext *context, ::llvm::Optional<Location> location,
|
||||
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
ArrayAttr sizes = SliceOpAdaptor(operands, attributes).size();
|
||||
ArrayAttr sizes = SliceOpAdaptor(operands, attributes).getSize();
|
||||
SmallVector<int64_t> outputShape;
|
||||
outputShape.reserve(sizes.size());
|
||||
for (auto val : sizes) {
|
||||
|
@ -570,7 +572,7 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
|
|||
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
TileOpAdaptor adaptor(operands, attributes);
|
||||
ArrayAttr multiples = adaptor.multiples();
|
||||
ArrayAttr multiples = adaptor.getMultiples();
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
SmallVector<int64_t> outputShape;
|
||||
if (!inputShape.hasRank()) {
|
||||
|
@ -606,7 +608,7 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
|
|||
ReshapeOpAdaptor adaptor(operands, attributes);
|
||||
ShapeAdaptor inputShape = operands.getShape(0);
|
||||
|
||||
ArrayAttr newShape = adaptor.new_shape();
|
||||
ArrayAttr newShape = adaptor.getNewShape();
|
||||
llvm::SmallVector<int64_t> newShapeValue;
|
||||
getI64Values(newShape, newShapeValue);
|
||||
|
||||
|
@ -741,7 +743,7 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
|||
int32_t inHeight = ShapedType::kDynamicSize;
|
||||
int32_t inWidth = ShapedType::kDynamicSize;
|
||||
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
outputShape[3] = inputShape.getDimSize(3);
|
||||
|
@ -750,9 +752,9 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
|||
inWidth = inputShape.getDimSize(2);
|
||||
}
|
||||
|
||||
int32_t shift = adaptor.shift();
|
||||
int32_t shift = adaptor.getShift();
|
||||
llvm::SmallVector<int64_t> newShape;
|
||||
getI64Values(adaptor.output_size(), newShape);
|
||||
getI64Values(adaptor.getOutputSize(), newShape);
|
||||
outputShape[1] = newShape[0];
|
||||
outputShape[2] = newShape[1];
|
||||
|
||||
|
@ -760,10 +762,10 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
|||
llvm::SmallVector<int64_t> offsetInt;
|
||||
llvm::SmallVector<double> strideFp;
|
||||
llvm::SmallVector<double> offsetFp;
|
||||
getI64Values(adaptor.offset(), offsetInt);
|
||||
getF64Values(adaptor.offset_fp(), offsetFp);
|
||||
getI64Values(adaptor.stride(), strideInt);
|
||||
getF64Values(adaptor.stride_fp(), strideFp);
|
||||
getI64Values(adaptor.getOffset(), offsetInt);
|
||||
getF64Values(adaptor.getOffsetFp(), offsetFp);
|
||||
getI64Values(adaptor.getStride(), strideInt);
|
||||
getF64Values(adaptor.getStrideFp(), strideFp);
|
||||
|
||||
// If we have a 0 zero in integers we know that the resize indexing needs to
|
||||
// be performed in floating point. Use the floating point varient to compute
|
||||
|
@ -1022,7 +1024,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
|||
|
||||
// Input shape describes input width/height and batch.
|
||||
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
|
@ -1030,7 +1032,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[3] = weightShape.getDimSize(0);
|
||||
weightHeight = weightShape.getDimSize(1);
|
||||
|
@ -1038,7 +1040,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
|
@ -1049,9 +1051,9 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
|
|||
llvm::SmallVector<int64_t> padding;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getI64Values(adaptor.dilation(), dilation);
|
||||
getI64Values(adaptor.pad(), padding);
|
||||
getI64Values(adaptor.stride(), stride);
|
||||
getI64Values(adaptor.getDilation(), dilation);
|
||||
getI64Values(adaptor.getPad(), padding);
|
||||
getI64Values(adaptor.getStride(), stride);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
|
@ -1091,7 +1093,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
int32_t weightDepth = ShapedType::kDynamicSize;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
|
@ -1100,7 +1102,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[4] = weightShape.getDimSize(0);
|
||||
weightHeight = weightShape.getDimSize(1);
|
||||
|
@ -1109,7 +1111,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[4] =
|
||||
(outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
|
||||
|
@ -1119,9 +1121,9 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
|
|||
llvm::SmallVector<int64_t> padding;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getI64Values(adaptor.dilation(), dilation);
|
||||
getI64Values(adaptor.pad(), padding);
|
||||
getI64Values(adaptor.stride(), stride);
|
||||
getI64Values(adaptor.getDilation(), dilation);
|
||||
getI64Values(adaptor.getPad(), padding);
|
||||
getI64Values(adaptor.getStride(), stride);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
|
@ -1183,7 +1185,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
|||
int32_t depthChannels = ShapedType::kDynamicSize;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = inputShape.getDimSize(0);
|
||||
inputHeight = inputShape.getDimSize(1);
|
||||
|
@ -1192,7 +1194,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.weight());
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
|
||||
if (weightShape.hasRank()) {
|
||||
weightHeight = weightShape.getDimSize(0);
|
||||
weightWidth = weightShape.getDimSize(1);
|
||||
|
@ -1210,7 +1212,7 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.bias());
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
|
@ -1221,9 +1223,9 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
|||
llvm::SmallVector<int64_t> padding;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getI64Values(adaptor.dilation(), dilation);
|
||||
getI64Values(adaptor.pad(), padding);
|
||||
getI64Values(adaptor.stride(), stride);
|
||||
getI64Values(adaptor.getDilation(), dilation);
|
||||
getI64Values(adaptor.getPad(), padding);
|
||||
getI64Values(adaptor.getStride(), stride);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
|
@ -1253,7 +1255,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||||
TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
|
||||
llvm::SmallVector<int64_t> outputShape;
|
||||
getI64Values(adaptor.out_shape(), outputShape);
|
||||
getI64Values(adaptor.getOutShape(), outputShape);
|
||||
|
||||
int32_t inputWidth = ShapedType::kDynamicSize;
|
||||
int32_t inputHeight = ShapedType::kDynamicSize;
|
||||
|
@ -1261,7 +1263,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
int32_t weightHeight = ShapedType::kDynamicSize;
|
||||
|
||||
// Input shape describes input width/height and batch.
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
|
||||
if (inputShape.hasRank()) {
|
||||
outputShape[0] = ShapedType::isDynamic(outputShape[0])
|
||||
? inputShape.getDimSize(0)
|
||||
|
@ -1271,7 +1273,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Weight shapes describes the filter width/height and the output channels.
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.filter());
|
||||
ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
|
||||
if (weightShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? weightShape.getDimSize(0)
|
||||
|
@ -1281,7 +1283,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
}
|
||||
|
||||
// Bias shape can describe the output channels.
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.input());
|
||||
ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
|
||||
if (biasShape.hasRank()) {
|
||||
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||||
? biasShape.getDimSize(0)
|
||||
|
@ -1291,8 +1293,8 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
|||
llvm::SmallVector<int64_t> padding;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getI64Values(adaptor.out_pad(), padding);
|
||||
getI64Values(adaptor.stride(), stride);
|
||||
getI64Values(adaptor.getOutPad(), padding);
|
||||
getI64Values(adaptor.getStride(), stride);
|
||||
|
||||
if (!ShapedType::isDynamic(inputHeight) &&
|
||||
!ShapedType::isDynamic(weightHeight)) {
|
||||
|
|
|
@ -25,8 +25,8 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::Conv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
Value input = op.getInput();
|
||||
Value weight = op.getWeight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.getType().cast<ShapedType>();
|
||||
|
@ -41,7 +41,7 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
|||
return rewriter.notifyMatchFailure(op, "unranked weight input");
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (APInt stride : op.stride().getAsValueRange<IntegerAttr>()) {
|
||||
for (APInt stride : op.getStride().getAsValueRange<IntegerAttr>()) {
|
||||
if (!stride.isOne())
|
||||
return failure();
|
||||
}
|
||||
|
@ -83,18 +83,18 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
|
|||
RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
|
||||
|
||||
Value fullyConnectedValue;
|
||||
if (op.quantization_info()) {
|
||||
if (op.getQuantizationInfo()) {
|
||||
fullyConnectedValue =
|
||||
rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType, reshapedInput,
|
||||
reshapedWeight, op.bias(), *op.quantization_info())
|
||||
reshapedWeight, op.getBias(), *op.getQuantizationInfo())
|
||||
.getResult();
|
||||
} else {
|
||||
fullyConnectedValue = rewriter
|
||||
.create<tosa::FullyConnectedOp>(
|
||||
op.getLoc(), fullyConnectedShapeType,
|
||||
reshapedInput, reshapedWeight, op.bias())
|
||||
reshapedInput, reshapedWeight, op.getBias())
|
||||
.getResult();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,11 +26,11 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
|||
|
||||
LogicalResult matchAndRewrite(tosa::DepthwiseConv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
Value input = op.getInput();
|
||||
Value weight = op.getWeight();
|
||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
||||
ShapedType resultType = op.getOutput().getType().cast<ShapedType>();
|
||||
Type inputEType = inputType.getElementType();
|
||||
|
||||
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
||||
|
@ -39,12 +39,12 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
|||
}
|
||||
|
||||
// Quantization information needs to still be performed.
|
||||
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
|
||||
if (op.getQuantizationInfo() || !inputEType.isa<FloatType>()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Stride must be 1 for this optimization.
|
||||
for (Attribute stride : op.stride().getValue()) {
|
||||
for (Attribute stride : op.getStride().getValue()) {
|
||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
|||
.getResult();
|
||||
|
||||
// Reshape output to [N, H, W, C * M].
|
||||
auto outputShape = op.output().getType().cast<ShapedType>().getShape();
|
||||
auto outputShape = op.getOutput().getType().cast<ShapedType>().getShape();
|
||||
auto outputShapeType = RankedTensorType::get(
|
||||
outputShape,
|
||||
input.getType().dyn_cast<RankedTensorType>().getElementType());
|
||||
|
@ -106,7 +106,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
|
|||
// Add in the bias.
|
||||
rewriter
|
||||
.replaceOpWithNewOp<tosa::AddOp>(op, outputShapeType, outputValue,
|
||||
op.bias())
|
||||
op.getBias())
|
||||
.getResult();
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -98,8 +98,8 @@ public:
|
|||
llvm::SmallVector<int64_t> pad;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
|
||||
getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
|
||||
getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
|
||||
getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
|
||||
|
||||
// If striding is all 1 we can modify padding and reverse the kernel along
|
||||
// the x/y direction to make it a regular convolution. This is much simpler
|
||||
|
@ -126,11 +126,11 @@ public:
|
|||
loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
|
||||
|
||||
Value conv2d;
|
||||
if (op.quantization_info()) {
|
||||
if (op.getQuantizationInfo()) {
|
||||
conv2d = rewriter.create<tosa::Conv2DOp>(
|
||||
loc, resultTy, input, reverse2, bias,
|
||||
rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
|
||||
rewriter.getI64ArrayAttr({1, 1}), *op.quantization_info());
|
||||
rewriter.getI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
|
||||
} else {
|
||||
conv2d = rewriter.create<tosa::Conv2DOp>(
|
||||
loc, resultTy, input, reverse2, bias,
|
||||
|
@ -167,8 +167,8 @@ public:
|
|||
llvm::SmallVector<int64_t> pad;
|
||||
llvm::SmallVector<int64_t> stride;
|
||||
|
||||
getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
|
||||
getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
|
||||
getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
|
||||
getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
|
||||
|
||||
// If striding is all 1 we can modify padding and reverse the kernel along
|
||||
// the x/y direction to make it a regular convolution. This is much simpler
|
||||
|
@ -200,8 +200,8 @@ public:
|
|||
Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
|
||||
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
|
||||
|
||||
if (op.quantization_info().has_value()) {
|
||||
auto quantInfo = op.quantization_info().value();
|
||||
if (op.getQuantizationInfo().has_value()) {
|
||||
auto quantInfo = op.getQuantizationInfo().value();
|
||||
weight = createOpAndInfer<tosa::PadOp>(
|
||||
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
|
||||
weightPaddingVal, nullptr,
|
||||
|
@ -264,8 +264,8 @@ public:
|
|||
Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
|
||||
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
|
||||
|
||||
if (op.quantization_info().has_value()) {
|
||||
auto quantInfo = op.quantization_info().value();
|
||||
if (op.getQuantizationInfo().has_value()) {
|
||||
auto quantInfo = op.getQuantizationInfo().value();
|
||||
input = createOpAndInfer<tosa::PadOp>(
|
||||
rewriter, loc, UnrankedTensorType::get(inputETy), input,
|
||||
inputPaddingVal, nullptr,
|
||||
|
@ -288,14 +288,14 @@ public:
|
|||
|
||||
// Perform the convolution using the zero bias.
|
||||
Value conv2d;
|
||||
if (op.quantization_info()) {
|
||||
if (op.getQuantizationInfo()) {
|
||||
conv2d = createOpAndInfer<tosa::Conv2DOp>(
|
||||
rewriter, loc, UnrankedTensorType::get(resultETy), input,
|
||||
weight, zeroBias,
|
||||
/*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
|
||||
/*stride=*/rewriter.getI64ArrayAttr({1, 1}),
|
||||
/*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
|
||||
*op.quantization_info())
|
||||
*op.getQuantizationInfo())
|
||||
.getResult();
|
||||
} else {
|
||||
conv2d = createOpAndInfer<tosa::Conv2DOp>(
|
||||
|
|
|
@ -31,21 +31,21 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
|
|||
return failure();
|
||||
|
||||
DenseElementsAttr inputValues;
|
||||
if (!matchPattern(op.input1(), m_Constant(&inputValues)))
|
||||
if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
|
||||
return failure();
|
||||
// Make sure the input is a constant that has a single user.
|
||||
if (!llvm::hasSingleElement(op.input1().getDefiningOp()->getUsers()))
|
||||
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
|
||||
return failure();
|
||||
|
||||
DenseIntElementsAttr permAttr;
|
||||
if (!matchPattern(op.perms(), m_Constant(&permAttr)))
|
||||
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
|
||||
return failure();
|
||||
auto permValues = llvm::to_vector<6>(llvm::map_range(
|
||||
// TOSA allows both 32- and 64-bit integer tensors here.
|
||||
permAttr.getValues<APInt>(),
|
||||
[](const APInt &val) { return val.getZExtValue(); }));
|
||||
|
||||
auto inputType = op.input1().getType().cast<ShapedType>();
|
||||
auto inputType = op.getInput1().getType().cast<ShapedType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numElements = inputType.getNumElements();
|
||||
|
||||
|
|
|
@ -144,8 +144,8 @@ struct ConvertTosaOp : public OpRewritePattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value input1 = tosaBinaryOp.input1();
|
||||
Value input2 = tosaBinaryOp.input2();
|
||||
Value input1 = tosaBinaryOp.getInput1();
|
||||
Value input2 = tosaBinaryOp.getInput2();
|
||||
Value output = tosaBinaryOp.getResult();
|
||||
|
||||
auto outputType = output.getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -174,9 +174,9 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
|
|||
LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value input1 = tosaBinaryOp.input1();
|
||||
Value input2 = tosaBinaryOp.input2();
|
||||
int32_t shift = tosaBinaryOp.shift();
|
||||
Value input1 = tosaBinaryOp.getInput1();
|
||||
Value input2 = tosaBinaryOp.getInput2();
|
||||
int32_t shift = tosaBinaryOp.getShift();
|
||||
Value output = tosaBinaryOp.getResult();
|
||||
auto outputType = output.getType().dyn_cast<RankedTensorType>();
|
||||
if (!outputType)
|
||||
|
@ -206,9 +206,9 @@ struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
|
|||
LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value input1 = tosaBinaryOp.input1();
|
||||
Value input2 = tosaBinaryOp.input2();
|
||||
int32_t round = tosaBinaryOp.round();
|
||||
Value input1 = tosaBinaryOp.getInput1();
|
||||
Value input2 = tosaBinaryOp.getInput2();
|
||||
int32_t round = tosaBinaryOp.getRound();
|
||||
Value output = tosaBinaryOp.getResult();
|
||||
auto outputType = output.getType().dyn_cast<RankedTensorType>();
|
||||
if (!outputType)
|
||||
|
|
|
@ -42,7 +42,7 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op,
|
|||
auto tosaNegateOp = cast<tosa::NegateOp>(op);
|
||||
|
||||
auto inputType =
|
||||
tosaNegateOp.input1().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
tosaNegateOp.getInput1().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
// skip if input is not ranked tensor type
|
||||
if (!inputType)
|
||||
return failure();
|
||||
|
@ -83,7 +83,7 @@ ConvertTosaNegateOp::matchAndRewrite(Operation *op,
|
|||
rewriter.getBoolAttr(narrowRange)));
|
||||
|
||||
ElementsAttr inputElems;
|
||||
if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems)))
|
||||
if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems)))
|
||||
return failure();
|
||||
|
||||
auto newConstOp =
|
||||
|
@ -112,14 +112,14 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
|
|||
auto tosaConv2DOp = cast<tosa::Conv2DOp>(op);
|
||||
|
||||
auto inputType =
|
||||
tosaConv2DOp.input().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
tosaConv2DOp.getInput().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
|
||||
// skip if input is not ranked tensor type
|
||||
if (!inputType)
|
||||
return failure();
|
||||
|
||||
auto weightType =
|
||||
tosaConv2DOp.weight().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
tosaConv2DOp.getWeight().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
|
||||
// skip if wt is not ranked tensor type
|
||||
if (!weightType)
|
||||
|
@ -146,9 +146,9 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
|
|||
RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32));
|
||||
|
||||
auto newTosaConv2DOp = rewriter.create<tosa::Conv2DOp>(
|
||||
op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(),
|
||||
tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(),
|
||||
tosaConv2DOp.stride(), tosaConv2DOp.dilation());
|
||||
op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(),
|
||||
tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPad(),
|
||||
tosaConv2DOp.getStride(), tosaConv2DOp.getDilation());
|
||||
|
||||
// Create rescale to quantized type
|
||||
double inputScale = inputQType.getScale();
|
||||
|
|
Loading…
Reference in New Issue