[mlir][arith] Add `index_cast` and `index_castui` support to WIE

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D138225
This commit is contained in:
Jakub Kuderski 2022-11-17 14:01:50 -05:00
parent a542d5422a
commit 92bcb8ccbb
2 changed files with 188 additions and 6 deletions

View File

@ -598,6 +598,86 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
} }
}; };
// Convert IndexCast ops
//===----------------------------------------------------------------------===//
/// Returns true iff the type is `index` or `vector<...index>`.
static bool isIndexOrIndexVector(Type type) {
if (type.isa<IndexType>())
return true;
if (auto vectorTy = type.dyn_cast<VectorType>())
if (vectorTy.getElementType().isa<IndexType>())
return true;
return false;
}
template <typename CastOp>
struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
using OpConversionPattern<CastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type resultType = op.getType();
if (!isIndexOrIndexVector(resultType))
return failure();
Location loc = op.getLoc();
Type inType = op.getIn().getType();
auto newInTy = this->getTypeConverter()
->convertType(inType)
.template dyn_cast_or_null<VectorType>();
if (!newInTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", inType));
// Discard the high half of the input truncating the original value.
Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
extracted = dropTrailingX1Dim(rewriter, loc, extracted);
rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
return success();
}
};
template <typename CastOp, typename ExtensionOp>
struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
using OpConversionPattern<CastOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type inType = op.getIn().getType();
if (!isIndexOrIndexVector(inType))
return failure();
Location loc = op.getLoc();
auto *typeConverter =
this->template getTypeConverter<arith::WideIntEmulationConverter>();
Type resultType = op.getType();
auto newTy = typeConverter->convertType(resultType)
.template dyn_cast_or_null<VectorType>();
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", resultType));
// Emit an index cast over the matching narrow type.
Type narrowTy =
rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
if (auto vecTy = resultType.dyn_cast<VectorType>())
narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
// Sign or zero-extend the result. Let the matching conversion pattern
// legalize the extension op.
Value underlyingVal =
rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
return success();
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConvertSelect // ConvertSelect
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -841,8 +921,7 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
// Rewrite this as an bitwise or of `arith.shrui` and sign extension bits. // Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
// Perform as many ops over the narrow integer type as possible and let the // Perform as many ops over the narrow integer type as possible and let the
// other emulation patterns convert the rest. // other emulation patterns convert the rest.
Value elemZero = Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
Value signBit = rewriter.create<arith::CmpIOp>( Value signBit = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
signBit = dropTrailingX1Dim(rewriter, loc, signBit); signBit = dropTrailingX1Dim(rewriter, loc, signBit);
@ -862,7 +941,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits); rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
// Use original arguments to create the right shift. // Use original arguments to create the right shift.
Value shrui = rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs()); Value shrui =
rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits); Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
// Handle shifting by zero. This is necessary when the `signBits` shift is // Handle shifting by zero. This is necessary when the `signBits` shift is
@ -870,7 +950,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
rhsElem0, elemZero); rhsElem0, elemZero);
isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
shrsi);
return success(); return success();
} }
@ -1045,6 +1126,11 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>, ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
ConvertBitwiseBinary<arith::XOrIOp>, ConvertBitwiseBinary<arith::XOrIOp>,
// Extension and truncation ops. // Extension and truncation ops.
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter, ConvertExtSI, ConvertExtUI, ConvertTruncI,
patterns.getContext()); // Cast ops.
ConvertIndexCastIntToIndex<arith::IndexCastOp>,
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
typeConverter, patterns.getContext());
} }

View File

@ -365,6 +365,102 @@ func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
return %r : vector<3xi64> return %r : vector<3xi64>
} }
// CHECK-LABEL: func @index_cast_int_to_index_scalar
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[EXT]] : i32 to index
// CHECK-NEXT: return [[RES]] : index
func.func @index_cast_int_to_index_scalar(%a : i64) -> index {
%r = arith.index_cast %a : i64 to index
return %r : index
}
// CHECK-LABEL: func @index_cast_int_to_index_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[SHAPE]] : vector<3xi32> to vector<3xindex>
// CHECK-NEXT: return [[RES]] : vector<3xindex>
func.func @index_cast_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
%r = arith.index_cast %a : vector<3xi64> to vector<3xindex>
return %r : vector<3xindex>
}
// CHECK-LABEL: func @index_castui_int_to_index_scalar
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[EXT]] : i32 to index
// CHECK-NEXT: return [[RES]] : index
func.func @index_castui_int_to_index_scalar(%a : i64) -> index {
%r = arith.index_castui %a : i64 to index
return %r : index
}
// CHECK-LABEL: func @index_castui_int_to_index_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[SHAPE]] : vector<3xi32> to vector<3xindex>
// CHECK-NEXT: return [[RES]] : vector<3xindex>
func.func @index_castui_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
%r = arith.index_castui %a : vector<3xi64> to vector<3xindex>
return %r : vector<3xindex>
}
// CHECK-LABEL: func @index_cast_index_to_int_scalar
// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
// CHECK-NEXT: [[CAST:%.+]] = arith.index_cast [[ARG]] : index to i32
// CHECK-NEXT: [[C0I32:%.+]] = arith.constant 0 : i32
// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[CAST]], [[C0I32]] : i32
// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[NEG]] : i1 to i32
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[EXT]], [[INS0]] [1] : i32 into vector<2xi32>
// CHECK-NEXT: return [[INS1]] : vector<2xi32>
func.func @index_cast_index_to_int_scalar(%a : index) -> i64 {
%r = arith.index_cast %a : index to i64
return %r : i64
}
// CHECK-LABEL: func @index_cast_index_to_int_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
// CHECK-NEXT: arith.index_cast [[ARG]] : vector<3xindex> to vector<3xi32>
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: arith.constant dense<0> : vector<3x1xi32>
// CHECK-NEXT: arith.cmpi slt
// CHECK-NEXT: arith.extsi
// CHECK-NEXT: arith.constant dense<0> : vector<3x2xi32>
// CHECK-NEXT: vector.insert_strided_slice
// CHECK-NEXT: vector.insert_strided_slice
// CHECK-NEXT: return {{%.+}} : vector<3x2xi32>
func.func @index_cast_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
%r = arith.index_cast %a : vector<3xindex> to vector<3xi64>
return %r : vector<3xi64>
}
// CHECK-LABEL: func @index_castui_index_to_int_scalar
// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : index to i32
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
// CHECK-NEXT: [[RES:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
// CHECK-NEXT: return [[RES]] : vector<2xi32>
func.func @index_castui_index_to_int_scalar(%a : index) -> i64 {
%r = arith.index_castui %a : index to i64
return %r : i64
}
// CHECK-LABEL: func @index_castui_index_to_int_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : vector<3xindex> to vector<3xi32>
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[CAST]] : vector<3xi32> to vector<3x1xi32>
// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32>
// CHECK-NEXT: [[RES:%.+]] = vector.insert_strided_slice [[SHAPE]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
// CHECK-NEXT: return [[RES]] : vector<3x2xi32>
func.func @index_castui_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
%r = arith.index_castui %a : vector<3xindex> to vector<3xi64>
return %r : vector<3xi64>
}
// CHECK-LABEL: func @trunci_scalar1 // CHECK-LABEL: func @trunci_scalar1
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32 // CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32> // CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>