[mlir][arith] Add `index_cast` and `index_castui` support to WIE
Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D138225
This commit is contained in:
parent
a542d5422a
commit
92bcb8ccbb
|
@ -598,6 +598,86 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Convert IndexCast ops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true iff the type is `index` or `vector<...index>`.
|
||||
static bool isIndexOrIndexVector(Type type) {
|
||||
if (type.isa<IndexType>())
|
||||
return true;
|
||||
|
||||
if (auto vectorTy = type.dyn_cast<VectorType>())
|
||||
if (vectorTy.getElementType().isa<IndexType>())
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename CastOp>
|
||||
struct ConvertIndexCastIntToIndex final : OpConversionPattern<CastOp> {
|
||||
using OpConversionPattern<CastOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type resultType = op.getType();
|
||||
if (!isIndexOrIndexVector(resultType))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Type inType = op.getIn().getType();
|
||||
auto newInTy = this->getTypeConverter()
|
||||
->convertType(inType)
|
||||
.template dyn_cast_or_null<VectorType>();
|
||||
if (!newInTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, llvm::formatv("unsupported type: {0}", inType));
|
||||
|
||||
// Discard the high half of the input truncating the original value.
|
||||
Value extracted = extractLastDimSlice(rewriter, loc, adaptor.getIn(), 0);
|
||||
extracted = dropTrailingX1Dim(rewriter, loc, extracted);
|
||||
rewriter.replaceOpWithNewOp<CastOp>(op, resultType, extracted);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename CastOp, typename ExtensionOp>
|
||||
struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
|
||||
using OpConversionPattern<CastOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type inType = op.getIn().getType();
|
||||
if (!isIndexOrIndexVector(inType))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto *typeConverter =
|
||||
this->template getTypeConverter<arith::WideIntEmulationConverter>();
|
||||
|
||||
Type resultType = op.getType();
|
||||
auto newTy = typeConverter->convertType(resultType)
|
||||
.template dyn_cast_or_null<VectorType>();
|
||||
if (!newTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, llvm::formatv("unsupported type: {0}", resultType));
|
||||
|
||||
// Emit an index cast over the matching narrow type.
|
||||
Type narrowTy =
|
||||
rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
|
||||
if (auto vecTy = resultType.dyn_cast<VectorType>())
|
||||
narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
|
||||
|
||||
// Sign or zero-extend the result. Let the matching conversion pattern
|
||||
// legalize the extension op.
|
||||
Value underlyingVal =
|
||||
rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn());
|
||||
rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvertSelect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -841,8 +921,7 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
|
|||
// Rewrite this as an bitwise or of `arith.shrui` and sign extension bits.
|
||||
// Perform as many ops over the narrow integer type as possible and let the
|
||||
// other emulation patterns convert the rest.
|
||||
Value elemZero =
|
||||
createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
|
||||
Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0);
|
||||
Value signBit = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, lhsElem1, elemZero);
|
||||
signBit = dropTrailingX1Dim(rewriter, loc, signBit);
|
||||
|
@ -862,7 +941,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
|
|||
rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits);
|
||||
|
||||
// Use original arguments to create the right shift.
|
||||
Value shrui = rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
|
||||
Value shrui =
|
||||
rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs());
|
||||
Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits);
|
||||
|
||||
// Handle shifting by zero. This is necessary when the `signBits` shift is
|
||||
|
@ -870,7 +950,8 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
|
|||
Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
rhsElem0, elemZero);
|
||||
isNoop = dropTrailingX1Dim(rewriter, loc, isNoop);
|
||||
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi);
|
||||
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(),
|
||||
shrsi);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1045,6 +1126,11 @@ void arith::populateArithWideIntEmulationPatterns(
|
|||
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
|
||||
ConvertBitwiseBinary<arith::XOrIOp>,
|
||||
// Extension and truncation ops.
|
||||
ConvertExtSI, ConvertExtUI, ConvertTruncI>(typeConverter,
|
||||
patterns.getContext());
|
||||
ConvertExtSI, ConvertExtUI, ConvertTruncI,
|
||||
// Cast ops.
|
||||
ConvertIndexCastIntToIndex<arith::IndexCastOp>,
|
||||
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
|
||||
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
|
||||
ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
|
|
@ -365,6 +365,102 @@ func.func @extui_vector(%a : vector<3xi16>) -> vector<3xi64> {
|
|||
return %r : vector<3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_cast_int_to_index_scalar
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
|
||||
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[EXT]] : i32 to index
|
||||
// CHECK-NEXT: return [[RES]] : index
|
||||
func.func @index_cast_int_to_index_scalar(%a : i64) -> index {
|
||||
%r = arith.index_cast %a : i64 to index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_cast_int_to_index_vector
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
|
||||
// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = arith.index_cast [[SHAPE]] : vector<3xi32> to vector<3xindex>
|
||||
// CHECK-NEXT: return [[RES]] : vector<3xindex>
|
||||
func.func @index_cast_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
|
||||
%r = arith.index_cast %a : vector<3xi64> to vector<3xindex>
|
||||
return %r : vector<3xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_castui_int_to_index_scalar
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> index
|
||||
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[EXT]] : i32 to index
|
||||
// CHECK-NEXT: return [[RES]] : index
|
||||
func.func @index_castui_int_to_index_scalar(%a : i64) -> index {
|
||||
%r = arith.index_castui %a : i64 to index
|
||||
return %r : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_castui_int_to_index_vector
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xindex>
|
||||
// CHECK-NEXT: [[EXT:%.+]] = vector.extract_strided_slice [[ARG]] {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[EXT]] : vector<3x1xi32> to vector<3xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = arith.index_castui [[SHAPE]] : vector<3xi32> to vector<3xindex>
|
||||
// CHECK-NEXT: return [[RES]] : vector<3xindex>
|
||||
func.func @index_castui_int_to_index_vector(%a : vector<3xi64>) -> vector<3xindex> {
|
||||
%r = arith.index_castui %a : vector<3xi64> to vector<3xindex>
|
||||
return %r : vector<3xindex>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_cast_index_to_int_scalar
|
||||
// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
|
||||
// CHECK-NEXT: [[CAST:%.+]] = arith.index_cast [[ARG]] : index to i32
|
||||
// CHECK-NEXT: [[C0I32:%.+]] = arith.constant 0 : i32
|
||||
// CHECK-NEXT: [[NEG:%.+]] = arith.cmpi slt, [[CAST]], [[C0I32]] : i32
|
||||
// CHECK-NEXT: [[EXT:%.+]] = arith.extsi [[NEG]] : i1 to i32
|
||||
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
|
||||
// CHECK-NEXT: [[INS0:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
|
||||
// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[EXT]], [[INS0]] [1] : i32 into vector<2xi32>
|
||||
// CHECK-NEXT: return [[INS1]] : vector<2xi32>
|
||||
func.func @index_cast_index_to_int_scalar(%a : index) -> i64 {
|
||||
%r = arith.index_cast %a : index to i64
|
||||
return %r : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_cast_index_to_int_vector
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
|
||||
// CHECK-NEXT: arith.index_cast [[ARG]] : vector<3xindex> to vector<3xi32>
|
||||
// CHECK-NEXT: vector.shape_cast
|
||||
// CHECK-NEXT: arith.constant dense<0> : vector<3x1xi32>
|
||||
// CHECK-NEXT: arith.cmpi slt
|
||||
// CHECK-NEXT: arith.extsi
|
||||
// CHECK-NEXT: arith.constant dense<0> : vector<3x2xi32>
|
||||
// CHECK-NEXT: vector.insert_strided_slice
|
||||
// CHECK-NEXT: vector.insert_strided_slice
|
||||
// CHECK-NEXT: return {{%.+}} : vector<3x2xi32>
|
||||
func.func @index_cast_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
|
||||
%r = arith.index_cast %a : vector<3xindex> to vector<3xi64>
|
||||
return %r : vector<3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_castui_index_to_int_scalar
|
||||
// CHECK-SAME: ([[ARG:%.+]]: index) -> vector<2xi32>
|
||||
// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : index to i32
|
||||
// CHECK-NEXT: [[VZ:%.+]] = arith.constant dense<0> : vector<2xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = vector.insert [[CAST]], [[VZ]] [0] : i32 into vector<2xi32>
|
||||
// CHECK-NEXT: return [[RES]] : vector<2xi32>
|
||||
func.func @index_castui_index_to_int_scalar(%a : index) -> i64 {
|
||||
%r = arith.index_castui %a : index to i64
|
||||
return %r : i64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @index_castui_index_to_int_vector
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<3xindex>) -> vector<3x2xi32>
|
||||
// CHECK-NEXT: [[CAST:%.+]] = arith.index_castui [[ARG]] : vector<3xindex> to vector<3xi32>
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = vector.shape_cast [[CAST]] : vector<3xi32> to vector<3x1xi32>
|
||||
// CHECK-NEXT: [[CST:%.+]] = arith.constant dense<0> : vector<3x2xi32>
|
||||
// CHECK-NEXT: [[RES:%.+]] = vector.insert_strided_slice [[SHAPE]], [[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
|
||||
// CHECK-NEXT: return [[RES]] : vector<3x2xi32>
|
||||
func.func @index_castui_index_to_int_vector(%a : vector<3xindex>) -> vector<3xi64> {
|
||||
%r = arith.index_castui %a : vector<3xindex> to vector<3xi64>
|
||||
return %r : vector<3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @trunci_scalar1
|
||||
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> i32
|
||||
// CHECK-NEXT: [[EXT:%.+]] = vector.extract [[ARG]][0] : vector<2xi32>
|
||||
|
|
Loading…
Reference in New Issue