[mlir] NFC: Remove Value::operator* and Value::operator-> now that Value is properly value-typed.
Summary: These were temporary methods used to simplify the transition. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D72548
This commit is contained in:
parent
1d641daf26
commit
2bdf33cc4c
|
@ -618,8 +618,7 @@ results. The third parameter to `Pattern` (and `Pat`) is for this purpose.
|
|||
For example, we can write
|
||||
|
||||
```tablegen
|
||||
def HasNoUseOf: Constraint<
|
||||
CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">;
|
||||
def HasNoUseOf: Constraint<CPred<"$_self.use_empty()">, "has no use">;
|
||||
|
||||
def HasSameElementType : Constraint<
|
||||
CPred<"$0.cast<ShapedType>().getElementType() == "
|
||||
|
|
|
@ -734,7 +734,7 @@ is used. They serve as "hooks" to the enclosing environment. This includes
|
|||
we want the constraints on each type definition reads naturally and we want
|
||||
to attach type constraints directly to an operand/result, `$_self` will be
|
||||
replaced by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
|
||||
`$_self` will be expanded as `getOperand(...)->getType()`.
|
||||
`$_self` will be expanded as `getOperand(...).getType()`.
|
||||
|
||||
TODO(b/130663252): Reconsider the leading symbol for special placeholders.
|
||||
Eventually we want to allow referencing operand/result $-names; such $-names
|
||||
|
|
|
@ -121,7 +121,7 @@ replacement:
|
|||
|
||||
```tablegen
|
||||
def createTFLLeakyRelu : NativeCodeCall<
|
||||
"createTFLLeakyRelu($_builder, $0->getDefiningOp(), $1, $2)">;
|
||||
"createTFLLeakyRelu($_builder, $0.getDefiningOp(), $1, $2)">;
|
||||
|
||||
def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
|
||||
(createTFLLeakyRelu $old_value, $arg, $a)>;
|
||||
|
@ -131,7 +131,7 @@ def : Pat<(TF_LeakyReluOp:$old_value, $arg, F32Attr:$a),
|
|||
static Value createTFLLeakyRelu(PatternRewriter &rewriter, Operation *op,
|
||||
Value operand, Attribute attr) {
|
||||
return rewriter.create<mlir::TFL::LeakyReluOp>(
|
||||
op->getLoc(), operands[0]->getType(), /*arg=*/operands[0],
|
||||
op->getLoc(), operands[0].getType(), /*arg=*/operands[0],
|
||||
/*alpha=*/attrs[0].cast<FloatAttr>());
|
||||
}
|
||||
```
|
||||
|
@ -177,7 +177,7 @@ struct ConvertTFLeakyRelu : public RewritePattern {
|
|||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
||||
op, op->getResult(0)->getType(), op->getOperand(0),
|
||||
op, op->getResult(0).getType(), op->getOperand(0),
|
||||
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
||||
}
|
||||
};
|
||||
|
@ -191,7 +191,7 @@ struct ConvertTFLeakyRelu : public RewritePattern {
|
|||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
||||
op, op->getResult(0)->getType(), op->getOperand(0),
|
||||
op, op->getResult(0).getType(), op->getOperand(0),
|
||||
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
|
@ -92,7 +92,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
@ -194,7 +194,7 @@ An example is a transformation that eliminates reshapes when they are redundant,
|
|||
i.e. when the input and output shapes are identical.
|
||||
|
||||
```tablegen
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
@ -208,7 +208,7 @@ optimize Reshape of a constant value by reshaping the constant in place and
|
|||
eliminating the reshape operation.
|
||||
|
||||
```tablegen
|
||||
def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
def ReshapeConstant : NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
|
|
@ -82,7 +82,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
};
|
||||
```
|
||||
|
@ -310,7 +310,7 @@ inferred as the shape of the inputs.
|
|||
```c++
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
```
|
||||
|
||||
At this point, each of the necessary Toy operations provide a mechanism by which
|
||||
|
|
|
@ -54,8 +54,7 @@ void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType =
|
||||
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
|
@ -158,7 +157,7 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -54,8 +54,7 @@ void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType =
|
||||
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
|
@ -158,7 +157,7 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -41,7 +41,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
|
|
|
@ -41,7 +41,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
|||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
@ -54,7 +54,7 @@ def FoldConstantReshapeOptPattern : Pat<
|
|||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
|
|
@ -53,7 +53,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
|
@ -104,8 +104,7 @@ void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType =
|
||||
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
|
@ -142,14 +141,14 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
@ -183,7 +182,7 @@ void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnOp
|
||||
|
@ -233,13 +232,13 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -46,7 +46,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
|
|
|
@ -41,7 +41,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
|||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
@ -54,7 +54,7 @@ def FoldConstantReshapeOptPattern : Pat<
|
|||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
|
|
@ -53,7 +53,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
|
@ -104,8 +104,7 @@ void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType =
|
||||
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
|
@ -142,14 +141,14 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
@ -183,7 +182,7 @@ void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnOp
|
||||
|
@ -233,13 +232,13 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -46,7 +46,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
|
|
|
@ -41,7 +41,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
|||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
@ -54,7 +54,7 @@ def FoldConstantReshapeOptPattern : Pat<
|
|||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
|
|
@ -53,7 +53,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
|
@ -104,8 +104,7 @@ void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType =
|
||||
op.getResult()->getType().dyn_cast<mlir::RankedTensorType>();
|
||||
auto resultType = op.getResult().getType().dyn_cast<mlir::RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
|
@ -142,14 +141,14 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
@ -183,7 +182,7 @@ void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnOp
|
||||
|
@ -233,13 +232,13 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -46,7 +46,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
|
|
|
@ -41,7 +41,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
|||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
@ -54,7 +54,7 @@ def FoldConstantReshapeOptPattern : Pat<
|
|||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
|
|
@ -54,7 +54,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
|
||||
/// Attempts to materialize a conversion for a type mismatch between a call
|
||||
|
@ -171,16 +171,16 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
|
|||
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
|
||||
/// in the op definition.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
return verifyConstantForType(op.getResult()->getType(), op.value(), op);
|
||||
return verifyConstantForType(op.getResult().getType(), op.value(), op);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(StructConstantOp op) {
|
||||
return verifyConstantForType(op.getResult()->getType(), op.value(), op);
|
||||
return verifyConstantForType(op.getResult().getType(), op.value(), op);
|
||||
}
|
||||
|
||||
/// Infer the output shape of the ConstantOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void ConstantOp::inferShapes() { getResult()->setType(value().getType()); }
|
||||
void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AddOp
|
||||
|
@ -193,14 +193,14 @@ void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the AddOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void AddOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
|
||||
/// Infer the output shape of the CastOp, this is required by the shape
|
||||
/// inference interface.
|
||||
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GenericCallOp
|
||||
|
@ -234,7 +234,7 @@ void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
|
||||
/// Infer the output shape of the MulOp, this is required by the shape inference
|
||||
/// interface.
|
||||
void MulOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReturnOp
|
||||
|
@ -280,7 +280,7 @@ static mlir::LogicalResult verify(ReturnOp op) {
|
|||
void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state,
|
||||
mlir::Value input, size_t index) {
|
||||
// Extract the result type from the input type.
|
||||
StructType structTy = input->getType().cast<StructType>();
|
||||
StructType structTy = input.getType().cast<StructType>();
|
||||
assert(index < structTy.getNumElementTypes());
|
||||
mlir::Type resultType = structTy.getElementTypes()[index];
|
||||
|
||||
|
@ -289,12 +289,12 @@ void StructAccessOp::build(mlir::Builder *b, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
static mlir::LogicalResult verify(StructAccessOp op) {
|
||||
StructType structTy = op.input()->getType().cast<StructType>();
|
||||
StructType structTy = op.input().getType().cast<StructType>();
|
||||
size_t index = op.index().getZExtValue();
|
||||
if (index >= structTy.getNumElementTypes())
|
||||
return op.emitOpError()
|
||||
<< "index should be within the range of the input struct type";
|
||||
mlir::Type resultType = op.getResult()->getType();
|
||||
mlir::Type resultType = op.getResult().getType();
|
||||
if (resultType != structTy.getElementTypes()[index])
|
||||
return op.emitOpError() << "must have the same result type as the struct "
|
||||
"element referred to by the index";
|
||||
|
@ -311,13 +311,13 @@ void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
|||
}
|
||||
|
||||
void TransposeOp::inferShapes() {
|
||||
auto arrayTy = getOperand()->getType().cast<RankedTensorType>();
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
||||
getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(TransposeOp op) {
|
||||
auto inputType = op.getOperand()->getType().dyn_cast<RankedTensorType>();
|
||||
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
|
||||
auto resultType = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!inputType || !resultType)
|
||||
return mlir::success();
|
||||
|
|
|
@ -585,11 +585,11 @@ private:
|
|||
mlir::Type type = getType(varType, vardecl.loc());
|
||||
if (!type)
|
||||
return nullptr;
|
||||
if (type != value->getType()) {
|
||||
if (type != value.getType()) {
|
||||
emitError(loc(vardecl.loc()))
|
||||
<< "struct type of initializer is different than the variable "
|
||||
"declaration. Got "
|
||||
<< value->getType() << ", but expected " << type;
|
||||
<< value.getType() << ", but expected " << type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
|||
// Look through the input of the current transpose.
|
||||
mlir::Value transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
|
|
|
@ -41,7 +41,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
|||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
NativeCodeCall<"$0.reshape(($1.getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
@ -54,7 +54,7 @@ def FoldConstantReshapeOptPattern : Pat<
|
|||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def TypesAreIdentical : Constraint<CPred<"$0.getType() == $1.getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
|
|
@ -69,7 +69,7 @@ public:
|
|||
|
||||
/// Return true if operation A dominates operation B.
|
||||
bool dominates(Value a, Operation *b) {
|
||||
return (Operation *)a->getDefiningOp() == b || properlyDominates(a, b);
|
||||
return (Operation *)a.getDefiningOp() == b || properlyDominates(a, b);
|
||||
}
|
||||
|
||||
/// Return true if the specified block A dominates block B.
|
||||
|
|
|
@ -151,7 +151,7 @@ public:
|
|||
/// Returns the source MemRefType for this DMA operation.
|
||||
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
|
||||
MemRefType getSrcMemRefType() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>();
|
||||
return getSrcMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the source MemRefType.
|
||||
|
@ -172,7 +172,7 @@ public:
|
|||
|
||||
/// Returns the memory space of the src memref.
|
||||
unsigned getSrcMemorySpace() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
/// Returns the operand index of the dst memref.
|
||||
|
@ -183,17 +183,17 @@ public:
|
|||
/// Returns the destination MemRefType for this DMA operations.
|
||||
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
|
||||
MemRefType getDstMemRefType() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>();
|
||||
return getDstMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the memory space of the src memref.
|
||||
unsigned getDstMemorySpace() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the dst memref.
|
||||
|
@ -217,12 +217,12 @@ public:
|
|||
/// Returns the Tag MemRef for this DMA operation.
|
||||
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
|
||||
MemRefType getTagMemRefType() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>();
|
||||
return getTagMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the rank (number of indices) of the tag MemRefType.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the tag memref.
|
||||
|
@ -335,7 +335,7 @@ public:
|
|||
// Returns the Tag MemRef associated with the DMA operation being waited on.
|
||||
Value getTagMemRef() { return getOperand(0); }
|
||||
MemRefType getTagMemRefType() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>();
|
||||
return getTagMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to access the tag memref.
|
||||
|
@ -352,7 +352,7 @@ public:
|
|||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
/// Returns the AffineMapAttr associated with 'memref'.
|
||||
|
@ -411,7 +411,7 @@ public:
|
|||
Value getMemRef() { return getOperand(getMemRefOperandIndex()); }
|
||||
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
|
||||
MemRefType getMemRefType() {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
return getMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Get affine map operands.
|
||||
|
@ -482,7 +482,7 @@ public:
|
|||
void setMemRef(Value value) { setOperand(getMemRefOperandIndex(), value); }
|
||||
|
||||
MemRefType getMemRefType() {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
return getMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Get affine map operands.
|
||||
|
|
|
@ -296,7 +296,7 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return memref()->getType().cast<MemRefType>();
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// Returns the affine map used to index the memref for this operation.
|
||||
|
|
|
@ -178,7 +178,7 @@ def LLVM_ICmpOp : LLVM_OneResultOp<"icmp", [NoSideEffect]>,
|
|||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, ICmpPredicate predicate, Value lhs, "
|
||||
"Value rhs", [{
|
||||
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
|
||||
LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
|
||||
build(b, result, LLVMType::getInt1Ty(dialect),
|
||||
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
|
||||
}]>];
|
||||
|
@ -225,7 +225,7 @@ def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]>,
|
|||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, FCmpPredicate predicate, Value lhs, "
|
||||
"Value rhs", [{
|
||||
LLVMDialect *dialect = &lhs->getType().cast<LLVMType>().getDialect();
|
||||
LLVMDialect *dialect = &lhs.getType().cast<LLVMType>().getDialect();
|
||||
build(b, result, LLVMType::getInt1Ty(dialect),
|
||||
b->getI64IntegerAttr(static_cast<int64_t>(predicate)), lhs, rhs);
|
||||
}]>];
|
||||
|
@ -285,7 +285,7 @@ def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
|
|||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, Value addr",
|
||||
[{
|
||||
auto type = addr->getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
auto type = addr.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
build(b, result, type, addr);
|
||||
}]>];
|
||||
let parser = [{ return parseLoadOp(parser, result); }];
|
||||
|
@ -378,7 +378,7 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
|
|||
"Builder *b, OperationState &result, Value container, Value value, "
|
||||
"ArrayAttr position",
|
||||
[{
|
||||
build(b, result, container->getType(), container, value, position);
|
||||
build(b, result, container.getType(), container, value, position);
|
||||
}]>];
|
||||
let parser = [{ return parseInsertValueOp(parser, result); }];
|
||||
let printer = [{ printInsertValueOp(p, *this); }];
|
||||
|
@ -392,8 +392,8 @@ def LLVM_ShuffleVectorOp
|
|||
"Builder *b, OperationState &result, Value v1, Value v2, "
|
||||
"ArrayAttr mask, ArrayRef<NamedAttribute> attrs = {}">];
|
||||
let verifier = [{
|
||||
auto wrappedVectorType1 = v1()->getType().cast<LLVM::LLVMType>();
|
||||
auto wrappedVectorType2 = v2()->getType().cast<LLVM::LLVMType>();
|
||||
auto wrappedVectorType1 = v1().getType().cast<LLVM::LLVMType>();
|
||||
auto wrappedVectorType2 = v2().getType().cast<LLVM::LLVMType>();
|
||||
if (!wrappedVectorType2.getUnderlyingType()->isVectorTy())
|
||||
return emitOpError("expected LLVM IR Dialect vector type for operand #2");
|
||||
if (wrappedVectorType1.getVectorElementType() !=
|
||||
|
@ -415,7 +415,7 @@ def LLVM_SelectOp
|
|||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result, Value condition, Value lhs, "
|
||||
"Value rhs", [{
|
||||
build(b, result, lhs->getType(), condition, lhs, rhs);
|
||||
build(b, result, lhs.getType(), condition, lhs, rhs);
|
||||
}]>];
|
||||
let parser = [{ return parseSelectOp(parser, result); }];
|
||||
let printer = [{ printSelectOp(p, *this); }];
|
||||
|
|
|
@ -56,7 +56,7 @@ struct StructuredIndexed {
|
|||
private:
|
||||
StructuredIndexed(Value v, ArrayRef<AffineExpr> indexings)
|
||||
: value(v), exprs(indexings.begin(), indexings.end()) {
|
||||
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
|
||||
assert(v.getType().isa<MemRefType>() && "MemRefType expected");
|
||||
}
|
||||
StructuredIndexed(ValueHandle v, ArrayRef<AffineExpr> indexings)
|
||||
: StructuredIndexed(v.getValue(), indexings) {}
|
||||
|
|
|
@ -165,7 +165,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
|
|||
Type getElementType() { return getShapedType().getElementType(); }
|
||||
ShapedType getShapedType() { return getType().cast<ShapedType>(); }
|
||||
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
|
||||
ShapedType getBaseViewType() { return view()->getType().cast<ShapedType>();}
|
||||
ShapedType getBaseViewType() { return view().getType().cast<ShapedType>();}
|
||||
|
||||
// Get the underlying indexing at a given rank.
|
||||
Value indexing(unsigned rank) { return *(indexings().begin() + rank); }
|
||||
|
@ -174,7 +174,7 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
|
|||
SmallVector<Value, 8> getRanges() {
|
||||
SmallVector<Value, 8> res;
|
||||
for (auto operand : indexings())
|
||||
if (!operand->getType().isa<IndexType>())
|
||||
if (!operand.getType().isa<IndexType>())
|
||||
res.push_back(operand);
|
||||
return res;
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPermutationAttrName() { return "permutation"; }
|
||||
ShapedType getShapedType() { return view()->getType().cast<ShapedType>(); }
|
||||
ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -236,7 +236,7 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
|
|||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
|
||||
unsigned nPar = input().getType().cast<ShapedType>().getRank();
|
||||
MLIRContext *ctx = getContext();
|
||||
SmallVector<Attribute, 8> iters(
|
||||
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
|
||||
|
@ -253,7 +253,7 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
|
|||
ArrayAttr indexing_maps();
|
||||
|
||||
ArrayAttr iterator_types() {
|
||||
unsigned nPar = input()->getType().cast<ShapedType>().getRank();
|
||||
unsigned nPar = input().getType().cast<ShapedType>().getRank();
|
||||
MLIRContext *ctx = getContext();
|
||||
SmallVector<Attribute, 8> iters(
|
||||
nPar, StringAttr::get(getParallelIteratorTypeName(), ctx));
|
||||
|
|
|
@ -83,7 +83,7 @@ public:
|
|||
}
|
||||
/// Return the `i`-th input buffer type.
|
||||
ShapedType getInputShapedType(unsigned i) {
|
||||
return getInput(i)->getType().template cast<ShapedType>();
|
||||
return getInput(i).getType().template cast<ShapedType>();
|
||||
}
|
||||
/// Return the range over inputs.
|
||||
Operation::operand_range getInputs() {
|
||||
|
@ -104,7 +104,7 @@ public:
|
|||
}
|
||||
/// Return the `i`-th output buffer type.
|
||||
ShapedType getOutputShapedType(unsigned i) {
|
||||
return getOutput(i)->getType().template cast<ShapedType>();
|
||||
return getOutput(i).getType().template cast<ShapedType>();
|
||||
}
|
||||
/// Query whether the op has only MemRef input and outputs.
|
||||
bool hasBufferSemantics() {
|
||||
|
|
|
@ -37,7 +37,7 @@ class AffineMapDomainHasDim<int n> : CPred<[{
|
|||
class HasOperandsOfType<string type>: CPred<[{
|
||||
llvm::any_of(op.getOperands(),
|
||||
[](Value v) {
|
||||
return dyn_cast_or_null<}] # type # [{>(v->getDefiningOp());
|
||||
return dyn_cast_or_null<}] # type # [{>(v.getDefiningOp());
|
||||
})
|
||||
}]>;
|
||||
|
||||
|
|
|
@ -120,7 +120,7 @@ template <typename ConcreteOp>
|
|||
SmallVector<Value, 8> getViewSizes(ConcreteOp linalgOp) {
|
||||
SmallVector<Value, 8> res;
|
||||
for (auto v : linalgOp.getInputsAndOutputs()) {
|
||||
MemRefType t = v->getType().template cast<MemRefType>();
|
||||
MemRefType t = v.getType().template cast<MemRefType>();
|
||||
for (unsigned i = 0; i < t.getRank(); ++i)
|
||||
res.push_back(edsc::intrinsics::dim(v, i));
|
||||
}
|
||||
|
|
|
@ -199,7 +199,7 @@ def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
|
|||
let results = (outs quant_RealValueType);
|
||||
|
||||
let verifier = [{
|
||||
auto tensorArg = arg()->getType().dyn_cast<TensorType>();
|
||||
auto tensorArg = arg().getType().dyn_cast<TensorType>();
|
||||
if (!tensorArg) return emitOpError("arg needs to be tensor type.");
|
||||
|
||||
// Verify layerStats attribute.
|
||||
|
|
|
@ -183,7 +183,7 @@ public:
|
|||
Value getSrcMemRef() { return getOperand(0); }
|
||||
// Returns the rank (number of indices) of the source MemRefType.
|
||||
unsigned getSrcMemRefRank() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
// Returns the source memref indices for this DMA operation.
|
||||
operand_range getSrcIndices() {
|
||||
|
@ -195,13 +195,13 @@ public:
|
|||
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
|
||||
// Returns the rank (number of indices) of the destination MemRefType.
|
||||
unsigned getDstMemRefRank() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getDstMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
unsigned getSrcMemorySpace() {
|
||||
return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
return getSrcMemRef().getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
unsigned getDstMemorySpace() {
|
||||
return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
|
||||
return getDstMemRef().getType().cast<MemRefType>().getMemorySpace();
|
||||
}
|
||||
|
||||
// Returns the destination memref indices for this DMA operation.
|
||||
|
@ -222,7 +222,7 @@ public:
|
|||
}
|
||||
// Returns the rank (number of indices) of the tag MemRefType.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
|
@ -313,7 +313,7 @@ public:
|
|||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() {
|
||||
return getTagMemRef()->getType().cast<MemRefType>().getRank();
|
||||
return getTagMemRef().getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
|
|
|
@ -192,7 +192,7 @@ def AllocOp : Std_Op<"alloc"> {
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getAlignmentAttrName() { return "alignment"; }
|
||||
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
|
||||
/// Returns the number of symbolic operands (the ones in square brackets),
|
||||
/// which bind to the symbols of the memref's layout map.
|
||||
|
@ -325,7 +325,7 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
|
|||
"ValueRange operands = {}", [{
|
||||
result.operands.push_back(callee);
|
||||
result.addOperands(operands);
|
||||
result.addTypes(callee->getType().cast<FunctionType>().getResults());
|
||||
result.addTypes(callee.getType().cast<FunctionType>().getResults());
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -723,7 +723,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
|
|||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value aggregate,"
|
||||
"ValueRange indices = {}", [{
|
||||
auto resType = aggregate->getType().cast<ShapedType>()
|
||||
auto resType = aggregate.getType().cast<ShapedType>()
|
||||
.getElementType();
|
||||
build(builder, result, resType, aggregate, indices);
|
||||
}]>];
|
||||
|
@ -809,7 +809,7 @@ def LoadOp : Std_Op<"load"> {
|
|||
let builders = [OpBuilder<
|
||||
"Builder *, OperationState &result, Value memref,"
|
||||
"ValueRange indices = {}", [{
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
result.addOperands(memref);
|
||||
result.addOperands(indices);
|
||||
result.types.push_back(memrefType.getElementType());
|
||||
|
@ -819,7 +819,7 @@ def LoadOp : Std_Op<"load"> {
|
|||
Value getMemRef() { return getOperand(0); }
|
||||
void setMemRef(Value value) { setOperand(0, value); }
|
||||
MemRefType getMemRefType() {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
return getMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
|
||||
|
@ -890,7 +890,7 @@ def MemRefCastOp : CastOp<"memref_cast"> {
|
|||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a memref_cast is always a memref.
|
||||
Type getType() { return getResult()->getType(); }
|
||||
Type getType() { return getResult().getType(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -958,7 +958,7 @@ def PrefetchOp : Std_Op<"prefetch"> {
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return memref()->getType().cast<MemRefType>();
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
static StringRef getLocalityHintAttrName() { return "localityHint"; }
|
||||
static StringRef getIsWriteAttrName() { return "isWrite"; }
|
||||
|
@ -1046,7 +1046,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
|
|||
"Builder *builder, OperationState &result, Value condition,"
|
||||
"Value trueValue, Value falseValue", [{
|
||||
result.addOperands({condition, trueValue, falseValue});
|
||||
result.addTypes(trueValue->getType());
|
||||
result.addTypes(trueValue.getType());
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -1215,7 +1215,7 @@ def StoreOp : Std_Op<"store"> {
|
|||
Value getMemRef() { return getOperand(1); }
|
||||
void setMemRef(Value value) { setOperand(1, value); }
|
||||
MemRefType getMemRefType() {
|
||||
return getMemRef()->getType().cast<MemRefType>();
|
||||
return getMemRef().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
operand_range getIndices() {
|
||||
|
@ -1367,11 +1367,11 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
|
|||
let extraClassDeclaration = [{
|
||||
/// Returns the type of the base memref operand.
|
||||
MemRefType getBaseMemRefType() {
|
||||
return source()->getType().cast<MemRefType>();
|
||||
return source().getType().cast<MemRefType>();
|
||||
}
|
||||
|
||||
/// The result of a subview is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
|
||||
/// Returns as integer value the number of offset operands.
|
||||
int64_t getNumOffsets() { return llvm::size(offsets()); }
|
||||
|
@ -1434,7 +1434,7 @@ def TensorCastOp : CastOp<"tensor_cast"> {
|
|||
static bool areCastCompatible(Type a, Type b);
|
||||
|
||||
/// The result of a tensor_cast is always a tensor.
|
||||
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -1457,7 +1457,7 @@ def TensorLoadOp : Std_Op<"tensor_load",
|
|||
|
||||
let builders = [OpBuilder<
|
||||
"Builder *builder, OperationState &result, Value memref", [{
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto resultType = RankedTensorType::get(memrefType.getShape(),
|
||||
memrefType.getElementType());
|
||||
result.addOperands(memref);
|
||||
|
@ -1467,7 +1467,7 @@ def TensorLoadOp : Std_Op<"tensor_load",
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
/// The result of a tensor_load is always a tensor.
|
||||
TensorType getType() { return getResult()->getType().cast<TensorType>(); }
|
||||
TensorType getType() { return getResult().getType().cast<TensorType>(); }
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -1565,7 +1565,7 @@ def ViewOp : Std_Op<"view", [NoSideEffect]> {
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
/// The result of a view is always a memref.
|
||||
MemRefType getType() { return getResult()->getType().cast<MemRefType>(); }
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
|
||||
/// Returns the dynamic offset for this view operation if specified.
|
||||
/// Returns nullptr if no dynamic offset was specified.
|
||||
|
|
|
@ -123,24 +123,24 @@ def Vector_ContractionOp :
|
|||
"Value acc, ArrayAttr indexingMaps, ArrayAttr iteratorTypes">];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getLhsType() {
|
||||
return lhs()->getType().cast<VectorType>();
|
||||
return lhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getRhsType() {
|
||||
return rhs()->getType().cast<VectorType>();
|
||||
return rhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getAccType() {
|
||||
return acc()->getType().cast<VectorType>();
|
||||
return acc().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getLHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(3)->getType().cast<VectorType>();
|
||||
return getOperand(3).getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getRHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(4)->getType().cast<VectorType>();
|
||||
return getOperand(4).getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
ArrayRef<StringRef> getTraitAttrNames();
|
||||
SmallVector<AffineMap, 4> getIndexingMaps();
|
||||
|
@ -198,9 +198,9 @@ def Vector_BroadcastOp :
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
Type getSourceType() { return source().getType(); }
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -248,13 +248,13 @@ def Vector_ShuffleOp :
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getMaskAttrName() { return "mask"; }
|
||||
VectorType getV1VectorType() {
|
||||
return v1()->getType().cast<VectorType>();
|
||||
return v1().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getV2VectorType() {
|
||||
return v2()->getType().cast<VectorType>();
|
||||
return v2().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ def Vector_ExtractElementOp :
|
|||
}];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ def Vector_ExtractOp :
|
|||
let extraClassDeclaration = [{
|
||||
static StringRef getPositionAttrName() { return "position"; }
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -354,10 +354,10 @@ def Vector_ExtractSlicesOp :
|
|||
"ArrayRef<int64_t> strides">];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getSourceVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
TupleType getResultTupleType() {
|
||||
return getResult()->getType().cast<TupleType>();
|
||||
return getResult().getType().cast<TupleType>();
|
||||
}
|
||||
void getSizes(SmallVectorImpl<int64_t> &results);
|
||||
void getStrides(SmallVectorImpl<int64_t> &results);
|
||||
|
@ -391,9 +391,9 @@ def Vector_InsertElementOp :
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
Type getSourceType() { return source().getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
return dest().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -425,9 +425,9 @@ def Vector_InsertOp :
|
|||
"Value dest, ArrayRef<int64_t>">];
|
||||
let extraClassDeclaration = [{
|
||||
static StringRef getPositionAttrName() { return "position"; }
|
||||
Type getSourceType() { return source()->getType(); }
|
||||
Type getSourceType() { return source().getType(); }
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
return dest().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -472,10 +472,10 @@ def Vector_InsertSlicesOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
TupleType getSourceTupleType() {
|
||||
return vectors()->getType().cast<TupleType>();
|
||||
return vectors().getType().cast<TupleType>();
|
||||
}
|
||||
VectorType getResultVectorType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
void getSizes(SmallVectorImpl<int64_t> &results);
|
||||
void getStrides(SmallVectorImpl<int64_t> &results);
|
||||
|
@ -520,10 +520,10 @@ def Vector_InsertStridedSliceOp :
|
|||
static StringRef getOffsetsAttrName() { return "offsets"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
VectorType getSourceVectorType() {
|
||||
return source()->getType().cast<VectorType>();
|
||||
return source().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getDestVectorType() {
|
||||
return dest()->getType().cast<VectorType>();
|
||||
return dest().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -552,17 +552,17 @@ def Vector_OuterProductOp :
|
|||
}];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getOperandVectorTypeLHS() {
|
||||
return lhs()->getType().cast<VectorType>();
|
||||
return lhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getOperandVectorTypeRHS() {
|
||||
return rhs()->getType().cast<VectorType>();
|
||||
return rhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getOperandVectorTypeACC() {
|
||||
return (llvm::size(acc()) == 0) ? VectorType() :
|
||||
(*acc().begin())->getType().cast<VectorType>();
|
||||
(*acc().begin()).getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -662,10 +662,10 @@ def Vector_ReshapeOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getInputVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getOutputVectorType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
|
||||
/// Returns as integer value the number of input shape operands.
|
||||
|
@ -723,7 +723,7 @@ def Vector_StridedSliceOp :
|
|||
static StringRef getOffsetsAttrName() { return "offsets"; }
|
||||
static StringRef getSizesAttrName() { return "sizes"; }
|
||||
static StringRef getStridesAttrName() { return "strides"; }
|
||||
VectorType getVectorType(){ return vector()->getType().cast<VectorType>(); }
|
||||
VectorType getVectorType(){ return vector().getType().cast<VectorType>(); }
|
||||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
|
@ -862,10 +862,10 @@ def Vector_TransferReadOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return memref()->getType().cast<MemRefType>();
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -933,10 +933,10 @@ def Vector_TransferWriteOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getVectorType() {
|
||||
return vector()->getType().cast<VectorType>();
|
||||
return vector().getType().cast<VectorType>();
|
||||
}
|
||||
MemRefType getMemRefType() {
|
||||
return memref()->getType().cast<MemRefType>();
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -976,10 +976,10 @@ def Vector_TypeCastOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getMemRefType() {
|
||||
return memref()->getType().cast<MemRefType>();
|
||||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
MemRefType getResultMemRefType() {
|
||||
return getResult()->getType().cast<MemRefType>();
|
||||
return getResult().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -1078,7 +1078,7 @@ def Vector_TupleOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
TupleType getResultTupleType() {
|
||||
return getResult()->getType().cast<TupleType>();
|
||||
return getResult().getType().cast<TupleType>();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
@ -1108,7 +1108,7 @@ def Vector_TupleGetOp :
|
|||
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getResultVectorType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
int64_t getIndex() {
|
||||
return getAttrOfType<IntegerAttr>("index").getValue().getSExtValue();
|
||||
|
@ -1144,7 +1144,7 @@ def Vector_PrintOp :
|
|||
let verifier = ?;
|
||||
let extraClassDeclaration = [{
|
||||
Type getPrintType() {
|
||||
return source()->getType();
|
||||
return source().getType();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
include "mlir/IR/OpBase.td"
|
||||
|
||||
class HasShape<list<int> shape> :
|
||||
CPred<"$0->getType().cast<ShapedType>().hasStaticShape({" #
|
||||
CPred<"$0.getType().cast<ShapedType>().hasStaticShape({" #
|
||||
StrJoinInt<shape>.result # "})">;
|
||||
|
||||
class UnrollVectorOp<list<int> factors> : NativeCodeCall<
|
||||
"unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
|
||||
"unrollSingleResultOpMatchingType($_builder, $0.getDefiningOp(), " #
|
||||
"{" # StrJoinInt<factors>.result # "})">;
|
||||
|
||||
#endif // VECTOR_TRANSFORM_PATTERNS
|
||||
|
|
|
@ -303,7 +303,7 @@ public:
|
|||
/// Value. An eager Value represents both the declaration and the definition
|
||||
/// (in the PL sense) of a placeholder for an mlir::Value that has already
|
||||
/// been constructed in the past and that is captured "now" in the program.
|
||||
explicit ValueHandle(Value v) : t(v->getType()), v(v) {}
|
||||
explicit ValueHandle(Value v) : t(v.getType()), v(v) {}
|
||||
|
||||
/// Builds a ConstantIndexOp of value `cst`. The constant is created at the
|
||||
/// current insertion point.
|
||||
|
@ -365,7 +365,7 @@ public:
|
|||
Operation *getOperation() const {
|
||||
if (!v)
|
||||
return nullptr;
|
||||
return v->getDefiningOp();
|
||||
return v.getDefiningOp();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
|
|
@ -36,7 +36,7 @@ struct IndexHandle : public ValueHandle {
|
|||
: ValueHandle(ScopedContext::getBuilder().getIndexType()) {}
|
||||
explicit IndexHandle(index_t v) : ValueHandle(v) {}
|
||||
explicit IndexHandle(Value v) : ValueHandle(v) {
|
||||
assert(v->getType() == ScopedContext::getBuilder().getIndexType() &&
|
||||
assert(v.getType() == ScopedContext::getBuilder().getIndexType() &&
|
||||
"Expected index type");
|
||||
}
|
||||
explicit IndexHandle(ValueHandle v) : ValueHandle(v) {
|
||||
|
|
|
@ -86,7 +86,7 @@ struct constant_int_op_binder {
|
|||
Attribute attr;
|
||||
if (!constant_op_binder<Attribute>(&attr).match(op))
|
||||
return false;
|
||||
auto type = op->getResult(0)->getType();
|
||||
auto type = op->getResult(0).getType();
|
||||
|
||||
if (type.isIntOrIndex()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value).match(attr);
|
||||
|
@ -145,7 +145,7 @@ typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
|
|||
MatcherClass, Operation *>::value,
|
||||
bool>
|
||||
matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
|
||||
if (auto defOp = op->getOperand(idx)->getDefiningOp())
|
||||
if (auto defOp = op->getOperand(idx).getDefiningOp())
|
||||
return matcher.match(defOp);
|
||||
return false;
|
||||
}
|
||||
|
@ -228,7 +228,7 @@ inline detail::constant_int_not_value_matcher<0> m_NonZero() {
|
|||
template <typename Pattern>
|
||||
inline bool matchPattern(Value value, const Pattern &pattern) {
|
||||
// TODO: handle other cases
|
||||
if (auto *op = value->getDefiningOp())
|
||||
if (auto *op = value.getDefiningOp())
|
||||
return const_cast<Pattern &>(pattern).match(op);
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ class Pred;
|
|||
// constraints on each type definition reads naturally and we want to attach
|
||||
// type constraints directly to an operand/result, $_self will be replaced
|
||||
// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
|
||||
// `$_self` will be expanded as `getOperand(...)->getType()`.
|
||||
// `$_self` will be expanded as `getOperand(...).getType()`.
|
||||
class CPred<code pred> : Pred {
|
||||
code predExpr = "(" # pred # ")";
|
||||
}
|
||||
|
@ -1600,7 +1600,7 @@ class Results<dag rets> {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HasNoUseOf: Constraint<
|
||||
CPred<"$_self->use_begin() == $_self->use_end()">, "has no use">;
|
||||
CPred<"$_self.use_empty()">, "has no use">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Common op type constraints
|
||||
|
@ -1661,7 +1661,7 @@ class AllTypesMatch<list<string> names> :
|
|||
// Type Constraint operand `idx`'s Element type is `type`.
|
||||
class TCopVTEtIs<int idx, Type type> : And<[
|
||||
CPred<"$_op.getNumOperands() > " # idx>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # idx # ").getType()",
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))",
|
||||
type.predicate>]>;
|
||||
|
@ -1688,9 +1688,9 @@ class ElementTypeIs<string name, Type type> : PredOpTrait<
|
|||
// type.
|
||||
class TCopVTEtIsSameAs<int i, int j> : And<[
|
||||
CPred<"$_op.getNumOperands() > std::max(" # i # "u," # j # "u)">,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # i # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # i # ").getType()",
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()",
|
||||
IsShapedTypePred>,
|
||||
CPred<"mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == "
|
||||
"mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>;
|
||||
|
@ -1700,16 +1700,16 @@ class TCopVTEtIsSameAs<int i, int j> : And<[
|
|||
class TCOpResIsShapedTypePred<int i, int j> : And<[
|
||||
CPred<"$_op.getNumResults() > " # i>,
|
||||
CPred<"$_op.getNumOperands() > " # j>,
|
||||
SubstLeaves<"$_self", "$_op.getResult(" # i # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getResult(" # i # ").getType()",
|
||||
IsShapedTypePred>,
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ")->getType()",
|
||||
SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()",
|
||||
IsShapedTypePred>]>;
|
||||
|
||||
// Predicate to verify that the i'th result and the j'th operand have the same
|
||||
// type.
|
||||
class TCresIsSameAsOpBase<int i, int j> :
|
||||
CPred<"$_op.getResult(" # i # ")->getType() == "
|
||||
"$_op.getOperand(" # j # ")->getType()">;
|
||||
CPred<"$_op.getResult(" # i # ").getType() == "
|
||||
"$_op.getOperand(" # j # ").getType()">;
|
||||
|
||||
// Basic Predicate to verify that the i'th result and the j'th operand have the
|
||||
// same elemental type.
|
||||
|
@ -1730,8 +1730,8 @@ class TCresVTEtIsSameAsOp<int i, int j> : And<[
|
|||
class TCOpIsBroadcastableToRes<int opId, int resId> : And<[
|
||||
TCOpResIsShapedTypePred<opId, resId>,
|
||||
CPred<"OpTrait::util::getBroadcastedType("
|
||||
"$_op.getOperand(" # opId # ")->getType(), "
|
||||
"$_op.getResult(" # resId # ")->getType())">]>;
|
||||
"$_op.getOperand(" # opId # ").getType(), "
|
||||
"$_op.getResult(" # resId # ").getType())">]>;
|
||||
|
||||
// Predicate to verify that all the operands at the given `indices`
|
||||
// have the same element type.
|
||||
|
|
|
@ -550,7 +550,7 @@ struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
|
|||
}
|
||||
|
||||
/// Return the type of the `i`-th result.
|
||||
Type getType(unsigned i) { return getResult(i)->getType(); }
|
||||
Type getType(unsigned i) { return getResult(i).getType(); }
|
||||
|
||||
/// Result iterator access.
|
||||
result_iterator result_begin() {
|
||||
|
@ -578,13 +578,13 @@ template <typename ConcreteType>
|
|||
class OneResult : public TraitBase<ConcreteType, OneResult> {
|
||||
public:
|
||||
Value getResult() { return this->getOperation()->getResult(0); }
|
||||
Type getType() { return getResult()->getType(); }
|
||||
Type getType() { return getResult().getType(); }
|
||||
|
||||
/// Replace all uses of 'this' value with the new value, updating anything in
|
||||
/// the IR that uses 'this' to use the other value instead. When this returns
|
||||
/// there are zero uses of 'this'.
|
||||
void replaceAllUsesWith(Value newValue) {
|
||||
getResult()->replaceAllUsesWith(newValue);
|
||||
getResult().replaceAllUsesWith(newValue);
|
||||
}
|
||||
|
||||
/// Replace all uses of 'this' value with the result of 'op'.
|
||||
|
|
|
@ -114,14 +114,14 @@ public:
|
|||
os << "(";
|
||||
interleaveComma(op->getNonSuccessorOperands(), os, [&](Value operand) {
|
||||
if (operand)
|
||||
printType(operand->getType());
|
||||
printType(operand.getType());
|
||||
else
|
||||
os << "<<NULL>";
|
||||
});
|
||||
os << ") -> ";
|
||||
if (op->getNumResults() == 1 &&
|
||||
!op->getResult(0)->getType().isa<FunctionType>()) {
|
||||
printType(op->getResult(0)->getType());
|
||||
!op->getResult(0).getType().isa<FunctionType>()) {
|
||||
printType(op->getResult(0).getType());
|
||||
} else {
|
||||
os << '(';
|
||||
interleaveComma(op->getResultTypes(), os);
|
||||
|
|
|
@ -149,14 +149,14 @@ public:
|
|||
|
||||
auto valueIt = values.begin();
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
getResult(i)->replaceAllUsesWith(*(valueIt++));
|
||||
getResult(i).replaceAllUsesWith(*(valueIt++));
|
||||
}
|
||||
|
||||
/// Replace all uses of results of this operation with results of 'op'.
|
||||
void replaceAllUsesWith(Operation *op) {
|
||||
assert(getNumResults() == op->getNumResults());
|
||||
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
|
||||
getResult(i)->replaceAllUsesWith(op->getResult(i));
|
||||
getResult(i).replaceAllUsesWith(op->getResult(i));
|
||||
}
|
||||
|
||||
/// Destroys this operation and its subclass data.
|
||||
|
|
|
@ -90,12 +90,6 @@ public:
|
|||
return U(ownerAndKind);
|
||||
}
|
||||
|
||||
/// Temporary methods to enable transition of Value to being used as a
|
||||
/// value-type.
|
||||
/// TODO(riverriddle) Remove these when all usages have been removed.
|
||||
Value operator*() const { return *this; }
|
||||
Value *operator->() const { return const_cast<Value *>(this); }
|
||||
|
||||
operator bool() const { return ownerAndKind.getPointer(); }
|
||||
bool operator==(const Value &other) const {
|
||||
return ownerAndKind == other.ownerAndKind;
|
||||
|
@ -122,7 +116,7 @@ public:
|
|||
|
||||
/// If this value is the result of an operation, use it as a location,
|
||||
/// otherwise return an unknown location.
|
||||
Location getLoc();
|
||||
Location getLoc() const;
|
||||
|
||||
/// Return the Region in which this Value is defined.
|
||||
Region *getParentRegion();
|
||||
|
@ -236,11 +230,6 @@ class BlockArgument : public Value {
|
|||
public:
|
||||
using Value::Value;
|
||||
|
||||
/// Temporary methods to enable transition of Value to being used as a
|
||||
/// value-type.
|
||||
/// TODO(riverriddle) Remove this when all usages have been removed.
|
||||
BlockArgument *operator->() { return this; }
|
||||
|
||||
static bool classof(Value value) {
|
||||
return value.getKind() == Kind::BlockArgument;
|
||||
}
|
||||
|
@ -288,12 +277,6 @@ class OpResult : public Value {
|
|||
public:
|
||||
using Value::Value;
|
||||
|
||||
/// Temporary methods to enable transition of Value to being used as a
|
||||
/// value-type.
|
||||
/// TODO(riverriddle) Remove these when all usages have been removed.
|
||||
OpResult operator*() { return *this; }
|
||||
OpResult *operator->() { return this; }
|
||||
|
||||
static bool classof(Value value) {
|
||||
return value.getKind() != Kind::BlockArgument;
|
||||
}
|
||||
|
|
|
@ -221,7 +221,7 @@ public:
|
|||
return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor;
|
||||
}
|
||||
|
||||
Operation *getOp() const final { return resultValue->getDefiningOp(); }
|
||||
Operation *getOp() const final { return resultValue.getDefiningOp(); }
|
||||
Value getValue() const final { return resultValue; }
|
||||
|
||||
void printLabel(raw_ostream &os) const override;
|
||||
|
|
|
@ -22,7 +22,7 @@ namespace mlir {
|
|||
template <typename Range>
|
||||
bool areValuesDefinedAbove(Range values, Region &limit) {
|
||||
for (Value v : values)
|
||||
if (!v->getParentRegion()->isProperAncestor(&limit))
|
||||
if (!v.getParentRegion()->isProperAncestor(&limit))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -53,7 +53,7 @@ void mlir::getReachableAffineApplyOps(
|
|||
|
||||
while (!worklist.empty()) {
|
||||
State &state = worklist.back();
|
||||
auto *opInst = state.value->getDefiningOp();
|
||||
auto *opInst = state.value.getDefiningOp();
|
||||
// Note: getDefiningOp will return nullptr if the operand is not an
|
||||
// Operation (i.e. block argument), which is a terminator for the search.
|
||||
if (!isa_and_nonnull<AffineApplyOp>(opInst)) {
|
||||
|
@ -455,7 +455,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
|
|||
auto symbol = operands[i];
|
||||
assert(isValidSymbol(symbol));
|
||||
// Check if the symbol is a constant.
|
||||
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol->getDefiningOp()))
|
||||
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol.getDefiningOp()))
|
||||
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
|
||||
cOp.getValue());
|
||||
}
|
||||
|
|
|
@ -585,7 +585,7 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
|
|||
unsigned d = offset;
|
||||
for (auto aDimValue : aDimValues) {
|
||||
unsigned loc;
|
||||
if (B->findId(*aDimValue, &loc)) {
|
||||
if (B->findId(aDimValue, &loc)) {
|
||||
assert(loc >= offset && "A's dim appears in B's aligned range");
|
||||
assert(loc < B->getNumDimIds() &&
|
||||
"A's dim appears in B's non-dim position");
|
||||
|
@ -608,7 +608,7 @@ static void mergeAndAlignIds(unsigned offset, FlatAffineConstraints *A,
|
|||
unsigned s = B->getNumDimIds();
|
||||
for (auto aSymValue : aSymValues) {
|
||||
unsigned loc;
|
||||
if (B->findId(*aSymValue, &loc)) {
|
||||
if (B->findId(aSymValue, &loc)) {
|
||||
assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
|
||||
"A's symbol appears in B's non-symbol position");
|
||||
swapId(B, s, loc);
|
||||
|
@ -683,7 +683,7 @@ LogicalResult FlatAffineConstraints::composeMap(const AffineValueMap *vMap) {
|
|||
// Dims and symbols.
|
||||
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
|
||||
unsigned loc;
|
||||
bool ret = findId(*vMap->getOperand(i), &loc);
|
||||
bool ret = findId(vMap->getOperand(i), &loc);
|
||||
assert(ret && "value map's id can't be found");
|
||||
(void)ret;
|
||||
// Negate 'eq[r]' since the newly added dimension will be set to this one.
|
||||
|
@ -804,12 +804,12 @@ void FlatAffineConstraints::convertLoopIVSymbolsToDims() {
|
|||
}
|
||||
// Turn each symbol in 'loopIVs' into a dim identifier.
|
||||
for (auto iv : loopIVs) {
|
||||
turnSymbolIntoDim(this, *iv);
|
||||
turnSymbolIntoDim(this, iv);
|
||||
}
|
||||
}
|
||||
|
||||
void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
|
||||
if (containsId(*id))
|
||||
if (containsId(id))
|
||||
return;
|
||||
|
||||
// Caller is expected to fully compose map/operands if necessary.
|
||||
|
@ -826,14 +826,14 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
|
|||
// Add top level symbol.
|
||||
addSymbolId(getNumSymbolIds(), id);
|
||||
// Check if the symbol is a constant.
|
||||
if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id->getDefiningOp()))
|
||||
setIdToConstant(*id, constOp.getValue());
|
||||
if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id.getDefiningOp()))
|
||||
setIdToConstant(id, constOp.getValue());
|
||||
}
|
||||
|
||||
LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
|
||||
unsigned pos;
|
||||
// Pre-condition for this method.
|
||||
if (!findId(*forOp.getInductionVar(), &pos)) {
|
||||
if (!findId(forOp.getInductionVar(), &pos)) {
|
||||
assert(false && "Value not found");
|
||||
return failure();
|
||||
}
|
||||
|
@ -1780,13 +1780,13 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
|
|||
localVarCst.setIdValues(0, localVarCst.getNumDimAndSymbolIds(), operands);
|
||||
for (auto operand : operands) {
|
||||
unsigned pos;
|
||||
if (findId(*operand, &pos)) {
|
||||
if (findId(operand, &pos)) {
|
||||
if (pos >= getNumDimIds() && pos < getNumDimAndSymbolIds()) {
|
||||
// If the local var cst has this as a dim, turn it into its symbol.
|
||||
turnDimIntoSymbol(&localVarCst, *operand);
|
||||
turnDimIntoSymbol(&localVarCst, operand);
|
||||
} else if (pos < getNumDimIds()) {
|
||||
// Or vice versa.
|
||||
turnSymbolIntoDim(&localVarCst, *operand);
|
||||
turnSymbolIntoDim(&localVarCst, operand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1800,7 +1800,7 @@ FlatAffineConstraints::addLowerOrUpperBound(unsigned pos, AffineMap boundMap,
|
|||
unsigned numOperands = operands.size();
|
||||
for (auto operand : operands) {
|
||||
unsigned pos;
|
||||
if (!findId(*operand, &pos))
|
||||
if (!findId(operand, &pos))
|
||||
assert(0 && "expected to be found");
|
||||
positions.push_back(pos);
|
||||
}
|
||||
|
@ -1847,7 +1847,7 @@ LogicalResult FlatAffineConstraints::addSliceBounds(ArrayRef<Value> values,
|
|||
|
||||
for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
|
||||
unsigned pos;
|
||||
if (!findId(*values[i], &pos))
|
||||
if (!findId(values[i], &pos))
|
||||
continue;
|
||||
|
||||
AffineMap lbMap = lbMaps[i];
|
||||
|
@ -2703,7 +2703,7 @@ void FlatAffineConstraints::projectOut(unsigned pos, unsigned num) {
|
|||
|
||||
void FlatAffineConstraints::projectOut(Value id) {
|
||||
unsigned pos;
|
||||
bool ret = findId(*id, &pos);
|
||||
bool ret = findId(id, &pos);
|
||||
assert(ret);
|
||||
(void)ret;
|
||||
FourierMotzkinEliminate(pos);
|
||||
|
|
|
@ -179,7 +179,7 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
|
|||
callee = SymbolTable::lookupNearestSymbolFrom(from,
|
||||
symbolRef.getRootReference());
|
||||
else
|
||||
callee = callable.get<Value>()->getDefiningOp();
|
||||
callee = callable.get<Value>().getDefiningOp();
|
||||
|
||||
// If the callee is non-null and is a valid callable object, try to get the
|
||||
// called region from it.
|
||||
|
|
|
@ -119,7 +119,7 @@ bool DominanceInfo::properlyDominates(Operation *a, Operation *b) {
|
|||
|
||||
/// Return true if value A properly dominates operation B.
|
||||
bool DominanceInfo::properlyDominates(Value a, Operation *b) {
|
||||
if (auto *aOp = a->getDefiningOp()) {
|
||||
if (auto *aOp = a.getDefiningOp()) {
|
||||
// The values defined by an operation do *not* dominate any nested
|
||||
// operations.
|
||||
if (aOp->getParentRegion() != b->getParentRegion() && aOp->isAncestor(b))
|
||||
|
@ -129,7 +129,7 @@ bool DominanceInfo::properlyDominates(Value a, Operation *b) {
|
|||
|
||||
// block arguments properly dominate all operations in their own block, so
|
||||
// we use a dominates check here, not a properlyDominates check.
|
||||
return dominates(a.cast<BlockArgument>()->getOwner(), b->getBlock());
|
||||
return dominates(a.cast<BlockArgument>().getOwner(), b->getBlock());
|
||||
}
|
||||
|
||||
DominanceInfoNode *DominanceInfo::getNode(Block *a) {
|
||||
|
|
|
@ -45,7 +45,7 @@ struct BlockInfoBuilder {
|
|||
// properties of the program, the uses must occur after
|
||||
// the definition. Therefore, we do not have to check
|
||||
// additional conditions to detect an escaping value.
|
||||
for (OpOperand &use : result->getUses())
|
||||
for (OpOperand &use : result.getUses())
|
||||
if (use.getOwner()->getBlock() != block) {
|
||||
outValues.insert(result);
|
||||
break;
|
||||
|
@ -171,15 +171,15 @@ Liveness::OperationListT Liveness::resolveLiveness(Value value) const {
|
|||
|
||||
// Start with the defining block
|
||||
Block *currentBlock;
|
||||
if (Operation *defOp = value->getDefiningOp())
|
||||
if (Operation *defOp = value.getDefiningOp())
|
||||
currentBlock = defOp->getBlock();
|
||||
else
|
||||
currentBlock = value.cast<BlockArgument>()->getOwner();
|
||||
currentBlock = value.cast<BlockArgument>().getOwner();
|
||||
toProcess.push_back(currentBlock);
|
||||
visited.insert(currentBlock);
|
||||
|
||||
// Start with all associated blocks
|
||||
for (OpOperand &use : value->getUses()) {
|
||||
for (OpOperand &use : value.getUses()) {
|
||||
Block *useBlock = use.getOwner()->getBlock();
|
||||
if (visited.insert(useBlock).second)
|
||||
toProcess.push_back(useBlock);
|
||||
|
@ -269,12 +269,12 @@ void Liveness::print(raw_ostream &os) const {
|
|||
|
||||
// Local printing helpers
|
||||
auto printValueRef = [&](Value value) {
|
||||
if (Operation *defOp = value->getDefiningOp())
|
||||
if (Operation *defOp = value.getDefiningOp())
|
||||
os << "val_" << defOp->getName();
|
||||
else {
|
||||
auto blockArg = value.cast<BlockArgument>();
|
||||
os << "arg" << blockArg->getArgNumber() << "@"
|
||||
<< blockIds[blockArg->getOwner()];
|
||||
os << "arg" << blockArg.getArgNumber() << "@"
|
||||
<< blockIds[blockArg.getOwner()];
|
||||
}
|
||||
os << " ";
|
||||
};
|
||||
|
@ -343,7 +343,7 @@ bool LivenessBlockInfo::isLiveOut(Value value) const {
|
|||
/// Gets the start operation for the given value
|
||||
/// (must be referenced in this block).
|
||||
Operation *LivenessBlockInfo::getStartOperation(Value value) const {
|
||||
Operation *definingOp = value->getDefiningOp();
|
||||
Operation *definingOp = value.getDefiningOp();
|
||||
// The given value is either live-in or is defined
|
||||
// in the scope of this block.
|
||||
if (isLiveIn(value) || !definingOp)
|
||||
|
@ -361,7 +361,7 @@ Operation *LivenessBlockInfo::getEndOperation(Value value,
|
|||
|
||||
// Resolve the last operation (must exist by definition).
|
||||
Operation *endOperation = startOperation;
|
||||
for (OpOperand &use : value->getUses()) {
|
||||
for (OpOperand &use : value.getUses()) {
|
||||
Operation *useOperation = use.getOwner();
|
||||
// Check whether the use is in our block and after
|
||||
// the current end operation.
|
||||
|
|
|
@ -166,7 +166,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(AffineForOp forOp) {
|
|||
/// conservative.
|
||||
static bool isAccessIndexInvariant(Value iv, Value index) {
|
||||
assert(isForInductionVar(iv) && "iv must be a AffineForOp");
|
||||
assert(index->getType().isa<IndexType>() && "index must be of IndexType");
|
||||
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
|
||||
SmallVector<Operation *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps({index}, affineApplyOps);
|
||||
|
||||
|
@ -373,7 +373,7 @@ bool mlir::isInstwiseShiftValid(AffineForOp forOp, ArrayRef<uint64_t> shifts) {
|
|||
// Validate the results of this operation if it were to be shifted.
|
||||
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
Value result = op.getResult(i);
|
||||
for (auto *user : result->getUsers()) {
|
||||
for (auto *user : result.getUsers()) {
|
||||
// If an ancestor operation doesn't lie in the block of forOp,
|
||||
// there is no shift to check.
|
||||
if (auto *ancOp = forBody->findAncestorOpInBlock(*user)) {
|
||||
|
|
|
@ -43,18 +43,18 @@ static void getForwardSliceImpl(Operation *op,
|
|||
}
|
||||
|
||||
if (auto forOp = dyn_cast<AffineForOp>(op)) {
|
||||
for (auto *ownerInst : forOp.getInductionVar()->getUsers())
|
||||
for (auto *ownerInst : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
} else if (auto forOp = dyn_cast<loop::ForOp>(op)) {
|
||||
for (auto *ownerInst : forOp.getInductionVar()->getUsers())
|
||||
for (auto *ownerInst : forOp.getInductionVar().getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
} else {
|
||||
assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
|
||||
assert(op->getNumResults() <= 1 && "unexpected multiple results");
|
||||
if (op->getNumResults() > 0) {
|
||||
for (auto *ownerInst : op->getResult(0)->getUsers())
|
||||
for (auto *ownerInst : op->getResult(0).getUsers())
|
||||
if (forwardSlice->count(ownerInst) == 0)
|
||||
getForwardSliceImpl(ownerInst, forwardSlice, filter);
|
||||
}
|
||||
|
@ -105,14 +105,14 @@ static void getBackwardSliceImpl(Operation *op,
|
|||
auto *loopOp = loopIv.getOperation();
|
||||
if (backwardSlice->count(loopOp) == 0)
|
||||
getBackwardSliceImpl(loopOp, backwardSlice, filter);
|
||||
} else if (blockArg->getOwner() !=
|
||||
} else if (blockArg.getOwner() !=
|
||||
&op->getParentOfType<FuncOp>().getBody().front()) {
|
||||
op->emitError("unsupported CF for operand ") << en.index();
|
||||
llvm_unreachable("Unsupported control flow");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto *op = operand->getDefiningOp();
|
||||
auto *op = operand.getDefiningOp();
|
||||
if (backwardSlice->count(op) == 0) {
|
||||
getBackwardSliceImpl(op, backwardSlice, filter);
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ struct DFSState {
|
|||
static void DFSPostorder(Operation *current, DFSState *state) {
|
||||
assert(current->getNumResults() <= 1 && "NYI: multi-result");
|
||||
if (current->getNumResults() > 0) {
|
||||
for (auto &u : current->getResult(0)->getUses()) {
|
||||
for (auto &u : current->getResult(0).getUses()) {
|
||||
auto *op = u.getOwner();
|
||||
DFSPostorder(op, state);
|
||||
}
|
||||
|
|
|
@ -59,12 +59,12 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
|
|||
// Add loop bound constraints for values which are loop IVs and equality
|
||||
// constraints for symbols which are constants.
|
||||
for (const auto &value : values) {
|
||||
assert(cst->containsId(*value) && "value expected to be present");
|
||||
assert(cst->containsId(value) && "value expected to be present");
|
||||
if (isValidSymbol(value)) {
|
||||
// Check if the symbol is a constant.
|
||||
|
||||
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
|
||||
cst->setIdToConstant(*value, cOp.getValue());
|
||||
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
|
||||
cst->setIdToConstant(value, cOp.getValue());
|
||||
} else if (auto loop = getForInductionVarOwner(value)) {
|
||||
if (failed(cst->addAffineForOpDomain(loop)))
|
||||
return failure();
|
||||
|
@ -88,13 +88,13 @@ void ComputationSliceState::clearBounds() {
|
|||
}
|
||||
|
||||
unsigned MemRefRegion::getRank() const {
|
||||
return memref->getType().cast<MemRefType>().getRank();
|
||||
return memref.getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
Optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
|
||||
SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
|
||||
SmallVectorImpl<int64_t> *lbDivisors) const {
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
auto memRefType = memref.getType().cast<MemRefType>();
|
||||
unsigned rank = memRefType.getRank();
|
||||
if (shape)
|
||||
shape->reserve(rank);
|
||||
|
@ -228,9 +228,9 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
|||
auto symbol = operand;
|
||||
assert(isValidSymbol(symbol));
|
||||
// Check if the symbol is a constant.
|
||||
if (auto *op = symbol->getDefiningOp()) {
|
||||
if (auto *op = symbol.getDefiningOp()) {
|
||||
if (auto constOp = dyn_cast<ConstantIndexOp>(op)) {
|
||||
cst.setIdToConstant(*symbol, constOp.getValue());
|
||||
cst.setIdToConstant(symbol, constOp.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -293,7 +293,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
|||
// to guard against potential over-approximation from projection.
|
||||
// TODO(andydavis) Support dynamic memref dimensions.
|
||||
if (addMemRefDimBounds) {
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
auto memRefType = memref.getType().cast<MemRefType>();
|
||||
for (unsigned r = 0; r < rank; r++) {
|
||||
cst.addConstantLowerBound(r, 0);
|
||||
int64_t dimSize = memRefType.getDimSize(r);
|
||||
|
@ -325,7 +325,7 @@ static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
|||
|
||||
// Returns the size of the region.
|
||||
Optional<int64_t> MemRefRegion::getRegionSize() {
|
||||
auto memRefType = memref->getType().cast<MemRefType>();
|
||||
auto memRefType = memref.getType().cast<MemRefType>();
|
||||
|
||||
auto layoutMaps = memRefType.getAffineMaps();
|
||||
if (layoutMaps.size() > 1 ||
|
||||
|
@ -854,7 +854,7 @@ MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
|
|||
}
|
||||
|
||||
unsigned MemRefAccess::getRank() const {
|
||||
return memref->getType().cast<MemRefType>().getRank();
|
||||
return memref.getType().cast<MemRefType>().getRank();
|
||||
}
|
||||
|
||||
bool MemRefAccess::isStore() const { return isa<AffineStoreOp>(opInst); }
|
||||
|
|
|
@ -198,7 +198,7 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op,
|
|||
}
|
||||
return false;
|
||||
} else if (op.getNumResults() == 1) {
|
||||
if (auto v = op.getResult(0)->getType().dyn_cast<VectorType>()) {
|
||||
if (auto v = op.getResult(0).getType().dyn_cast<VectorType>()) {
|
||||
superVectorType = v;
|
||||
} else {
|
||||
// Not a vector type.
|
||||
|
|
|
@ -130,7 +130,7 @@ LogicalResult OperationVerifier::verifyRegion(Region ®ion) {
|
|||
|
||||
LogicalResult OperationVerifier::verifyBlock(Block &block) {
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg->getOwner() != &block)
|
||||
if (arg.getOwner() != &block)
|
||||
return emitError(block, "block argument not owned by block");
|
||||
|
||||
// Verify that this block has a terminator.
|
||||
|
@ -241,7 +241,7 @@ LogicalResult OperationVerifier::verifyDominance(Operation &op) {
|
|||
|
||||
auto diag = op.emitError("operand #")
|
||||
<< operandNo << " does not dominate this use";
|
||||
if (auto *useOp = operand->getDefiningOp())
|
||||
if (auto *useOp = operand.getDefiningOp())
|
||||
diag.attachNote(useOp->getLoc()) << "operand defined here";
|
||||
return failure();
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ public:
|
|||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
|
||||
LLVMType resultType = lowering.convertType(op->getResult(0)->getType())
|
||||
LLVMType resultType = lowering.convertType(op->getResult(0).getType())
|
||||
.template cast<LLVM::LLVMType>();
|
||||
LLVMType funcType = getFunctionType(resultType, operands);
|
||||
StringRef funcName = getFunctionName(resultType);
|
||||
|
@ -64,7 +64,7 @@ private:
|
|||
using LLVM::LLVMType;
|
||||
SmallVector<LLVMType, 1> operandTypes;
|
||||
for (Value operand : operands) {
|
||||
operandTypes.push_back(operand->getType().cast<LLVMType>());
|
||||
operandTypes.push_back(operand.getType().cast<LLVMType>());
|
||||
}
|
||||
return LLVMType::getFunctionTy(resultType, operandTypes,
|
||||
/*isVarArg=*/false);
|
||||
|
|
|
@ -253,7 +253,7 @@ Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
|
|||
arraySize, /*alignment=*/0);
|
||||
for (unsigned idx = 0; idx < numKernelOperands; ++idx) {
|
||||
auto operand = launchOp.getKernelOperand(idx);
|
||||
auto llvmType = operand->getType().cast<LLVM::LLVMType>();
|
||||
auto llvmType = operand.getType().cast<LLVM::LLVMType>();
|
||||
Value memLocation = builder.create<LLVM::AllocaOp>(
|
||||
loc, llvmType.getPointerTo(), one, /*alignment=*/1);
|
||||
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
|
||||
|
|
|
@ -66,7 +66,7 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
|
|||
Value operand = operands.front();
|
||||
|
||||
// TODO(csigg): Generalize to other types of accumulation.
|
||||
assert(op->getOperand(0)->getType().isIntOrFloat());
|
||||
assert(op->getOperand(0).getType().isIntOrFloat());
|
||||
|
||||
// Create the reduction using an accumulator factory.
|
||||
AccumulatorFactory factory =
|
||||
|
@ -87,7 +87,7 @@ private:
|
|||
return getFactory(allReduce.body());
|
||||
}
|
||||
if (allReduce.op()) {
|
||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||
auto type = operand.getType().cast<LLVM::LLVMType>();
|
||||
return getFactory(*allReduce.op(), type.getUnderlyingType());
|
||||
}
|
||||
return AccumulatorFactory();
|
||||
|
@ -127,7 +127,7 @@ private:
|
|||
|
||||
// Return accumulator result.
|
||||
rewriter.setInsertionPointToStart(split);
|
||||
return split->addArgument(lhs->getType());
|
||||
return split->addArgument(lhs.getType());
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -154,7 +154,7 @@ private:
|
|||
template <typename T> AccumulatorFactory getFactory() const {
|
||||
return [](Location loc, Value lhs, Value rhs,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
return rewriter.create<T>(loc, lhs->getType(), lhs, rhs);
|
||||
return rewriter.create<T>(loc, lhs.getType(), lhs, rhs);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -197,10 +197,10 @@ private:
|
|||
Value createBlockReduce(Location loc, Value operand,
|
||||
AccumulatorFactory &accumFactory,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||
auto type = operand.getType().cast<LLVM::LLVMType>();
|
||||
|
||||
// Create shared memory array to store the warp reduction.
|
||||
auto module = operand->getDefiningOp()->getParentOfType<ModuleOp>();
|
||||
auto module = operand.getDefiningOp()->getParentOfType<ModuleOp>();
|
||||
assert(module && "op must belong to a module");
|
||||
Value sharedMemPtr =
|
||||
createSharedMemoryArray(loc, module, type, kWarpSize, rewriter);
|
||||
|
@ -295,7 +295,7 @@ private:
|
|||
assert(thenOperands.size() == elseOperands.size());
|
||||
rewriter.setInsertionPointToStart(continueBlock);
|
||||
for (auto operand : thenOperands)
|
||||
continueBlock->addArgument(operand->getType());
|
||||
continueBlock->addArgument(operand.getType());
|
||||
}
|
||||
|
||||
/// Shortcut for createIf with empty else block and no block operands.
|
||||
|
@ -321,7 +321,7 @@ private:
|
|||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
|
||||
Value isPartialWarp = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize);
|
||||
auto type = operand->getType().cast<LLVM::LLVMType>();
|
||||
auto type = operand.getType().cast<LLVM::LLVMType>();
|
||||
|
||||
createIf(
|
||||
loc, rewriter, isPartialWarp,
|
||||
|
@ -453,7 +453,7 @@ private:
|
|||
/// Returns value divided by the warp size (i.e. 32).
|
||||
Value getDivideByWarpSize(Value value,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = value->getLoc();
|
||||
auto loc = value.getLoc();
|
||||
auto warpSize = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize));
|
||||
return rewriter.create<LLVM::SDivOp>(loc, int32Type, value, warpSize);
|
||||
|
@ -492,7 +492,7 @@ struct GPUShuffleOpLowering : public LLVMOpLowering {
|
|||
gpu::ShuffleOpOperandAdaptor adaptor(operands);
|
||||
|
||||
auto dialect = lowering.getDialect();
|
||||
auto valueTy = adaptor.value()->getType().cast<LLVM::LLVMType>();
|
||||
auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
|
||||
auto int32Type = LLVM::LLVMType::getInt32Ty(dialect);
|
||||
auto predTy = LLVM::LLVMType::getInt1Ty(dialect);
|
||||
auto resultTy = LLVM::LLVMType::getStructTy(dialect, {valueTy, predTy});
|
||||
|
@ -540,7 +540,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
|
|||
for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
|
||||
Value attribution = en.value();
|
||||
|
||||
auto type = attribution->getType().dyn_cast<MemRefType>();
|
||||
auto type = attribution.getType().dyn_cast<MemRefType>();
|
||||
assert(type && type.hasStaticShape() && "unexpected type in attribution");
|
||||
|
||||
uint64_t numElements = type.getNumElements();
|
||||
|
@ -612,7 +612,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
|
|||
// otherwise necessary given that memref sizes are fixed, but we can try
|
||||
// and canonicalize that away later.
|
||||
Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
|
||||
auto type = attribution->getType().cast<MemRefType>();
|
||||
auto type = attribution.getType().cast<MemRefType>();
|
||||
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
|
||||
type, memory);
|
||||
signatureConversion.remapInput(numProperArguments + en.index(), descr);
|
||||
|
@ -624,7 +624,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
|
|||
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
|
||||
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
|
||||
Value attribution = en.value();
|
||||
auto type = attribution->getType().cast<MemRefType>();
|
||||
auto type = attribution.getType().cast<MemRefType>();
|
||||
assert(type && type.hasStaticShape() &&
|
||||
"unexpected type in attribution");
|
||||
|
||||
|
|
|
@ -127,7 +127,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
|
|||
|
||||
// Create the new induction variable to use.
|
||||
BlockArgument newIndVar =
|
||||
header->addArgument(forOperands.lowerBound()->getType());
|
||||
header->addArgument(forOperands.lowerBound().getType());
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
// Apply signature conversion to the body of the forOp. It has a single block,
|
||||
|
@ -166,7 +166,7 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value> operands,
|
|||
|
||||
// Add the step to the induction variable and branch to the header.
|
||||
Value updatedIndVar = rewriter.create<spirv::IAddOp>(
|
||||
loc, newIndVar->getType(), newIndVar, forOperands.step());
|
||||
loc, newIndVar.getType(), newIndVar, forOperands.step());
|
||||
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
|
||||
|
||||
rewriter.eraseOp(forOp);
|
||||
|
|
|
@ -152,7 +152,7 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto rangeOp = cast<RangeOp>(op);
|
||||
auto rangeDescriptorTy =
|
||||
convertLinalgType(rangeOp.getResult()->getType(), lowering);
|
||||
convertLinalgType(rangeOp.getResult().getType(), lowering);
|
||||
|
||||
edsc::ScopedContext context(rewriter, op->getLoc());
|
||||
|
||||
|
@ -251,7 +251,7 @@ public:
|
|||
for (int i = 0, e = memRefType.getRank(); i < e; ++i) {
|
||||
Value indexing = adaptor.indexings()[i];
|
||||
Value min = indexing;
|
||||
if (sliceOp.indexing(i)->getType().isa<RangeType>())
|
||||
if (sliceOp.indexing(i).getType().isa<RangeType>())
|
||||
min = extractvalue(int64Ty, indexing, pos(0));
|
||||
baseOffset = add(baseOffset, mul(min, strides[i]));
|
||||
}
|
||||
|
@ -274,7 +274,7 @@ public:
|
|||
int numNewDims = 0;
|
||||
for (auto en : llvm::enumerate(sliceOp.indexings())) {
|
||||
Value indexing = en.value();
|
||||
if (indexing->getType().isa<RangeType>()) {
|
||||
if (indexing.getType().isa<RangeType>()) {
|
||||
int rank = en.index();
|
||||
Value rangeDescriptor = adaptor.indexings()[rank];
|
||||
Value min = extractvalue(int64Ty, rangeDescriptor, pos(0));
|
||||
|
|
|
@ -215,7 +215,7 @@ struct LoopToGpuConverter {
|
|||
|
||||
// Return true if the value is obviously a constant "one".
|
||||
static bool isConstantOne(Value value) {
|
||||
if (auto def = dyn_cast_or_null<ConstantIndexOp>(value->getDefiningOp()))
|
||||
if (auto def = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
|
||||
return def.getValue() == 1;
|
||||
return false;
|
||||
}
|
||||
|
@ -457,7 +457,7 @@ void LoopToGpuConverter::createLaunch(OpTy rootForOp, OpTy innermostForOp,
|
|||
|
||||
Value ivReplacement =
|
||||
builder.create<AddIOp>(rootForOp.getLoc(), *lbArgumentIt, id);
|
||||
en.value()->replaceAllUsesWith(ivReplacement);
|
||||
en.value().replaceAllUsesWith(ivReplacement);
|
||||
replaceAllUsesInRegionWith(steps[en.index()], *stepArgumentIt,
|
||||
launchOp.body());
|
||||
std::advance(lbArgumentIt, 1);
|
||||
|
|
|
@ -249,7 +249,7 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
|
|||
/*============================================================================*/
|
||||
StructBuilder::StructBuilder(Value v) : value(v) {
|
||||
assert(value != nullptr && "value cannot be null");
|
||||
structType = value->getType().cast<LLVM::LLVMType>();
|
||||
structType = value.getType().cast<LLVM::LLVMType>();
|
||||
}
|
||||
|
||||
Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
|
||||
|
@ -272,7 +272,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
|
|||
MemRefDescriptor::MemRefDescriptor(Value descriptor)
|
||||
: StructBuilder(descriptor) {
|
||||
assert(value != nullptr && "value cannot be null");
|
||||
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
indexType = value.getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
kOffsetPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
|
@ -412,7 +412,7 @@ void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
|
|||
}
|
||||
|
||||
LLVM::LLVMType MemRefDescriptor::getElementType() {
|
||||
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
return value.getType().cast<LLVM::LLVMType>().getStructElementType(
|
||||
kAlignedPtrPosInMemRefDescriptor);
|
||||
}
|
||||
|
||||
|
@ -673,7 +673,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
SmallVector<Value, 4> results;
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
auto type = this->lowering.convertType(op->getResult(i)->getType());
|
||||
auto type = this->lowering.convertType(op->getResult(i).getType());
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
||||
rewriter.getI64ArrayAttr(i)));
|
||||
|
@ -723,21 +723,21 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
|
|||
|
||||
// Cannot convert ops if their operands are not of LLVM type.
|
||||
for (Value operand : operands) {
|
||||
if (!operand || !operand->getType().isa<LLVM::LLVMType>())
|
||||
if (!operand || !operand.getType().isa<LLVM::LLVMType>())
|
||||
return this->matchFailure();
|
||||
}
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto llvmArrayTy = operands[0]->getType().cast<LLVM::LLVMType>();
|
||||
auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
|
||||
|
||||
if (!llvmArrayTy.isArrayTy()) {
|
||||
auto newOp = rewriter.create<TargetOp>(
|
||||
op->getLoc(), operands[0]->getType(), operands, op->getAttrs());
|
||||
op->getLoc(), operands[0].getType(), operands, op->getAttrs());
|
||||
rewriter.replaceOp(op, newOp.getResult());
|
||||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
auto vectorType = op->getResult(0)->getType().dyn_cast<VectorType>();
|
||||
auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
|
||||
if (!vectorType)
|
||||
return this->matchFailure();
|
||||
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, this->lowering);
|
||||
|
@ -1032,7 +1032,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
|
|||
Value subbed =
|
||||
rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
|
||||
Value offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
|
||||
Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
|
||||
Value aligned = rewriter.create<LLVM::GEPOp>(loc, allocated.getType(),
|
||||
allocated, offset);
|
||||
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, elementPtrType, ArrayRef<Value>(aligned));
|
||||
|
@ -1132,7 +1132,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> {
|
|||
SmallVector<Value, 4> results;
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
auto type = this->lowering.convertType(op->getResult(i)->getType());
|
||||
auto type = this->lowering.convertType(op->getResult(i).getType());
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
op->getLoc(), type, newOp.getOperation()->getResult(0),
|
||||
rewriter.getI64ArrayAttr(i)));
|
||||
|
@ -1207,7 +1207,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> {
|
|||
|
||||
OperandAdaptor<TanhOp> transformed(operands);
|
||||
LLVMTypeT operandType =
|
||||
transformed.operand()->getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
transformed.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
|
||||
|
||||
if (!operandType)
|
||||
return matchFailure();
|
||||
|
@ -1249,12 +1249,12 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
|
||||
PatternMatchResult match(Operation *op) const override {
|
||||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
Type srcType = memRefCastOp.getOperand()->getType();
|
||||
Type srcType = memRefCastOp.getOperand().getType();
|
||||
Type dstType = memRefCastOp.getType();
|
||||
|
||||
if (srcType.isa<MemRefType>() && dstType.isa<MemRefType>()) {
|
||||
MemRefType sourceType =
|
||||
memRefCastOp.getOperand()->getType().cast<MemRefType>();
|
||||
memRefCastOp.getOperand().getType().cast<MemRefType>();
|
||||
MemRefType targetType = memRefCastOp.getType().cast<MemRefType>();
|
||||
return (isSupportedMemRefType(targetType) &&
|
||||
isSupportedMemRefType(sourceType))
|
||||
|
@ -1278,7 +1278,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
|
|||
auto memRefCastOp = cast<MemRefCastOp>(op);
|
||||
OperandAdaptor<MemRefCastOp> transformed(operands);
|
||||
|
||||
auto srcType = memRefCastOp.getOperand()->getType();
|
||||
auto srcType = memRefCastOp.getOperand().getType();
|
||||
auto dstType = memRefCastOp.getType();
|
||||
auto targetStructType = lowering.convertType(memRefCastOp.getType());
|
||||
auto loc = op->getLoc();
|
||||
|
@ -1349,7 +1349,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dimOp = cast<DimOp>(op);
|
||||
OperandAdaptor<DimOp> transformed(operands);
|
||||
MemRefType type = dimOp.getOperand()->getType().cast<MemRefType>();
|
||||
MemRefType type = dimOp.getOperand().getType().cast<MemRefType>();
|
||||
|
||||
auto shape = type.getShape();
|
||||
int64_t index = dimOp.getIndex();
|
||||
|
@ -1529,9 +1529,9 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> {
|
|||
auto indexCastOp = cast<IndexCastOp>(op);
|
||||
|
||||
auto targetType =
|
||||
this->lowering.convertType(indexCastOp.getResult()->getType())
|
||||
this->lowering.convertType(indexCastOp.getResult().getType())
|
||||
.cast<LLVM::LLVMType>();
|
||||
auto sourceType = transformed.in()->getType().cast<LLVM::LLVMType>();
|
||||
auto sourceType = transformed.in().getType().cast<LLVM::LLVMType>();
|
||||
unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth();
|
||||
unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth();
|
||||
|
||||
|
@ -1564,7 +1564,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> {
|
|||
CmpIOpOperandAdaptor transformed(operands);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
|
||||
op, lowering.convertType(cmpiOp.getResult()->getType()),
|
||||
op, lowering.convertType(cmpiOp.getResult().getType()),
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(
|
||||
convertCmpPredicate<LLVM::ICmpPredicate>(cmpiOp.getPredicate()))),
|
||||
transformed.lhs(), transformed.rhs());
|
||||
|
@ -1583,7 +1583,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> {
|
|||
CmpFOpOperandAdaptor transformed(operands);
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::FCmpOp>(
|
||||
op, lowering.convertType(cmpfOp.getResult()->getType()),
|
||||
op, lowering.convertType(cmpfOp.getResult().getType()),
|
||||
rewriter.getI64IntegerAttr(static_cast<int64_t>(
|
||||
convertCmpPredicate<LLVM::FCmpPredicate>(cmpfOp.getPredicate()))),
|
||||
transformed.lhs(), transformed.rhs());
|
||||
|
@ -1807,7 +1807,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> {
|
|||
1 + viewOp.getNumOffsets() + viewOp.getNumSizes()),
|
||||
operands.end());
|
||||
|
||||
auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
|
||||
auto sourceMemRefType = viewOp.source().getType().cast<MemRefType>();
|
||||
auto sourceElementTy =
|
||||
lowering.convertType(sourceMemRefType.getElementType())
|
||||
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||
|
@ -2174,7 +2174,7 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
|
|||
auto indexType = IndexType::get(context);
|
||||
// Alloca with proper alignment. We do not expect optimizations of this
|
||||
// alloca op and so we omit allocating at the entry block.
|
||||
auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo();
|
||||
auto ptrType = operand.getType().cast<LLVM::LLVMType>().getPointerTo();
|
||||
Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
|
||||
IntegerAttr::get(indexType, 1));
|
||||
Value allocated =
|
||||
|
@ -2193,8 +2193,8 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
|
|||
for (auto it : llvm::zip(opOperands, operands)) {
|
||||
auto operand = std::get<0>(it);
|
||||
auto llvmOperand = std::get<1>(it);
|
||||
if (!operand->getType().isa<MemRefType>() &&
|
||||
!operand->getType().isa<UnrankedMemRefType>()) {
|
||||
if (!operand.getType().isa<MemRefType>() &&
|
||||
!operand.getType().isa<UnrankedMemRefType>()) {
|
||||
promotedOperands.push_back(operand);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ public:
|
|||
matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultType =
|
||||
this->typeConverter.convertType(operation.getResult()->getType());
|
||||
this->typeConverter.convertType(operation.getResult().getType());
|
||||
rewriter.template replaceOpWithNewOp<SPIRVOp>(
|
||||
operation, resultType, operands, ArrayRef<NamedAttribute>());
|
||||
return this->matchSuccess();
|
||||
|
@ -178,7 +178,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder,
|
|||
PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
||||
ConstantOp constIndexOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
|
||||
if (!constIndexOp.getResult().getType().isa<IndexType>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
// The attribute has index type which is not directly supported in
|
||||
|
@ -197,7 +197,7 @@ PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
|
|||
return matchFailure();
|
||||
}
|
||||
auto spirvConstType =
|
||||
typeConverter.convertType(constIndexOp.getResult()->getType());
|
||||
typeConverter.convertType(constIndexOp.getResult().getType());
|
||||
auto spirvConstVal =
|
||||
rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
|
||||
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
|
||||
|
@ -217,9 +217,9 @@ CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
|
|||
switch (cmpFOp.getPredicate()) {
|
||||
#define DISPATCH(cmpPredicate, spirvOp) \
|
||||
case cmpPredicate: \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>( \
|
||||
cmpFOp, cmpFOp.getResult()->getType(), cmpFOpOperands.lhs(), \
|
||||
cmpFOpOperands.rhs()); \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
|
||||
cmpFOpOperands.lhs(), \
|
||||
cmpFOpOperands.rhs()); \
|
||||
return matchSuccess();
|
||||
|
||||
// Ordered.
|
||||
|
@ -257,9 +257,9 @@ CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
|
|||
switch (cmpIOp.getPredicate()) {
|
||||
#define DISPATCH(cmpPredicate, spirvOp) \
|
||||
case cmpPredicate: \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>( \
|
||||
cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(), \
|
||||
cmpIOpOperands.rhs()); \
|
||||
rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
|
||||
cmpIOpOperands.lhs(), \
|
||||
cmpIOpOperands.rhs()); \
|
||||
return matchSuccess();
|
||||
|
||||
DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
|
||||
|
@ -287,7 +287,7 @@ LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
|
|||
ConversionPatternRewriter &rewriter) const {
|
||||
LoadOpOperandAdaptor loadOperands(operands);
|
||||
auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
|
||||
loadOp.memref()->getType().cast<MemRefType>(),
|
||||
loadOp.memref().getType().cast<MemRefType>(),
|
||||
loadOperands.memref(), loadOperands.indices());
|
||||
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
|
||||
/*memory_access =*/nullptr,
|
||||
|
@ -333,7 +333,7 @@ StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
|
|||
StoreOpOperandAdaptor storeOperands(operands);
|
||||
auto storePtr =
|
||||
getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
|
||||
storeOp.memref()->getType().cast<MemRefType>(),
|
||||
storeOp.memref().getType().cast<MemRefType>(),
|
||||
storeOperands.memref(), storeOperands.indices());
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
|
||||
storeOperands.value(),
|
||||
|
|
|
@ -110,8 +110,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
|
|||
PatternMatchResult
|
||||
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
dyn_cast_or_null<SubViewOp>(loadOp.memref()->getDefiningOp());
|
||||
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
|
||||
if (!subViewOp) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
@ -133,7 +132,7 @@ PatternMatchResult
|
|||
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
auto subViewOp =
|
||||
dyn_cast_or_null<SubViewOp>(storeOp.memref()->getDefiningOp());
|
||||
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
|
||||
if (!subViewOp) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
|
|
@ -368,7 +368,7 @@ public:
|
|||
auto adaptor = vector::ExtractOpOperandAdaptor(operands);
|
||||
auto extractOp = cast<vector::ExtractOp>(op);
|
||||
auto vectorType = extractOp.getVectorType();
|
||||
auto resultType = extractOp.getResult()->getType();
|
||||
auto resultType = extractOp.getResult().getType();
|
||||
auto llvmResultType = lowering.convertType(resultType);
|
||||
auto positionArrayAttr = extractOp.position();
|
||||
|
||||
|
@ -647,12 +647,12 @@ public:
|
|||
auto loc = op->getLoc();
|
||||
auto adaptor = vector::OuterProductOpOperandAdaptor(operands);
|
||||
auto *ctx = op->getContext();
|
||||
auto vLHS = adaptor.lhs()->getType().cast<LLVM::LLVMType>();
|
||||
auto vRHS = adaptor.rhs()->getType().cast<LLVM::LLVMType>();
|
||||
auto vLHS = adaptor.lhs().getType().cast<LLVM::LLVMType>();
|
||||
auto vRHS = adaptor.rhs().getType().cast<LLVM::LLVMType>();
|
||||
auto rankLHS = vLHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto rankRHS = vRHS.getUnderlyingType()->getVectorNumElements();
|
||||
auto llvmArrayOfVectType = lowering.convertType(
|
||||
cast<vector::OuterProductOp>(op).getResult()->getType());
|
||||
cast<vector::OuterProductOp>(op).getResult().getType());
|
||||
Value desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayOfVectType);
|
||||
Value a = adaptor.lhs(), b = adaptor.rhs();
|
||||
Value acc = adaptor.acc().empty() ? nullptr : adaptor.acc().front();
|
||||
|
@ -699,9 +699,9 @@ public:
|
|||
auto loc = op->getLoc();
|
||||
vector::TypeCastOp castOp = cast<vector::TypeCastOp>(op);
|
||||
MemRefType sourceMemRefType =
|
||||
castOp.getOperand()->getType().cast<MemRefType>();
|
||||
castOp.getOperand().getType().cast<MemRefType>();
|
||||
MemRefType targetMemRefType =
|
||||
castOp.getResult()->getType().cast<MemRefType>();
|
||||
castOp.getResult().getType().cast<MemRefType>();
|
||||
|
||||
// Only static shape casts supported atm.
|
||||
if (!sourceMemRefType.hasStaticShape() ||
|
||||
|
@ -709,7 +709,7 @@ public:
|
|||
return matchFailure();
|
||||
|
||||
auto llvmSourceDescriptorTy =
|
||||
operands[0]->getType().dyn_cast<LLVM::LLVMType>();
|
||||
operands[0].getType().dyn_cast<LLVM::LLVMType>();
|
||||
if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy())
|
||||
return matchFailure();
|
||||
MemRefDescriptor sourceMemRef(operands[0]);
|
||||
|
|
|
@ -108,8 +108,8 @@ static bool isFunctionRegion(Region *region) {
|
|||
/// symbol.
|
||||
bool mlir::isTopLevelValue(Value value) {
|
||||
if (auto arg = value.dyn_cast<BlockArgument>())
|
||||
return isFunctionRegion(arg->getOwner()->getParent());
|
||||
return isFunctionRegion(value->getDefiningOp()->getParentRegion());
|
||||
return isFunctionRegion(arg.getOwner()->getParent());
|
||||
return isFunctionRegion(value.getDefiningOp()->getParentRegion());
|
||||
}
|
||||
|
||||
// Value can be used as a dimension id if it is valid as a symbol, or
|
||||
|
@ -117,10 +117,10 @@ bool mlir::isTopLevelValue(Value value) {
|
|||
// with dimension id arguments.
|
||||
bool mlir::isValidDim(Value value) {
|
||||
// The value must be an index type.
|
||||
if (!value->getType().isIndex())
|
||||
if (!value.getType().isIndex())
|
||||
return false;
|
||||
|
||||
if (auto *op = value->getDefiningOp()) {
|
||||
if (auto *op = value.getDefiningOp()) {
|
||||
// Top level operation or constant operation is ok.
|
||||
if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
|
||||
return true;
|
||||
|
@ -134,7 +134,7 @@ bool mlir::isValidDim(Value value) {
|
|||
return false;
|
||||
}
|
||||
// This value has to be a block argument for a FuncOp or an affine.for.
|
||||
auto *parentOp = value.cast<BlockArgument>()->getOwner()->getParentOp();
|
||||
auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
|
||||
return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
|
||||
}
|
||||
|
||||
|
@ -162,11 +162,11 @@ static bool isDimOpValidSymbol(DimOp dimOp) {
|
|||
// The dim op is also okay if its operand memref/tensor is a view/subview
|
||||
// whose corresponding size is a valid symbol.
|
||||
unsigned index = dimOp.getIndex();
|
||||
if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp()))
|
||||
if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp()))
|
||||
return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
|
||||
if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp()))
|
||||
if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp()))
|
||||
return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
|
||||
if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp()))
|
||||
if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp()))
|
||||
return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
|
||||
return false;
|
||||
}
|
||||
|
@ -177,10 +177,10 @@ static bool isDimOpValidSymbol(DimOp dimOp) {
|
|||
// constraints.
|
||||
bool mlir::isValidSymbol(Value value) {
|
||||
// The value must be an index type.
|
||||
if (!value->getType().isIndex())
|
||||
if (!value.getType().isIndex())
|
||||
return false;
|
||||
|
||||
if (auto *op = value->getDefiningOp()) {
|
||||
if (auto *op = value.getDefiningOp()) {
|
||||
// Top level operation or constant operation is ok.
|
||||
if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
|
||||
return true;
|
||||
|
@ -283,7 +283,7 @@ LogicalResult AffineApplyOp::verify() {
|
|||
return emitOpError("operands must be of type 'index'");
|
||||
}
|
||||
|
||||
if (!getResult()->getType().isIndex())
|
||||
if (!getResult().getType().isIndex())
|
||||
return emitOpError("result must be of type 'index'");
|
||||
|
||||
// Verify that the map only produces one result.
|
||||
|
@ -332,7 +332,7 @@ AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
|
|||
if (inserted) {
|
||||
reorderedDims.push_back(v);
|
||||
}
|
||||
return getAffineDimExpr(iterPos->second, v->getContext())
|
||||
return getAffineDimExpr(iterPos->second, v.getContext())
|
||||
.cast<AffineDimExpr>();
|
||||
}
|
||||
|
||||
|
@ -365,7 +365,7 @@ static llvm::SetVector<unsigned>
|
|||
indicesFromAffineApplyOp(ArrayRef<Value> operands) {
|
||||
llvm::SetVector<unsigned> res;
|
||||
for (auto en : llvm::enumerate(operands))
|
||||
if (isa_and_nonnull<AffineApplyOp>(en.value()->getDefiningOp()))
|
||||
if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
|
||||
res.insert(en.index());
|
||||
return res;
|
||||
}
|
||||
|
@ -487,7 +487,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
|
|||
// 1. Only dispatch dims or symbols.
|
||||
for (auto en : llvm::enumerate(operands)) {
|
||||
auto t = en.value();
|
||||
assert(t->getType().isIndex());
|
||||
assert(t.getType().isIndex());
|
||||
bool isDim = (en.index() < map.getNumDims());
|
||||
if (isDim) {
|
||||
// a. The mathematical composition of AffineMap composes dims.
|
||||
|
@ -503,7 +503,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
|
|||
// 2. Compose AffineApplyOps and dispatch dims or symbols.
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
auto t = operands[i];
|
||||
auto affineApply = dyn_cast_or_null<AffineApplyOp>(t->getDefiningOp());
|
||||
auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
|
||||
if (affineApply) {
|
||||
// a. Compose affine.apply operations.
|
||||
LLVM_DEBUG(affineApply.getOperation()->print(
|
||||
|
@ -588,7 +588,7 @@ static void composeAffineMapAndOperands(AffineMap *map,
|
|||
void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
|
||||
SmallVectorImpl<Value> *operands) {
|
||||
while (llvm::any_of(*operands, [](Value v) {
|
||||
return isa_and_nonnull<AffineApplyOp>(v->getDefiningOp());
|
||||
return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
|
||||
})) {
|
||||
composeAffineMapAndOperands(map, operands);
|
||||
}
|
||||
|
@ -819,8 +819,8 @@ void AffineApplyOp::getCanonicalizationPatterns(
|
|||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
|
||||
if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
|
||||
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
|
||||
operand.set(cast.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
|
@ -856,16 +856,16 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result,
|
|||
}
|
||||
|
||||
void AffineDmaStartOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.dma_start " << *getSrcMemRef() << '[';
|
||||
p << "affine.dma_start " << getSrcMemRef() << '[';
|
||||
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
|
||||
p << "], " << *getDstMemRef() << '[';
|
||||
p << "], " << getDstMemRef() << '[';
|
||||
p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
|
||||
p << "], " << *getTagMemRef() << '[';
|
||||
p << "], " << getTagMemRef() << '[';
|
||||
p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
|
||||
p << "], " << *getNumElements();
|
||||
p << "], " << getNumElements();
|
||||
if (isStrided()) {
|
||||
p << ", " << *getStride();
|
||||
p << ", " << *getNumElementsPerStride();
|
||||
p << ", " << getStride();
|
||||
p << ", " << getNumElementsPerStride();
|
||||
}
|
||||
p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
|
||||
<< getTagMemRefType();
|
||||
|
@ -951,11 +951,11 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
LogicalResult AffineDmaStartOp::verify() {
|
||||
if (!getOperand(getSrcMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA source to be of memref type");
|
||||
if (!getOperand(getDstMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA destination to be of memref type");
|
||||
if (!getOperand(getTagMemRefOperandIndex())->getType().isa<MemRefType>())
|
||||
if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA tag to be of memref type");
|
||||
|
||||
// DMAs from different memory spaces supported.
|
||||
|
@ -971,19 +971,19 @@ LogicalResult AffineDmaStartOp::verify() {
|
|||
}
|
||||
|
||||
for (auto idx : getSrcIndices()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("src index to dma_start must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("src index must be a dimension or symbol identifier");
|
||||
}
|
||||
for (auto idx : getDstIndices()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("dst index to dma_start must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("dst index must be a dimension or symbol identifier");
|
||||
}
|
||||
for (auto idx : getTagIndices()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("tag index to dma_start must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("tag index must be a dimension or symbol identifier");
|
||||
|
@ -1012,12 +1012,12 @@ void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
|
|||
}
|
||||
|
||||
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.dma_wait " << *getTagMemRef() << '[';
|
||||
p << "affine.dma_wait " << getTagMemRef() << '[';
|
||||
SmallVector<Value, 2> operands(getTagIndices());
|
||||
p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
|
||||
p << "], ";
|
||||
p.printOperand(getNumElements());
|
||||
p << " : " << getTagMemRef()->getType();
|
||||
p << " : " << getTagMemRef().getType();
|
||||
}
|
||||
|
||||
// Parse AffineDmaWaitOp.
|
||||
|
@ -1056,10 +1056,10 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
LogicalResult AffineDmaWaitOp::verify() {
|
||||
if (!getOperand(0)->getType().isa<MemRefType>())
|
||||
if (!getOperand(0).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA tag to be of memref type");
|
||||
for (auto idx : getTagIndices()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("index to dma_wait must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("index must be a dimension or symbol identifier");
|
||||
|
@ -1123,8 +1123,7 @@ static LogicalResult verify(AffineForOp op) {
|
|||
// Check that the body defines as single block argument for the induction
|
||||
// variable.
|
||||
auto *body = op.getBody();
|
||||
if (body->getNumArguments() != 1 ||
|
||||
!body->getArgument(0)->getType().isIndex())
|
||||
if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
|
||||
return op.emitOpError(
|
||||
"expected body to have a single index argument for the "
|
||||
"induction variable");
|
||||
|
@ -1553,7 +1552,7 @@ bool AffineForOp::matchingBoundOperandList() {
|
|||
Region &AffineForOp::getLoopBody() { return region(); }
|
||||
|
||||
bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
|
||||
return !region().isAncestor(value->getParentRegion());
|
||||
return !region().isAncestor(value.getParentRegion());
|
||||
}
|
||||
|
||||
LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
|
||||
|
@ -1571,9 +1570,9 @@ bool mlir::isForInductionVar(Value val) {
|
|||
/// not an induction variable, then return nullptr.
|
||||
AffineForOp mlir::getForInductionVarOwner(Value val) {
|
||||
auto ivArg = val.dyn_cast<BlockArgument>();
|
||||
if (!ivArg || !ivArg->getOwner())
|
||||
if (!ivArg || !ivArg.getOwner())
|
||||
return AffineForOp();
|
||||
auto *containingInst = ivArg->getOwner()->getParent()->getParentOp();
|
||||
auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
|
||||
return dyn_cast<AffineForOp>(containingInst);
|
||||
}
|
||||
|
||||
|
@ -1744,7 +1743,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result,
|
|||
result.addOperands(operands);
|
||||
if (map)
|
||||
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
|
||||
auto memrefType = operands[0]->getType().cast<MemRefType>();
|
||||
auto memrefType = operands[0].getType().cast<MemRefType>();
|
||||
result.types.push_back(memrefType.getElementType());
|
||||
}
|
||||
|
||||
|
@ -1753,14 +1752,14 @@ void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
|
|||
assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
|
||||
result.addOperands(memref);
|
||||
result.addOperands(mapOperands);
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
|
||||
result.types.push_back(memrefType.getElementType());
|
||||
}
|
||||
|
||||
void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
|
||||
ValueRange indices) {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto rank = memrefType.getRank();
|
||||
// Create identity map for memrefs with at least one dimension or () -> ()
|
||||
// for zero-dimensional memrefs.
|
||||
|
@ -1789,7 +1788,7 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
void AffineLoadOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.load " << *getMemRef() << '[';
|
||||
p << "affine.load " << getMemRef() << '[';
|
||||
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
|
||||
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
|
||||
p << ']';
|
||||
|
@ -1816,7 +1815,7 @@ LogicalResult AffineLoadOp::verify() {
|
|||
}
|
||||
|
||||
for (auto idx : getMapOperands()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("index to load must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("index must be a dimension or symbol identifier");
|
||||
|
@ -1854,7 +1853,7 @@ void AffineStoreOp::build(Builder *builder, OperationState &result,
|
|||
void AffineStoreOp::build(Builder *builder, OperationState &result,
|
||||
Value valueToStore, Value memref,
|
||||
ValueRange indices) {
|
||||
auto memrefType = memref->getType().cast<MemRefType>();
|
||||
auto memrefType = memref.getType().cast<MemRefType>();
|
||||
auto rank = memrefType.getRank();
|
||||
// Create identity map for memrefs with at least one dimension or () -> ()
|
||||
// for zero-dimensional memrefs.
|
||||
|
@ -1885,8 +1884,8 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
void AffineStoreOp::print(OpAsmPrinter &p) {
|
||||
p << "affine.store " << *getValueToStore();
|
||||
p << ", " << *getMemRef() << '[';
|
||||
p << "affine.store " << getValueToStore();
|
||||
p << ", " << getMemRef() << '[';
|
||||
if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
|
||||
p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
|
||||
p << ']';
|
||||
|
@ -1896,7 +1895,7 @@ void AffineStoreOp::print(OpAsmPrinter &p) {
|
|||
|
||||
LogicalResult AffineStoreOp::verify() {
|
||||
// First operand must have same type as memref element type.
|
||||
if (getValueToStore()->getType() != getMemRefType().getElementType())
|
||||
if (getValueToStore().getType() != getMemRefType().getElementType())
|
||||
return emitOpError("first operand must have same type memref element type");
|
||||
|
||||
auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
|
||||
|
@ -1914,7 +1913,7 @@ LogicalResult AffineStoreOp::verify() {
|
|||
}
|
||||
|
||||
for (auto idx : getMapOperands()) {
|
||||
if (!idx->getType().isIndex())
|
||||
if (!idx.getType().isIndex())
|
||||
return emitOpError("index to store must have 'index' type");
|
||||
if (!isValidAffineIndexOperand(idx))
|
||||
return emitOpError("index must be a dimension or symbol identifier");
|
||||
|
@ -2059,7 +2058,7 @@ static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
void print(OpAsmPrinter &p, AffinePrefetchOp op) {
|
||||
p << AffinePrefetchOp::getOperationName() << " " << *op.memref() << '[';
|
||||
p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
|
||||
AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
|
||||
if (mapAttr) {
|
||||
SmallVector<Value, 2> operands(op.getMapOperands());
|
||||
|
|
|
@ -47,8 +47,8 @@ static Value emitUniformPerLayerDequantize(Location loc, Value input,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Type storageType = elementType.castToStorageType(input->getType());
|
||||
Type realType = elementType.castToExpressedType(input->getType());
|
||||
Type storageType = elementType.castToStorageType(input.getType());
|
||||
Type realType = elementType.castToExpressedType(input.getType());
|
||||
Type intermediateType =
|
||||
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
|
||||
assert(storageType && "cannot cast to storage type");
|
||||
|
@ -90,7 +90,7 @@ emitUniformPerAxisDequantize(Location loc, Value input,
|
|||
|
||||
static Value emitDequantize(Location loc, Value input,
|
||||
PatternRewriter &rewriter) {
|
||||
Type inputType = input->getType();
|
||||
Type inputType = input.getType();
|
||||
QuantizedType qElementType =
|
||||
QuantizedType::getQuantizedElementType(inputType);
|
||||
if (auto uperLayerElementType =
|
||||
|
@ -113,8 +113,8 @@ struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
|
|||
|
||||
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Type inputType = op.arg()->getType();
|
||||
Type outputType = op.getResult()->getType();
|
||||
Type inputType = op.arg().getType();
|
||||
Type outputType = op.getResult().getType();
|
||||
|
||||
QuantizedType inputElementType =
|
||||
QuantizedType::getQuantizedElementType(inputType);
|
||||
|
|
|
@ -53,11 +53,11 @@ struct UniformBinaryOpInfo {
|
|||
UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
|
||||
Optional<APFloat> clampMin, Optional<APFloat> clampMax)
|
||||
: op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
|
||||
lhsType(getUniformElementType(lhs->getType())),
|
||||
rhsType(getUniformElementType(rhs->getType())),
|
||||
lhsType(getUniformElementType(lhs.getType())),
|
||||
rhsType(getUniformElementType(rhs.getType())),
|
||||
resultType(getUniformElementType(*op->result_type_begin())),
|
||||
lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())),
|
||||
rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())),
|
||||
lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
|
||||
rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
|
||||
resultStorageType(
|
||||
quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
|
||||
}
|
||||
|
|
|
@ -110,7 +110,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
|||
// to encode target module" has landed.
|
||||
// auto functionType = kernelFunc.getType();
|
||||
// for (unsigned i = 0; i < numKernelFuncArgs; ++i) {
|
||||
// if (getKernelOperand(i)->getType() != functionType.getInput(i)) {
|
||||
// if (getKernelOperand(i).getType() != functionType.getInput(i)) {
|
||||
// return emitOpError("type of function argument ")
|
||||
// << i << " does not match";
|
||||
// }
|
||||
|
@ -137,7 +137,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
|
|||
if (allReduce.body().front().getNumArguments() != 2)
|
||||
return allReduce.emitError("expected two region arguments");
|
||||
for (auto argument : allReduce.body().front().getArguments()) {
|
||||
if (argument->getType() != allReduce.getType())
|
||||
if (argument.getType() != allReduce.getType())
|
||||
return allReduce.emitError("incorrect region argument type");
|
||||
}
|
||||
unsigned yieldCount = 0;
|
||||
|
@ -145,7 +145,7 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
|
|||
if (auto yield = dyn_cast<gpu::YieldOp>(block.getTerminator())) {
|
||||
if (yield.getNumOperands() != 1)
|
||||
return allReduce.emitError("expected one gpu.yield operand");
|
||||
if (yield.getOperand(0)->getType() != allReduce.getType())
|
||||
if (yield.getOperand(0).getType() != allReduce.getType())
|
||||
return allReduce.emitError("incorrect gpu.yield type");
|
||||
++yieldCount;
|
||||
}
|
||||
|
@ -157,8 +157,8 @@ static LogicalResult verifyAllReduce(gpu::AllReduceOp allReduce) {
|
|||
}
|
||||
|
||||
static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
|
||||
auto type = shuffleOp.value()->getType();
|
||||
if (shuffleOp.result()->getType() != type) {
|
||||
auto type = shuffleOp.value().getType();
|
||||
if (shuffleOp.result().getType() != type) {
|
||||
return shuffleOp.emitOpError()
|
||||
<< "requires the same type for value operand and result";
|
||||
}
|
||||
|
@ -170,10 +170,8 @@ static LogicalResult verifyShuffleOp(gpu::ShuffleOp shuffleOp) {
|
|||
}
|
||||
|
||||
static void printShuffleOp(OpAsmPrinter &p, ShuffleOp op) {
|
||||
p << ShuffleOp::getOperationName() << ' ';
|
||||
p.printOperands(op.getOperands());
|
||||
p << ' ' << op.mode() << " : ";
|
||||
p.printType(op.value()->getType());
|
||||
p << ShuffleOp::getOperationName() << ' ' << op.getOperands() << ' '
|
||||
<< op.mode() << " : " << op.value().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
|
||||
|
@ -201,14 +199,6 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &state) {
|
|||
// LaunchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static SmallVector<Type, 4> getValueTypes(ValueRange values) {
|
||||
SmallVector<Type, 4> types;
|
||||
types.reserve(values.size());
|
||||
for (Value v : values)
|
||||
types.push_back(v->getType());
|
||||
return types;
|
||||
}
|
||||
|
||||
void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
|
||||
Value gridSizeY, Value gridSizeZ, Value blockSizeX,
|
||||
Value blockSizeY, Value blockSizeZ, ValueRange operands) {
|
||||
|
@ -224,7 +214,7 @@ void LaunchOp::build(Builder *builder, OperationState &result, Value gridSizeX,
|
|||
Block *body = new Block();
|
||||
body->addArguments(
|
||||
std::vector<Type>(kNumConfigRegionAttributes, builder->getIndexType()));
|
||||
body->addArguments(getValueTypes(operands));
|
||||
body->addArguments(llvm::to_vector<4>(operands.getTypes()));
|
||||
kernelRegion->push_back(body);
|
||||
}
|
||||
|
||||
|
@ -309,10 +299,10 @@ LogicalResult verify(LaunchOp op) {
|
|||
// where %size-* and %iter-* will correspond to the body region arguments.
|
||||
static void printSizeAssignment(OpAsmPrinter &p, KernelDim3 size,
|
||||
ValueRange operands, KernelDim3 ids) {
|
||||
p << '(' << *ids.x << ", " << *ids.y << ", " << *ids.z << ") in (";
|
||||
p << *size.x << " = " << *operands[0] << ", ";
|
||||
p << *size.y << " = " << *operands[1] << ", ";
|
||||
p << *size.z << " = " << *operands[2] << ')';
|
||||
p << '(' << ids.x << ", " << ids.y << ", " << ids.z << ") in (";
|
||||
p << size.x << " = " << operands[0] << ", ";
|
||||
p << size.y << " = " << operands[1] << ", ";
|
||||
p << size.z << " = " << operands[2] << ')';
|
||||
}
|
||||
|
||||
void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
|
||||
|
@ -335,8 +325,8 @@ void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
|
|||
p << ' ' << op.getArgsKeyword() << '(';
|
||||
Block *entryBlock = &op.body().front();
|
||||
interleaveComma(llvm::seq<int>(0, operands.size()), p, [&](int i) {
|
||||
p << *entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
|
||||
<< " = " << *operands[i];
|
||||
p << entryBlock->getArgument(LaunchOp::kNumConfigRegionAttributes + i)
|
||||
<< " = " << operands[i];
|
||||
});
|
||||
p << ") ";
|
||||
}
|
||||
|
@ -486,14 +476,14 @@ class PropagateConstantBounds : public OpRewritePattern<LaunchOp> {
|
|||
for (unsigned i = operands.size(); i > 0; --i) {
|
||||
unsigned index = i - 1;
|
||||
Value operand = operands[index];
|
||||
if (!isa_and_nonnull<ConstantOp>(operand->getDefiningOp()))
|
||||
if (!isa_and_nonnull<ConstantOp>(operand.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
found = true;
|
||||
Value internalConstant =
|
||||
rewriter.clone(*operand->getDefiningOp())->getResult(0);
|
||||
rewriter.clone(*operand.getDefiningOp())->getResult(0);
|
||||
Value kernelArg = *std::next(kernelArgs.begin(), index);
|
||||
kernelArg->replaceAllUsesWith(internalConstant);
|
||||
kernelArg.replaceAllUsesWith(internalConstant);
|
||||
launchOp.eraseKernelArgument(index);
|
||||
}
|
||||
|
||||
|
@ -740,7 +730,7 @@ static void printAttributions(OpAsmPrinter &p, StringRef keyword,
|
|||
|
||||
p << ' ' << keyword << '(';
|
||||
interleaveComma(values, p,
|
||||
[&p](BlockArgument v) { p << *v << " : " << v->getType(); });
|
||||
[&p](BlockArgument v) { p << v << " : " << v.getType(); });
|
||||
p << ')';
|
||||
}
|
||||
|
||||
|
@ -790,7 +780,7 @@ static LogicalResult verifyAttributions(Operation *op,
|
|||
ArrayRef<BlockArgument> attributions,
|
||||
unsigned memorySpace) {
|
||||
for (Value v : attributions) {
|
||||
auto type = v->getType().dyn_cast<MemRefType>();
|
||||
auto type = v.getType().dyn_cast<MemRefType>();
|
||||
if (!type)
|
||||
return op->emitOpError() << "expected memref type in attribution";
|
||||
|
||||
|
@ -814,7 +804,7 @@ LogicalResult GPUFuncOp::verifyBody() {
|
|||
|
||||
ArrayRef<Type> funcArgTypes = getType().getInputs();
|
||||
for (unsigned i = 0; i < numFuncArguments; ++i) {
|
||||
Type blockArgType = front().getArgument(i)->getType();
|
||||
Type blockArgType = front().getArgument(i).getType();
|
||||
if (funcArgTypes[i] != blockArgType)
|
||||
return emitOpError() << "expected body region argument #" << i
|
||||
<< " to be of type " << funcArgTypes[i] << ", got "
|
||||
|
|
|
@ -45,7 +45,7 @@ static void injectGpuIndexOperations(Location loc, Region &body) {
|
|||
// Replace the leading 12 function args with the respective thread/block index
|
||||
// operations. Iterate backwards since args are erased and indices change.
|
||||
for (int i = 11; i >= 0; --i) {
|
||||
firstBlock.getArgument(i)->replaceAllUsesWith(indexOps[i]);
|
||||
firstBlock.getArgument(i).replaceAllUsesWith(indexOps[i]);
|
||||
firstBlock.eraseArgument(i);
|
||||
}
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
|
|||
map.map(launch.getKernelOperand(i), kernelFunc.getArgument(i));
|
||||
}
|
||||
for (int i = launch.getNumKernelOperands() - 1; i >= 0; --i) {
|
||||
auto operandOp = launch.getKernelOperand(i)->getDefiningOp();
|
||||
auto operandOp = launch.getKernelOperand(i).getDefiningOp();
|
||||
if (!operandOp || !isInliningBeneficiary(operandOp)) {
|
||||
newLaunchArgs.push_back(launch.getKernelOperand(i));
|
||||
continue;
|
||||
|
@ -77,7 +77,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
|
|||
continue;
|
||||
}
|
||||
auto clone = kernelBuilder.clone(*operandOp, map);
|
||||
firstBlock.getArgument(i)->replaceAllUsesWith(clone->getResult(0));
|
||||
firstBlock.getArgument(i).replaceAllUsesWith(clone->getResult(0));
|
||||
firstBlock.eraseArgument(i);
|
||||
}
|
||||
if (newLaunchArgs.size() == launch.getNumKernelOperands())
|
||||
|
@ -88,7 +88,7 @@ static gpu::LaunchFuncOp inlineBeneficiaryOps(gpu::GPUFuncOp kernelFunc,
|
|||
SmallVector<Type, 8> newArgumentTypes;
|
||||
newArgumentTypes.reserve(firstBlock.getNumArguments());
|
||||
for (auto value : firstBlock.getArguments()) {
|
||||
newArgumentTypes.push_back(value->getType());
|
||||
newArgumentTypes.push_back(value.getType());
|
||||
}
|
||||
kernelFunc.setType(LaunchBuilder.getFunctionType(newArgumentTypes, {}));
|
||||
auto newLaunch = LaunchBuilder.create<gpu::LaunchFuncOp>(
|
||||
|
|
|
@ -35,16 +35,16 @@ using namespace mlir::LLVM;
|
|||
//===----------------------------------------------------------------------===//
|
||||
static void printICmpOp(OpAsmPrinter &p, ICmpOp &op) {
|
||||
p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
<< "\" " << op.getOperand(0) << ", " << op.getOperand(1);
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
p << " : " << op.lhs()->getType();
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
static void printFCmpOp(OpAsmPrinter &p, FCmpOp &op) {
|
||||
p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
|
||||
<< "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
|
||||
<< "\" " << op.getOperand(0) << ", " << op.getOperand(1);
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"predicate"});
|
||||
p << " : " << op.lhs()->getType();
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
|
||||
|
@ -120,10 +120,10 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
|
|||
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
|
||||
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
|
||||
auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
|
||||
auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
|
||||
op.getContext());
|
||||
|
||||
p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
|
||||
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
|
||||
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
else
|
||||
|
@ -168,7 +168,7 @@ static void printGEPOp(OpAsmPrinter &p, GEPOp &op) {
|
|||
SmallVector<Type, 8> types(op.getOperandTypes());
|
||||
auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
|
||||
|
||||
p << op.getOperationName() << ' ' << *op.base() << '['
|
||||
p << op.getOperationName() << ' ' << op.base() << '['
|
||||
<< op.getOperands().drop_front() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << funcTy;
|
||||
|
@ -212,9 +212,9 @@ static ParseResult parseGEPOp(OpAsmParser &parser, OperationState &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printLoadOp(OpAsmPrinter &p, LoadOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.addr();
|
||||
p << op.getOperationName() << ' ' << op.addr();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.addr()->getType();
|
||||
p << " : " << op.addr().getType();
|
||||
}
|
||||
|
||||
// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
|
||||
|
@ -256,9 +256,9 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printStoreOp(OpAsmPrinter &p, StoreOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
|
||||
p << op.getOperationName() << ' ' << op.value() << ", " << op.addr();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.addr()->getType();
|
||||
p << " : " << op.addr().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
|
||||
|
@ -300,7 +300,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
|
|||
if (isDirect)
|
||||
p.printSymbolName(callee.getValue());
|
||||
else
|
||||
p << *op.getOperand(0);
|
||||
p << op.getOperand(0);
|
||||
|
||||
p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
|
@ -408,17 +408,17 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
|
|||
void LLVM::ExtractElementOp::build(Builder *b, OperationState &result,
|
||||
Value vector, Value position,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>();
|
||||
auto wrappedVectorType = vector.getType().cast<LLVM::LLVMType>();
|
||||
auto llvmType = wrappedVectorType.getVectorElementType();
|
||||
build(b, result, llvmType, vector, position);
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
static void printExtractElementOp(OpAsmPrinter &p, ExtractElementOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.vector() << "[" << *op.position()
|
||||
<< " : " << op.position()->getType() << "]";
|
||||
p << op.getOperationName() << ' ' << op.vector() << "[" << op.position()
|
||||
<< " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
p << " : " << op.vector().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
|
||||
|
@ -450,9 +450,9 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.container() << op.position();
|
||||
p << op.getOperationName() << ' ' << op.container() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.container()->getType();
|
||||
p << " : " << op.container().getType();
|
||||
}
|
||||
|
||||
// Extract the type at `position` in the wrapped LLVM IR aggregate type
|
||||
|
@ -542,10 +542,10 @@ static ParseResult parseExtractValueOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printInsertElementOp(OpAsmPrinter &p, InsertElementOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.value() << ", " << *op.vector()
|
||||
<< "[" << *op.position() << " : " << op.position()->getType() << "]";
|
||||
p << op.getOperationName() << ' ' << op.value() << ", " << op.vector() << "["
|
||||
<< op.position() << " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
p << " : " << op.vector().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
|
||||
|
@ -586,10 +586,10 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printInsertValueOp(OpAsmPrinter &p, InsertValueOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
|
||||
p << op.getOperationName() << ' ' << op.value() << ", " << op.container()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.container()->getType();
|
||||
p << " : " << op.container().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
|
||||
|
@ -629,10 +629,10 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printSelectOp(OpAsmPrinter &p, SelectOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.condition() << ", "
|
||||
<< *op.trueValue() << ", " << *op.falseValue();
|
||||
p << op.getOperationName() << ' ' << op.condition() << ", " << op.trueValue()
|
||||
<< ", " << op.falseValue();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
|
||||
p << " : " << op.condition().getType() << ", " << op.trueValue().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
|
||||
|
@ -686,7 +686,7 @@ static ParseResult parseBrOp(OpAsmParser &parser, OperationState &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printCondBrOp(OpAsmPrinter &p, CondBrOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
|
||||
p << op.getOperationName() << ' ' << op.getOperand(0) << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), 0);
|
||||
p << ", ";
|
||||
p.printSuccessorAndUseList(op.getOperation(), 1);
|
||||
|
@ -733,7 +733,7 @@ static void printReturnOp(OpAsmPrinter &p, ReturnOp &op) {
|
|||
if (op.getNumOperands() == 0)
|
||||
return;
|
||||
|
||||
p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
|
||||
p << ' ' << op.getOperand(0) << " : " << op.getOperand(0).getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
|
||||
|
@ -761,7 +761,7 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
|||
static void printUndefOp(OpAsmPrinter &p, UndefOp &op) {
|
||||
p << op.getOperationName();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.res()->getType();
|
||||
p << " : " << op.res().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.undef` attribute-dict? : type
|
||||
|
@ -792,7 +792,7 @@ GlobalOp AddressOfOp::getGlobal() {
|
|||
static void printAddressOfOp(OpAsmPrinter &p, AddressOfOp op) {
|
||||
p << op.getOperationName() << " @" << op.global_name();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"global_name"});
|
||||
p << " : " << op.getResult()->getType();
|
||||
p << " : " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseAddressOfOp(OpAsmParser &parser,
|
||||
|
@ -816,7 +816,7 @@ static LogicalResult verify(AddressOfOp op) {
|
|||
"must reference a global defined by 'llvm.mlir.global'");
|
||||
|
||||
if (global.getType().getPointerTo(global.addr_space().getZExtValue()) !=
|
||||
op.getResult()->getType())
|
||||
op.getResult().getType())
|
||||
return op.emitOpError(
|
||||
"the type must be a pointer to the type of the referred global");
|
||||
|
||||
|
@ -830,7 +830,7 @@ static LogicalResult verify(AddressOfOp op) {
|
|||
static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
|
||||
p << op.getOperationName() << '(' << op.value() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"value"});
|
||||
p << " : " << op.res()->getType();
|
||||
p << " : " << op.res().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.mlir.constant` `(` attribute `)` attribute-list? : type
|
||||
|
@ -1060,7 +1060,7 @@ static LogicalResult verify(GlobalOp op) {
|
|||
void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
|
||||
Value v2, ArrayAttr mask,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>();
|
||||
auto wrappedContainerType1 = v1.getType().cast<LLVM::LLVMType>();
|
||||
auto vType = LLVMType::getVectorTy(
|
||||
wrappedContainerType1.getVectorElementType(), mask.size());
|
||||
build(b, result, vType, v1, v2, mask);
|
||||
|
@ -1068,10 +1068,10 @@ void LLVM::ShuffleVectorOp::build(Builder *b, OperationState &result, Value v1,
|
|||
}
|
||||
|
||||
static void printShuffleVectorOp(OpAsmPrinter &p, ShuffleVectorOp &op) {
|
||||
p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
|
||||
p << op.getOperationName() << ' ' << op.v1() << ", " << op.v2() << " "
|
||||
<< op.mask();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"mask"});
|
||||
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
|
||||
p << " : " << op.v1().getType() << ", " << op.v2().getType();
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
|
||||
|
@ -1329,7 +1329,7 @@ static LogicalResult verify(LLVMFuncOp op) {
|
|||
unsigned numArguments = funcType->getNumParams();
|
||||
Block &entryBlock = op.front();
|
||||
for (unsigned i = 0; i < numArguments; ++i) {
|
||||
Type argType = entryBlock.getArgument(i)->getType();
|
||||
Type argType = entryBlock.getArgument(i).getType();
|
||||
auto argLLVMType = argType.dyn_cast<LLVMType>();
|
||||
if (!argLLVMType)
|
||||
return op.emitOpError("entry block argument #")
|
||||
|
|
|
@ -48,30 +48,30 @@ Value Aliases::find(Value v) {
|
|||
|
||||
auto it = aliases.find(v);
|
||||
if (it != aliases.end()) {
|
||||
assert(it->getSecond()->getType().isa<MemRefType>() && "Memref expected");
|
||||
assert(it->getSecond().getType().isa<MemRefType>() && "Memref expected");
|
||||
return it->getSecond();
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (v.isa<BlockArgument>())
|
||||
return v;
|
||||
if (auto alloc = dyn_cast_or_null<AllocOp>(v->getDefiningOp())) {
|
||||
if (auto alloc = dyn_cast_or_null<AllocOp>(v.getDefiningOp())) {
|
||||
if (isStrided(alloc.getType()))
|
||||
return alloc.getResult();
|
||||
}
|
||||
if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
|
||||
if (auto slice = dyn_cast_or_null<SliceOp>(v.getDefiningOp())) {
|
||||
auto it = aliases.insert(std::make_pair(v, find(slice.view())));
|
||||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
|
||||
if (auto view = dyn_cast_or_null<ViewOp>(v.getDefiningOp())) {
|
||||
auto it = aliases.insert(std::make_pair(v, view.source()));
|
||||
return it.first->second;
|
||||
}
|
||||
if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
|
||||
if (auto view = dyn_cast_or_null<SubViewOp>(v.getDefiningOp())) {
|
||||
v = view.source();
|
||||
continue;
|
||||
}
|
||||
llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
|
||||
llvm::errs() << "View alias analysis reduces to: " << v << "\n";
|
||||
llvm_unreachable("unsupported view alias case");
|
||||
}
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
|
|||
auto *op = dependence.dependentOpView.op;
|
||||
LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
|
||||
<< toStringRef(dt) << ": " << *src << " -> " << *op
|
||||
<< " on " << *dependence.indexingView);
|
||||
<< " on " << dependence.indexingView);
|
||||
res.push_back(op);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,7 +88,7 @@ Operation *mlir::edsc::makeGenericLinalgOp(
|
|||
for (auto it : llvm::enumerate(values))
|
||||
blockTypes.push_back((it.index() < nViews)
|
||||
? getElementTypeOrSelf(it.value())
|
||||
: it.value()->getType());
|
||||
: it.value().getType());
|
||||
|
||||
assert(op->getRegions().front().empty());
|
||||
op->getRegions().front().push_front(new Block);
|
||||
|
|
|
@ -120,7 +120,7 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
|
|||
|
||||
for (unsigned i = 0; i < nViews; ++i) {
|
||||
auto viewType = op.getShapedType(i);
|
||||
if (viewType.getElementType() != block.getArgument(i)->getType())
|
||||
if (viewType.getElementType() != block.getArgument(i).getType())
|
||||
return op.emitOpError("expected block argument ")
|
||||
<< i << " of the same type as elemental type of "
|
||||
<< ((i < nInputViews) ? "input " : "output ")
|
||||
|
@ -139,7 +139,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
|||
"number of loops");
|
||||
|
||||
for (unsigned i = 0; i < nLoops; ++i) {
|
||||
if (!block.getArgument(i)->getType().isIndex())
|
||||
if (!block.getArgument(i).getType().isIndex())
|
||||
return op.emitOpError("expected block argument ")
|
||||
<< i << " to be of IndexType";
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
|
|||
unsigned memrefArgIndex = i + nLoops;
|
||||
auto viewType = op.getShapedType(i);
|
||||
if (viewType.getElementType() !=
|
||||
block.getArgument(memrefArgIndex)->getType())
|
||||
block.getArgument(memrefArgIndex).getType())
|
||||
return op.emitOpError("expected block argument ")
|
||||
<< memrefArgIndex << " of the same type as elemental type of "
|
||||
<< ((i < nInputViews) ? "input " : "output ")
|
||||
|
@ -314,10 +314,10 @@ static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); }
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, RangeOp op) {
|
||||
p << op.getOperationName() << " " << *op.min() << ":" << *op.max() << ":"
|
||||
<< *op.step();
|
||||
p << op.getOperationName() << " " << op.min() << ":" << op.max() << ":"
|
||||
<< op.step();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getResult()->getType();
|
||||
p << " : " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -541,7 +541,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
|
|||
result.addOperands(base);
|
||||
result.addOperands(indexings);
|
||||
|
||||
auto memRefType = base->getType().cast<MemRefType>();
|
||||
auto memRefType = base.getType().cast<MemRefType>();
|
||||
int64_t offset;
|
||||
SmallVector<int64_t, 4> strides;
|
||||
auto res = getStridesAndOffset(memRefType, strides, offset);
|
||||
|
@ -560,7 +560,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
|
|||
|
||||
static void print(OpAsmPrinter &p, SliceOp op) {
|
||||
auto indexings = op.indexings();
|
||||
p << SliceOp::getOperationName() << " " << *op.view() << "[" << indexings
|
||||
p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings
|
||||
<< "] ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getBaseViewType();
|
||||
|
@ -599,7 +599,7 @@ static LogicalResult verify(SliceOp op) {
|
|||
<< rank << " indexings, got " << llvm::size(op.indexings());
|
||||
unsigned index = 0;
|
||||
for (auto indexing : op.indexings()) {
|
||||
if (indexing->getType().isa<IndexType>())
|
||||
if (indexing.getType().isa<IndexType>())
|
||||
--rank;
|
||||
++index;
|
||||
}
|
||||
|
@ -618,7 +618,7 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
|
|||
auto permutationMap = permutation.getValue();
|
||||
assert(permutationMap);
|
||||
|
||||
auto memRefType = view->getType().cast<MemRefType>();
|
||||
auto memRefType = view.getType().cast<MemRefType>();
|
||||
auto rank = memRefType.getRank();
|
||||
auto originalSizes = memRefType.getShape();
|
||||
// Compute permuted sizes.
|
||||
|
@ -644,10 +644,10 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TransposeOp op) {
|
||||
p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
|
||||
p << op.getOperationName() << " " << op.view() << " " << op.permutation();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
{TransposeOp::getPermutationAttrName()});
|
||||
p << " : " << op.view()->getType();
|
||||
p << " : " << op.view().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTransposeOp(OpAsmParser &parser,
|
||||
|
@ -698,9 +698,9 @@ LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) {
|
|||
|
||||
for (unsigned i = 0; i != nOutputViews; ++i) {
|
||||
auto elementType = genericOp.getOutputShapedType(i).getElementType();
|
||||
if (op.getOperand(i)->getType() != elementType)
|
||||
if (op.getOperand(i).getType() != elementType)
|
||||
return op.emitOpError("type of return operand ")
|
||||
<< i << " (" << op.getOperand(i)->getType()
|
||||
<< i << " (" << op.getOperand(i).getType()
|
||||
<< ") doesn't match view element type (" << elementType << ")";
|
||||
}
|
||||
return success();
|
||||
|
@ -765,7 +765,7 @@ static ParseResult parseLinalgStructuredOp(OpAsmParser &parser,
|
|||
|
||||
static LogicalResult verify(FillOp op) {
|
||||
auto viewType = op.getOutputShapedType(0);
|
||||
auto fillType = op.value()->getType();
|
||||
auto fillType = op.value().getType();
|
||||
if (viewType.getElementType() != fillType)
|
||||
return op.emitOpError("expects fill type to match view elemental type");
|
||||
return success();
|
||||
|
@ -816,9 +816,9 @@ verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(ConvOp op) {
|
||||
auto oType = op.output()->getType().cast<MemRefType>();
|
||||
auto fType = op.filter()->getType().cast<MemRefType>();
|
||||
auto iType = op.input()->getType().cast<MemRefType>();
|
||||
auto oType = op.output().getType().cast<MemRefType>();
|
||||
auto fType = op.filter().getType().cast<MemRefType>();
|
||||
auto iType = op.input().getType().cast<MemRefType>();
|
||||
if (oType.getElementType() != iType.getElementType() ||
|
||||
oType.getElementType() != fType.getElementType())
|
||||
return op.emitOpError("expects memref elemental types to match");
|
||||
|
|
|
@ -133,8 +133,7 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
|
|||
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
|
||||
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
|
||||
<< "\n");
|
||||
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
|
||||
<< "\n");
|
||||
LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n");
|
||||
return ViewDimension{view, static_cast<unsigned>(en2.index())};
|
||||
}
|
||||
}
|
||||
|
@ -146,9 +145,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
|
|||
unsigned consumerIdx, unsigned producerIdx,
|
||||
OperationFolder *folder) {
|
||||
auto subView = dyn_cast_or_null<SubViewOp>(
|
||||
consumer.getInput(consumerIdx)->getDefiningOp());
|
||||
auto slice = dyn_cast_or_null<SliceOp>(
|
||||
consumer.getInput(consumerIdx)->getDefiningOp());
|
||||
consumer.getInput(consumerIdx).getDefiningOp());
|
||||
auto slice =
|
||||
dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp());
|
||||
assert(subView || slice);
|
||||
(void)subView;
|
||||
(void)slice;
|
||||
|
@ -272,13 +271,13 @@ Optional<FusionInfo> mlir::linalg::fuseProducerOf(
|
|||
auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
|
||||
// `consumerIdx` and `producerIdx` exist by construction.
|
||||
LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
|
||||
<< " view: " << *producedView
|
||||
<< " view: " << producedView
|
||||
<< " output index: " << producerIdx);
|
||||
|
||||
// Must be a subview or a slice to guarantee there are loops we can fuse
|
||||
// into.
|
||||
auto subView = dyn_cast_or_null<SubViewOp>(consumedView->getDefiningOp());
|
||||
auto slice = dyn_cast_or_null<SliceOp>(consumedView->getDefiningOp());
|
||||
auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
|
||||
auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
|
||||
if (!subView && !slice) {
|
||||
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
|
||||
continue;
|
||||
|
|
|
@ -166,7 +166,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) {
|
|||
|
||||
// TODO(ntv): non-identity layout.
|
||||
auto isStaticMemRefWithIdentityLayout = [](Value v) {
|
||||
auto m = v->getType().dyn_cast<MemRefType>();
|
||||
auto m = v.getType().dyn_cast<MemRefType>();
|
||||
if (!m || !m.hasStaticShape() || !m.getAffineMaps().empty())
|
||||
return false;
|
||||
return true;
|
||||
|
@ -281,7 +281,7 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter,
|
|||
LinalgOp linOp = cast<LinalgOp>(op);
|
||||
SetVector<Value> subViews;
|
||||
for (auto it : linOp.getInputsAndOutputs())
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
|
||||
subViews.insert(sv);
|
||||
if (!subViews.empty()) {
|
||||
promoteSubViewOperands(rewriter, linOp, subViews);
|
||||
|
|
|
@ -47,10 +47,10 @@ static llvm::cl::opt<bool> clPromoteDynamic(
|
|||
llvm::cl::cat(clOptionsCategory), llvm::cl::init(false));
|
||||
|
||||
static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers) {
|
||||
auto *ctx = size->getContext();
|
||||
auto *ctx = size.getContext();
|
||||
auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
|
||||
if (!dynamicBuffers)
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size->getDefiningOp()))
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
|
||||
return alloc(
|
||||
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)));
|
||||
Value mul = muli(constant_index(width), size);
|
||||
|
@ -116,7 +116,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
|||
res.reserve(subViews.size());
|
||||
DenseMap<Value, PromotionInfo> promotionInfoMap;
|
||||
for (auto v : subViews) {
|
||||
SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
|
||||
SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
|
||||
auto viewType = subView.getType();
|
||||
// TODO(ntv): support more cases than just float.
|
||||
if (!viewType.getElementType().isa<FloatType>())
|
||||
|
@ -128,7 +128,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
for (auto v : subViews) {
|
||||
SubViewOp subView = cast<SubViewOp>(v->getDefiningOp());
|
||||
SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
|
||||
auto info = promotionInfoMap.find(v);
|
||||
if (info == promotionInfoMap.end())
|
||||
continue;
|
||||
|
@ -146,7 +146,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
|
|||
auto info = promotionInfoMap.find(v);
|
||||
if (info == promotionInfoMap.end())
|
||||
continue;
|
||||
copy(cast<SubViewOp>(v->getDefiningOp()), info->second.partialLocalView);
|
||||
copy(cast<SubViewOp>(v.getDefiningOp()), info->second.partialLocalView);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -208,7 +208,7 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
|
|||
SetVector<Value> subViews;
|
||||
OpBuilder b(op);
|
||||
for (auto it : op.getInputsAndOutputs())
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(it->getDefiningOp()))
|
||||
if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
|
||||
subViews.insert(sv);
|
||||
if (!subViews.empty()) {
|
||||
promoteSubViewOperands(b, op, subViews, dynamicBuffers, &folder);
|
||||
|
|
|
@ -45,8 +45,8 @@ static llvm::cl::list<unsigned>
|
|||
llvm::cl::cat(clOptionsCategory));
|
||||
|
||||
static bool isZero(Value v) {
|
||||
return isa_and_nonnull<ConstantIndexOp>(v->getDefiningOp()) &&
|
||||
cast<ConstantIndexOp>(v->getDefiningOp()).getValue() == 0;
|
||||
return isa_and_nonnull<ConstantIndexOp>(v.getDefiningOp()) &&
|
||||
cast<ConstantIndexOp>(v.getDefiningOp()).getValue() == 0;
|
||||
}
|
||||
|
||||
using LoopIndexToRangeIndexMap = DenseMap<int, int>;
|
||||
|
@ -201,8 +201,8 @@ void transformIndexedGenericOpIndices(
|
|||
// variable and replace all uses of the previous value.
|
||||
Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
|
||||
pivs[rangeIndex->second]->getValue());
|
||||
for (auto &use : oldIndex->getUses()) {
|
||||
if (use.getOwner() == newIndex->getDefiningOp())
|
||||
for (auto &use : oldIndex.getUses()) {
|
||||
if (use.getOwner() == newIndex.getDefiningOp())
|
||||
continue;
|
||||
use.set(newIndex);
|
||||
}
|
||||
|
@ -258,7 +258,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs();
|
||||
++viewIndex) {
|
||||
Value view = *(viewIteratorBegin + viewIndex);
|
||||
unsigned rank = view->getType().cast<MemRefType>().getRank();
|
||||
unsigned rank = view.getType().cast<MemRefType>().getRank();
|
||||
auto map = loopToOperandRangesMaps(linalgOp)[viewIndex];
|
||||
// If the view is not tiled, we can use it as is.
|
||||
if (!isTiled(map, tileSizes)) {
|
||||
|
@ -299,8 +299,8 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
|
|||
// defined.
|
||||
if (folder)
|
||||
for (auto v : llvm::concat<Value>(lbs, subViewSizes))
|
||||
if (v->use_empty())
|
||||
v->getDefiningOp()->erase();
|
||||
if (v.use_empty())
|
||||
v.getDefiningOp()->erase();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -35,9 +35,9 @@ using namespace mlir::loop;
|
|||
mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv,
|
||||
ValueHandle range) {
|
||||
assert(range.getType() && "expected !linalg.range type");
|
||||
assert(range.getValue()->getDefiningOp() &&
|
||||
assert(range.getValue().getDefiningOp() &&
|
||||
"need operations to extract range parts");
|
||||
auto rangeOp = cast<RangeOp>(range.getValue()->getDefiningOp());
|
||||
auto rangeOp = cast<RangeOp>(range.getValue().getDefiningOp());
|
||||
auto lb = rangeOp.min();
|
||||
auto ub = rangeOp.max();
|
||||
auto step = rangeOp.step();
|
||||
|
@ -168,7 +168,7 @@ mlir::linalg::getAssumedNonViewOperands(LinalgOp linalgOp) {
|
|||
res.reserve(nOperands);
|
||||
for (unsigned i = 0; i < nOperands; ++i) {
|
||||
res.push_back(op->getOperand(numViews + i));
|
||||
auto t = res.back()->getType();
|
||||
auto t = res.back().getType();
|
||||
(void)t;
|
||||
assert((t.isIntOrIndexOrFloat() || t.isa<VectorType>()) &&
|
||||
"expected scalar or vector type");
|
||||
|
|
|
@ -69,23 +69,22 @@ void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub,
|
|||
}
|
||||
|
||||
LogicalResult verify(ForOp op) {
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step()->getDefiningOp()))
|
||||
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
|
||||
if (cst.getValue() <= 0)
|
||||
return op.emitOpError("constant step operand must be positive");
|
||||
|
||||
// Check that the body defines as single block argument for the induction
|
||||
// variable.
|
||||
auto *body = op.getBody();
|
||||
if (body->getNumArguments() != 1 ||
|
||||
!body->getArgument(0)->getType().isIndex())
|
||||
if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
|
||||
return op.emitOpError("expected body to have a single index argument for "
|
||||
"the induction variable");
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ForOp op) {
|
||||
p << op.getOperationName() << " " << *op.getInductionVar() << " = "
|
||||
<< *op.lowerBound() << " to " << *op.upperBound() << " step " << *op.step();
|
||||
p << op.getOperationName() << " " << op.getInductionVar() << " = "
|
||||
<< op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
|
||||
p.printRegion(op.region(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
@ -126,11 +125,11 @@ static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
|
|||
Region &ForOp::getLoopBody() { return region(); }
|
||||
|
||||
bool ForOp::isDefinedOutsideOfLoop(Value value) {
|
||||
return !region().isAncestor(value->getParentRegion());
|
||||
return !region().isAncestor(value.getParentRegion());
|
||||
}
|
||||
|
||||
LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
|
||||
for (auto *op : ops)
|
||||
for (auto op : ops)
|
||||
op->moveBefore(this->getOperation());
|
||||
return success();
|
||||
}
|
||||
|
@ -139,8 +138,8 @@ ForOp mlir::loop::getForInductionVarOwner(Value val) {
|
|||
auto ivArg = val.dyn_cast<BlockArgument>();
|
||||
if (!ivArg)
|
||||
return ForOp();
|
||||
assert(ivArg->getOwner() && "unlinked block argument");
|
||||
auto *containingInst = ivArg->getOwner()->getParentOp();
|
||||
assert(ivArg.getOwner() && "unlinked block argument");
|
||||
auto *containingInst = ivArg.getOwner()->getParentOp();
|
||||
return dyn_cast_or_null<ForOp>(containingInst);
|
||||
}
|
||||
|
||||
|
@ -205,7 +204,7 @@ static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, IfOp op) {
|
||||
p << IfOp::getOperationName() << " " << *op.condition();
|
||||
p << IfOp::getOperationName() << " " << op.condition();
|
||||
p.printRegion(op.thenRegion(),
|
||||
/*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
|
|
|
@ -36,8 +36,8 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
|
|||
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
|
||||
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
|
||||
/// value of x if the casts invert each other.
|
||||
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg()->getDefiningOp());
|
||||
if (!srcScastOp || srcScastOp.arg()->getType() != getType())
|
||||
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
|
||||
if (!srcScastOp || srcScastOp.arg().getType() != getType())
|
||||
return OpFoldResult();
|
||||
return srcScastOp.arg();
|
||||
}
|
||||
|
|
|
@ -52,7 +52,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
// Does the qbarrier convert to a quantized type. This will not be true
|
||||
// if a quantized type has not yet been chosen or if the cast to an equivalent
|
||||
// storage type is not supported.
|
||||
Type qbarrierResultType = qbarrier.getResult()->getType();
|
||||
Type qbarrierResultType = qbarrier.getResult().getType();
|
||||
QuantizedType quantizedElementType =
|
||||
QuantizedType::getQuantizedElementType(qbarrierResultType);
|
||||
if (!quantizedElementType) {
|
||||
|
@ -66,7 +66,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
// type? This will not be true if the qbarrier is superfluous (converts
|
||||
// from and to a quantized type).
|
||||
if (!quantizedElementType.isCompatibleExpressedType(
|
||||
qbarrier.arg()->getType())) {
|
||||
qbarrier.arg().getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
|
@ -86,7 +86,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
// When creating the new const op, use a fused location that combines the
|
||||
// original const and the qbarrier that led to the quantization.
|
||||
auto fusedLoc = FusedLoc::get(
|
||||
{qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
|
||||
{qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()},
|
||||
rewriter.getContext());
|
||||
auto newConstOp =
|
||||
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
|
||||
|
|
|
@ -104,7 +104,7 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(valuesToRepl.size() == 1 &&
|
||||
"spv.ReturnValue expected to only handle one result");
|
||||
valuesToRepl.front()->replaceAllUsesWith(retValOp.value());
|
||||
valuesToRepl.front().replaceAllUsesWith(retValOp.value());
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -167,8 +167,8 @@ printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer,
|
|||
|
||||
static LogicalResult verifyCastOp(Operation *op,
|
||||
bool requireSameBitWidth = true) {
|
||||
Type operandType = op->getOperand(0)->getType();
|
||||
Type resultType = op->getResult(0)->getType();
|
||||
Type operandType = op->getOperand(0).getType();
|
||||
Type resultType = op->getResult(0).getType();
|
||||
|
||||
// ODS checks that result type and operand type have the same shape.
|
||||
if (auto vectorType = operandType.dyn_cast<VectorType>()) {
|
||||
|
@ -271,8 +271,8 @@ static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
|
|||
//
|
||||
// TODO(ravishankarm): Check that the value type satisfies restrictions of
|
||||
// SPIR-V OpLoad/OpStore operations
|
||||
if (val->getType() !=
|
||||
ptr->getType().cast<spirv::PointerType>().getPointeeType()) {
|
||||
if (val.getType() !=
|
||||
ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
|
||||
return op.emitOpError("mismatch in result type and pointer type");
|
||||
}
|
||||
return success();
|
||||
|
@ -497,11 +497,11 @@ static void printBitFieldExtractOp(Operation *op, OpAsmPrinter &printer) {
|
|||
}
|
||||
|
||||
static LogicalResult verifyBitFieldExtractOp(Operation *op) {
|
||||
if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
|
||||
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
|
||||
return op->emitError("expected the same type for the first operand and "
|
||||
"result, but provided ")
|
||||
<< op->getOperand(0)->getType() << " and "
|
||||
<< op->getResult(0)->getType();
|
||||
<< op->getOperand(0).getType() << " and "
|
||||
<< op->getResult(0).getType();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -547,13 +547,12 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
|
|||
printer << spirv::stringifyMemorySemantics(
|
||||
static_cast<spirv::MemorySemantics>(
|
||||
memorySemanticsAttr.getInt()))
|
||||
<< "\" " << op->getOperands() << " : "
|
||||
<< op->getOperand(0)->getType();
|
||||
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
|
||||
}
|
||||
|
||||
// Verifies an atomic update op.
|
||||
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
||||
auto ptrType = op->getOperand(0)->getType().cast<spirv::PointerType>();
|
||||
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
|
||||
auto elementType = ptrType.getPointeeType();
|
||||
if (!elementType.isa<IntegerType>())
|
||||
return op->emitOpError(
|
||||
|
@ -561,7 +560,7 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
|||
<< elementType;
|
||||
|
||||
if (op->getNumOperands() > 1) {
|
||||
auto valueType = op->getOperand(1)->getType();
|
||||
auto valueType = op->getOperand(1).getType();
|
||||
if (valueType != elementType)
|
||||
return op->emitOpError("expected value to have the same type as the "
|
||||
"pointer operand's pointee type ")
|
||||
|
@ -595,8 +594,8 @@ static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
|
|||
}
|
||||
|
||||
static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
|
||||
printer << unaryOp->getName() << ' ' << *unaryOp->getOperand(0) << " : "
|
||||
<< unaryOp->getOperand(0)->getType();
|
||||
printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
|
||||
<< unaryOp->getOperand(0).getType();
|
||||
}
|
||||
|
||||
/// Result of a logical op must be a scalar or vector of boolean type.
|
||||
|
@ -634,7 +633,7 @@ static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
|
|||
|
||||
static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
|
||||
printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
|
||||
<< logicalOp->getOperand(0)->getType();
|
||||
<< logicalOp->getOperand(0).getType();
|
||||
}
|
||||
|
||||
static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
|
||||
|
@ -657,16 +656,16 @@ static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
|
|||
static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
|
||||
Value base = op->getOperand(0);
|
||||
Value shift = op->getOperand(1);
|
||||
printer << op->getName() << ' ' << *base << ", " << *shift << " : "
|
||||
<< base->getType() << ", " << shift->getType();
|
||||
printer << op->getName() << ' ' << base << ", " << shift << " : "
|
||||
<< base.getType() << ", " << shift.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verifyShiftOp(Operation *op) {
|
||||
if (op->getOperand(0)->getType() != op->getResult(0)->getType()) {
|
||||
if (op->getOperand(0).getType() != op->getResult(0).getType()) {
|
||||
return op->emitError("expected the same type for the first operand and "
|
||||
"result, but provided ")
|
||||
<< op->getOperand(0)->getType() << " and "
|
||||
<< op->getResult(0)->getType();
|
||||
<< op->getOperand(0).getType() << " and "
|
||||
<< op->getResult(0).getType();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -704,7 +703,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
|
|||
}
|
||||
index = 0;
|
||||
if (resultType.isa<spirv::StructType>()) {
|
||||
Operation *op = indexSSA->getDefiningOp();
|
||||
Operation *op = indexSSA.getDefiningOp();
|
||||
if (!op) {
|
||||
emitError(baseLoc, "'spv.AccessChain' op index must be an "
|
||||
"integer spv.constant to access "
|
||||
|
@ -734,7 +733,7 @@ static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
|
|||
|
||||
void spirv::AccessChainOp::build(Builder *builder, OperationState &state,
|
||||
Value basePtr, ValueRange indices) {
|
||||
auto type = getElementPtrType(basePtr->getType(), indices, state.location);
|
||||
auto type = getElementPtrType(basePtr.getType(), indices, state.location);
|
||||
assert(type && "Unable to deduce return type based on basePtr and indices");
|
||||
build(builder, state, type, basePtr, indices);
|
||||
}
|
||||
|
@ -768,14 +767,14 @@ static ParseResult parseAccessChainOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::AccessChainOp::getOperationName() << ' ' << *op.base_ptr()
|
||||
<< '[' << op.indices() << "] : " << op.base_ptr()->getType();
|
||||
printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
|
||||
<< '[' << op.indices() << "] : " << op.base_ptr().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
||||
SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
|
||||
accessChainOp.indices().end());
|
||||
auto resultType = getElementPtrType(accessChainOp.base_ptr()->getType(),
|
||||
auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
|
||||
indices, accessChainOp.getLoc());
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
|
@ -808,7 +807,7 @@ struct CombineChainedAccessChain
|
|||
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||
accessChainOp.base_ptr()->getDefiningOp());
|
||||
accessChainOp.base_ptr().getDefiningOp());
|
||||
|
||||
if (!parentAccessChainOp) {
|
||||
return matchFailure();
|
||||
|
@ -868,7 +867,7 @@ static void print(spirv::AddressOfOp addressOfOp, OpAsmPrinter &printer) {
|
|||
printer.printSymbolName(addressOfOp.variable());
|
||||
|
||||
// Print the type.
|
||||
printer << " : " << addressOfOp.pointer()->getType();
|
||||
printer << " : " << addressOfOp.pointer().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
||||
|
@ -878,7 +877,7 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
|||
if (!varOp) {
|
||||
return addressOfOp.emitOpError("expected spv.globalVariable symbol");
|
||||
}
|
||||
if (addressOfOp.pointer()->getType() != varOp.type()) {
|
||||
if (addressOfOp.pointer().getType() != varOp.type()) {
|
||||
return addressOfOp.emitOpError(
|
||||
"result type mismatch with the referenced global variable's type");
|
||||
}
|
||||
|
@ -926,7 +925,7 @@ static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
|
|||
<< stringifyScope(atomOp.memory_scope()) << "\" \""
|
||||
<< stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
|
||||
<< stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
|
||||
<< atomOp.getOperands() << " : " << atomOp.pointer()->getType();
|
||||
<< atomOp.getOperands() << " : " << atomOp.pointer().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
|
||||
|
@ -934,19 +933,19 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
|
|||
// "The type of Value must be the same as Result Type. The type of the value
|
||||
// pointed to by Pointer must be the same as Result Type. This type must also
|
||||
// match the type of Comparator."
|
||||
if (atomOp.getType() != atomOp.value()->getType())
|
||||
if (atomOp.getType() != atomOp.value().getType())
|
||||
return atomOp.emitOpError("value operand must have the same type as the op "
|
||||
"result, but found ")
|
||||
<< atomOp.value()->getType() << " vs " << atomOp.getType();
|
||||
<< atomOp.value().getType() << " vs " << atomOp.getType();
|
||||
|
||||
if (atomOp.getType() != atomOp.comparator()->getType())
|
||||
if (atomOp.getType() != atomOp.comparator().getType())
|
||||
return atomOp.emitOpError(
|
||||
"comparator operand must have the same type as the op "
|
||||
"result, but found ")
|
||||
<< atomOp.comparator()->getType() << " vs " << atomOp.getType();
|
||||
<< atomOp.comparator().getType() << " vs " << atomOp.getType();
|
||||
|
||||
Type pointeeType =
|
||||
atomOp.pointer()->getType().cast<spirv::PointerType>().getPointeeType();
|
||||
atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
if (atomOp.getType() != pointeeType)
|
||||
return atomOp.emitOpError(
|
||||
"pointer operand's pointee type must have the same "
|
||||
|
@ -966,8 +965,8 @@ static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
|
|||
static LogicalResult verify(spirv::BitcastOp bitcastOp) {
|
||||
// TODO: The SPIR-V spec validation rules are different for different
|
||||
// versions.
|
||||
auto operandType = bitcastOp.operand()->getType();
|
||||
auto resultType = bitcastOp.result()->getType();
|
||||
auto operandType = bitcastOp.operand().getType();
|
||||
auto resultType = bitcastOp.result().getType();
|
||||
if (operandType == resultType) {
|
||||
return bitcastOp.emitError(
|
||||
"result type must be different from operand type");
|
||||
|
@ -1026,15 +1025,15 @@ static void print(spirv::BitFieldInsertOp bitFieldInsertOp,
|
|||
OpAsmPrinter &printer) {
|
||||
printer << spirv::BitFieldInsertOp::getOperationName() << ' '
|
||||
<< bitFieldInsertOp.getOperands() << " : "
|
||||
<< bitFieldInsertOp.base()->getType() << ", "
|
||||
<< bitFieldInsertOp.offset()->getType() << ", "
|
||||
<< bitFieldInsertOp.count()->getType();
|
||||
<< bitFieldInsertOp.base().getType() << ", "
|
||||
<< bitFieldInsertOp.offset().getType() << ", "
|
||||
<< bitFieldInsertOp.count().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::BitFieldInsertOp bitFieldOp) {
|
||||
auto baseType = bitFieldOp.base()->getType();
|
||||
auto insertType = bitFieldOp.insert()->getType();
|
||||
auto resultType = bitFieldOp.getResult()->getType();
|
||||
auto baseType = bitFieldOp.base().getType();
|
||||
auto insertType = bitFieldOp.insert().getType();
|
||||
auto resultType = bitFieldOp.getResult().getType();
|
||||
|
||||
if ((baseType != insertType) || (baseType != resultType)) {
|
||||
return bitFieldOp.emitError("expected the same type for the base operand, "
|
||||
|
@ -1199,7 +1198,7 @@ static void print(spirv::CompositeConstructOp compositeConstructOp,
|
|||
OpAsmPrinter &printer) {
|
||||
printer << spirv::CompositeConstructOp::getOperationName() << " "
|
||||
<< compositeConstructOp.constituents() << " : "
|
||||
<< compositeConstructOp.getResult()->getType();
|
||||
<< compositeConstructOp.getResult().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
|
||||
|
@ -1214,11 +1213,11 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
|
|||
}
|
||||
|
||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||
if (constituents[index]->getType() != cType.getElementType(index)) {
|
||||
if (constituents[index].getType() != cType.getElementType(index)) {
|
||||
return compositeConstructOp.emitError(
|
||||
"operand type mismatch: expected operand type ")
|
||||
<< cType.getElementType(index) << ", but provided "
|
||||
<< constituents[index]->getType();
|
||||
<< constituents[index].getType();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1234,7 +1233,7 @@ void spirv::CompositeExtractOp::build(Builder *builder, OperationState &state,
|
|||
ArrayRef<int32_t> indices) {
|
||||
auto indexAttr = builder->getI32ArrayAttr(indices);
|
||||
auto elementType =
|
||||
getElementType(composite->getType(), indexAttr, state.location);
|
||||
getElementType(composite.getType(), indexAttr, state.location);
|
||||
if (!elementType) {
|
||||
return;
|
||||
}
|
||||
|
@ -1268,13 +1267,13 @@ static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
|
|||
static void print(spirv::CompositeExtractOp compositeExtractOp,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << spirv::CompositeExtractOp::getOperationName() << ' '
|
||||
<< *compositeExtractOp.composite() << compositeExtractOp.indices()
|
||||
<< " : " << compositeExtractOp.composite()->getType();
|
||||
<< compositeExtractOp.composite() << compositeExtractOp.indices()
|
||||
<< " : " << compositeExtractOp.composite().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
||||
auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
|
||||
auto resultType = getElementType(compExOp.composite()->getType(),
|
||||
auto resultType = getElementType(compExOp.composite().getType(),
|
||||
indicesArrayAttr, compExOp.getLoc());
|
||||
if (!resultType)
|
||||
return failure();
|
||||
|
@ -1321,21 +1320,21 @@ static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
|
|||
static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
|
||||
auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
|
||||
auto objectType =
|
||||
getElementType(compositeInsertOp.composite()->getType(), indicesArrayAttr,
|
||||
getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
|
||||
compositeInsertOp.getLoc());
|
||||
if (!objectType)
|
||||
return failure();
|
||||
|
||||
if (objectType != compositeInsertOp.object()->getType()) {
|
||||
if (objectType != compositeInsertOp.object().getType()) {
|
||||
return compositeInsertOp.emitOpError("object operand type should be ")
|
||||
<< objectType << ", but found "
|
||||
<< compositeInsertOp.object()->getType();
|
||||
<< compositeInsertOp.object().getType();
|
||||
}
|
||||
|
||||
if (compositeInsertOp.composite()->getType() != compositeInsertOp.getType()) {
|
||||
if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
|
||||
return compositeInsertOp.emitOpError("result type should be the same as "
|
||||
"the composite type, but found ")
|
||||
<< compositeInsertOp.composite()->getType() << " vs "
|
||||
<< compositeInsertOp.composite().getType() << " vs "
|
||||
<< compositeInsertOp.getType();
|
||||
}
|
||||
|
||||
|
@ -1345,10 +1344,10 @@ static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
|
|||
static void print(spirv::CompositeInsertOp compositeInsertOp,
|
||||
OpAsmPrinter &printer) {
|
||||
printer << spirv::CompositeInsertOp::getOperationName() << " "
|
||||
<< *compositeInsertOp.object() << ", "
|
||||
<< *compositeInsertOp.composite() << compositeInsertOp.indices()
|
||||
<< " : " << compositeInsertOp.object()->getType() << " into "
|
||||
<< compositeInsertOp.composite()->getType();
|
||||
<< compositeInsertOp.object() << ", " << compositeInsertOp.composite()
|
||||
<< compositeInsertOp.indices() << " : "
|
||||
<< compositeInsertOp.object().getType() << " into "
|
||||
<< compositeInsertOp.composite().getType();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1707,12 +1706,12 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
|||
}
|
||||
|
||||
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
|
||||
if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) {
|
||||
if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
|
||||
return functionCallOp.emitOpError(
|
||||
"operand type mismatch: expected operand type ")
|
||||
<< functionType.getInput(i) << ", but provided "
|
||||
<< functionCallOp.getOperand(i)->getType()
|
||||
<< " for operand number " << i;
|
||||
<< functionCallOp.getOperand(i).getType() << " for operand number "
|
||||
<< i;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1724,10 +1723,10 @@ static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
|||
}
|
||||
|
||||
if (functionCallOp.getNumResults() &&
|
||||
(functionCallOp.getResult(0)->getType() != functionType.getResult(0))) {
|
||||
(functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
|
||||
return functionCallOp.emitOpError("result type mismatch: expected ")
|
||||
<< functionType.getResult(0) << ", but provided "
|
||||
<< functionCallOp.getResult(0)->getType();
|
||||
<< functionCallOp.getResult(0).getType();
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -1955,7 +1954,7 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
|||
void spirv::LoadOp::build(Builder *builder, OperationState &state,
|
||||
Value basePtr, IntegerAttr memory_access,
|
||||
IntegerAttr alignment) {
|
||||
auto ptrType = basePtr->getType().cast<spirv::PointerType>();
|
||||
auto ptrType = basePtr.getType().cast<spirv::PointerType>();
|
||||
build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
|
||||
alignment);
|
||||
}
|
||||
|
@ -1986,7 +1985,7 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
|
|||
auto *op = loadOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
StringRef sc = stringifyStorageClass(
|
||||
loadOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
|
||||
loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
||||
printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
|
||||
<< loadOp.ptr();
|
||||
|
||||
|
@ -2414,7 +2413,7 @@ static ParseResult parseReferenceOfOp(OpAsmParser &parser,
|
|||
static void print(spirv::ReferenceOfOp referenceOfOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ReferenceOfOp::getOperationName() << ' ';
|
||||
printer.printSymbolName(referenceOfOp.spec_const());
|
||||
printer << " : " << referenceOfOp.reference()->getType();
|
||||
printer << " : " << referenceOfOp.reference().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
||||
|
@ -2424,7 +2423,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
|||
if (!specConstOp) {
|
||||
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
|
||||
}
|
||||
if (referenceOfOp.reference()->getType() !=
|
||||
if (referenceOfOp.reference().getType() !=
|
||||
specConstOp.default_value().getType()) {
|
||||
return referenceOfOp.emitOpError("result type mismatch with the referenced "
|
||||
"specialization constant's type");
|
||||
|
@ -2461,7 +2460,7 @@ static ParseResult parseReturnValueOp(OpAsmParser &parser,
|
|||
|
||||
static void print(spirv::ReturnValueOp retValOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ReturnValueOp::getOperationName() << ' ' << retValOp.value()
|
||||
<< " : " << retValOp.value()->getType();
|
||||
<< " : " << retValOp.value().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
||||
|
@ -2472,7 +2471,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
|||
"returns 1 value but enclosing function requires ")
|
||||
<< numFnResults << " results";
|
||||
|
||||
auto operandType = retValOp.value()->getType();
|
||||
auto operandType = retValOp.value().getType();
|
||||
auto fnResultType = funcOp.getType().getResult(0);
|
||||
if (operandType != fnResultType)
|
||||
return retValOp.emitOpError(" return value's type (")
|
||||
|
@ -2488,7 +2487,7 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) {
|
|||
|
||||
void spirv::SelectOp::build(Builder *builder, OperationState &state, Value cond,
|
||||
Value trueValue, Value falseValue) {
|
||||
build(builder, state, trueValue->getType(), cond, trueValue, falseValue);
|
||||
build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
|
||||
}
|
||||
|
||||
static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
|
||||
|
@ -2514,19 +2513,18 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &state) {
|
|||
|
||||
static void print(spirv::SelectOp op, OpAsmPrinter &printer) {
|
||||
printer << spirv::SelectOp::getOperationName() << " " << op.getOperands()
|
||||
<< " : " << op.condition()->getType() << ", "
|
||||
<< op.result()->getType();
|
||||
<< " : " << op.condition().getType() << ", " << op.result().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::SelectOp op) {
|
||||
auto resultTy = op.result()->getType();
|
||||
if (op.true_value()->getType() != resultTy) {
|
||||
auto resultTy = op.result().getType();
|
||||
if (op.true_value().getType() != resultTy) {
|
||||
return op.emitOpError("result type and true value type must be the same");
|
||||
}
|
||||
if (op.false_value()->getType() != resultTy) {
|
||||
if (op.false_value().getType() != resultTy) {
|
||||
return op.emitOpError("result type and false value type must be the same");
|
||||
}
|
||||
if (auto conditionTy = op.condition()->getType().dyn_cast<VectorType>()) {
|
||||
if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
|
||||
auto resultVectorTy = resultTy.dyn_cast<VectorType>();
|
||||
if (!resultVectorTy) {
|
||||
return op.emitOpError("result expected to be of vector type when "
|
||||
|
@ -2695,7 +2693,7 @@ struct ConvertSelectionOpToSelect
|
|||
cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
|
||||
|
||||
auto selectOp = rewriter.create<spirv::SelectOp>(
|
||||
selectionOp.getLoc(), trueValue->getType(), brConditionalOp.condition(),
|
||||
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
|
||||
trueValue, falseValue);
|
||||
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
||||
selectOp.getResult(), storeOpAttributes);
|
||||
|
@ -2773,7 +2771,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|||
// attributes and a valid type of the value.
|
||||
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
||||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
|
||||
!isValidType(trueBrStoreOp.value()->getType())) {
|
||||
!isValidType(trueBrStoreOp.value().getType())) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
|
@ -2879,13 +2877,13 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
|
|||
auto *op = storeOp.getOperation();
|
||||
SmallVector<StringRef, 4> elidedAttrs;
|
||||
StringRef sc = stringifyStorageClass(
|
||||
storeOp.ptr()->getType().cast<spirv::PointerType>().getStorageClass());
|
||||
storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
|
||||
printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
|
||||
<< storeOp.ptr() << ", " << storeOp.value();
|
||||
|
||||
printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
|
||||
|
||||
printer << " : " << storeOp.value()->getType();
|
||||
printer << " : " << storeOp.value().getType();
|
||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||
}
|
||||
|
||||
|
@ -3025,7 +3023,7 @@ static LogicalResult verify(spirv::VariableOp varOp) {
|
|||
"spv.globalVariable for module-level variables.");
|
||||
}
|
||||
|
||||
auto pointerType = varOp.pointer()->getType().cast<spirv::PointerType>();
|
||||
auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
|
||||
if (varOp.storage_class() != pointerType.getStorageClass())
|
||||
return varOp.emitOpError(
|
||||
"storage class must match result pointer's storage class");
|
||||
|
@ -3033,7 +3031,7 @@ static LogicalResult verify(spirv::VariableOp varOp) {
|
|||
if (varOp.getNumOperands() != 0) {
|
||||
// SPIR-V spec: "Initializer must be an <id> from a constant instruction or
|
||||
// a global (module scope) OpVariable instruction".
|
||||
auto *initOp = varOp.getOperand(0)->getDefiningOp();
|
||||
auto *initOp = varOp.getOperand(0).getDefiningOp();
|
||||
if (!initOp || !(isa<spirv::ConstantOp>(initOp) || // for normal constant
|
||||
isa<spirv::ReferenceOfOp>(initOp) || // for spec constant
|
||||
isa<spirv::AddressOfOp>(initOp)))
|
||||
|
|
|
@ -1775,7 +1775,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
|
|||
<< " from block " << block << "\n");
|
||||
if (!isFnEntryBlock(block)) {
|
||||
for (BlockArgument blockArg : block->getArguments()) {
|
||||
auto newArg = newBlock->addArgument(blockArg->getType());
|
||||
auto newArg = newBlock->addArgument(blockArg.getType());
|
||||
mapper.map(blockArg, newArg);
|
||||
LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
|
||||
<< " to " << newArg << '\n');
|
||||
|
@ -1816,7 +1816,7 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
|
|||
// make sure the old merge block has the same block argument list.
|
||||
assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
|
||||
for (BlockArgument blockArg : headerBlock->getArguments()) {
|
||||
mergeBlock->addArgument(blockArg->getType());
|
||||
mergeBlock->addArgument(blockArg.getType());
|
||||
}
|
||||
|
||||
// If the loop header block has block arguments, make sure the spv.branch op
|
||||
|
@ -2200,7 +2200,7 @@ LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
|
|||
"spirv::BitcastOp, only ")
|
||||
<< wordIndex << " of " << words.size() << " processed";
|
||||
}
|
||||
if (resultTypes[0] == operands[0]->getType() &&
|
||||
if (resultTypes[0] == operands[0].getType() &&
|
||||
resultTypes[0].isa<IntegerType>()) {
|
||||
// TODO(b/130356985): This check is added to ignore error in Op verification
|
||||
// due to both signed and unsigned integers mapping to the same
|
||||
|
|
|
@ -507,10 +507,10 @@ void Serializer::printValueIDMap(raw_ostream &os) {
|
|||
Value val = valueIDPair.first;
|
||||
os << " " << val << " "
|
||||
<< "id = " << valueIDPair.second << ' ';
|
||||
if (auto *op = val->getDefiningOp()) {
|
||||
if (auto *op = val.getDefiningOp()) {
|
||||
os << "from op '" << op->getName() << "'";
|
||||
} else if (auto arg = val.dyn_cast<BlockArgument>()) {
|
||||
Block *block = arg->getOwner();
|
||||
Block *block = arg.getOwner();
|
||||
os << "from argument of block " << block << ' ';
|
||||
os << " in op '" << block->getParentOp()->getName() << "'";
|
||||
}
|
||||
|
@ -714,7 +714,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
|
|||
// Declare the parameters.
|
||||
for (auto arg : op.getArguments()) {
|
||||
uint32_t argTypeID = 0;
|
||||
if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
|
||||
if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
auto argValueID = getNextID();
|
||||
|
@ -1397,7 +1397,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
|
|||
|
||||
// Get the type <id> and result <id> for this OpPhi instruction.
|
||||
uint32_t phiTypeID = 0;
|
||||
if (failed(processType(arg->getLoc(), arg->getType(), phiTypeID)))
|
||||
if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
|
||||
return failure();
|
||||
uint32_t phiID = getNextID();
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
|
|||
|
||||
// Change the type for the direct users.
|
||||
target.addDynamicallyLegalOp<spirv::AddressOfOp>([](spirv::AddressOfOp op) {
|
||||
return VulkanLayoutUtils::isLegalType(op.pointer()->getType());
|
||||
return VulkanLayoutUtils::isLegalType(op.pointer().getType());
|
||||
});
|
||||
|
||||
// TODO: Change the type for the indirect users such as spv.Load, spv.Store,
|
||||
|
|
|
@ -79,7 +79,7 @@ struct StdInlinerInterface : public DialectInlinerInterface {
|
|||
// Replace the values directly with the return operands.
|
||||
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
||||
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
||||
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
|
||||
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -96,9 +96,9 @@ static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0);
|
||||
<< op->getOperand(0);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : " << op->getOperand(0)->getType();
|
||||
p << " : " << op->getOperand(0).getType();
|
||||
}
|
||||
|
||||
/// A custom binary operation printer that omits the "std." prefix from the
|
||||
|
@ -109,20 +109,20 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
|
||||
// If not all the operand and result types are the same, just use the
|
||||
// generic assembly form to avoid omitting information in printing.
|
||||
auto resultType = op->getResult(0)->getType();
|
||||
if (op->getOperand(0)->getType() != resultType ||
|
||||
op->getOperand(1)->getType() != resultType) {
|
||||
auto resultType = op->getResult(0).getType();
|
||||
if (op->getOperand(0).getType() != resultType ||
|
||||
op->getOperand(1).getType() != resultType) {
|
||||
p.printGenericOp(op);
|
||||
return;
|
||||
}
|
||||
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0) << ", " << *op->getOperand(1);
|
||||
<< op->getOperand(0) << ", " << op->getOperand(1);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
|
||||
// Now we can output only one type for all operands and the result.
|
||||
p << " : " << op->getResult(0)->getType();
|
||||
p << " : " << op->getResult(0).getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation printer that omits the "std." prefix from the
|
||||
|
@ -130,13 +130,13 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
|
|||
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
|
||||
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
|
||||
<< *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
<< op->getOperand(0) << " : " << op->getOperand(0).getType() << " to "
|
||||
<< op->getResult(0).getType();
|
||||
}
|
||||
|
||||
/// A custom cast operation verifier.
|
||||
template <typename T> static LogicalResult verifyCastOp(T op) {
|
||||
auto opType = op.getOperand()->getType();
|
||||
auto opType = op.getOperand().getType();
|
||||
auto resType = op.getType();
|
||||
if (!T::areCastCompatible(opType, resType))
|
||||
return op.emitError("operand type ") << opType << " and result type "
|
||||
|
@ -209,8 +209,8 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
|
|||
static LogicalResult foldMemRefCast(Operation *op) {
|
||||
bool folded = false;
|
||||
for (OpOperand &operand : op->getOpOperands()) {
|
||||
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
|
||||
if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
|
||||
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
|
||||
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
|
||||
operand.set(cast.getOperand());
|
||||
folded = true;
|
||||
}
|
||||
|
@ -281,7 +281,7 @@ static ParseResult parseAllocOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(AllocOp op) {
|
||||
auto memRefType = op.getResult()->getType().dyn_cast<MemRefType>();
|
||||
auto memRefType = op.getResult().getType().dyn_cast<MemRefType>();
|
||||
if (!memRefType)
|
||||
return op.emitOpError("result must be a memref");
|
||||
|
||||
|
@ -338,7 +338,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
|
|||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
|
||||
auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
|
@ -489,14 +489,14 @@ static LogicalResult verify(CallOp op) {
|
|||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (op.getOperand(i)->getType() != fnType.getInput(i))
|
||||
if (op.getOperand(i).getType() != fnType.getInput(i))
|
||||
return op.emitOpError("operand type mismatch");
|
||||
|
||||
if (fnType.getNumResults() != op.getNumResults())
|
||||
return op.emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (op.getResult(i)->getType() != fnType.getResult(i))
|
||||
if (op.getResult(i).getType() != fnType.getResult(i))
|
||||
return op.emitOpError("result type mismatch");
|
||||
|
||||
return success();
|
||||
|
@ -553,12 +553,12 @@ static ParseResult parseCallIndirectOp(OpAsmParser &parser,
|
|||
static void print(OpAsmPrinter &p, CallIndirectOp op) {
|
||||
p << "call_indirect " << op.getCallee() << '(' << op.getArgOperands() << ')';
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"});
|
||||
p << " : " << op.getCallee()->getType();
|
||||
p << " : " << op.getCallee().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CallIndirectOp op) {
|
||||
// The callee must be a function.
|
||||
auto fnType = op.getCallee()->getType().dyn_cast<FunctionType>();
|
||||
auto fnType = op.getCallee().getType().dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return op.emitOpError("callee must have function type");
|
||||
|
||||
|
@ -567,14 +567,14 @@ static LogicalResult verify(CallIndirectOp op) {
|
|||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
|
||||
if (op.getOperand(i + 1)->getType() != fnType.getInput(i))
|
||||
if (op.getOperand(i + 1).getType() != fnType.getInput(i))
|
||||
return op.emitOpError("operand type mismatch");
|
||||
|
||||
if (fnType.getNumResults() != op.getNumResults())
|
||||
return op.emitOpError("incorrect number of results for callee");
|
||||
|
||||
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
|
||||
if (op.getResult(i)->getType() != fnType.getResult(i))
|
||||
if (op.getResult(i).getType() != fnType.getResult(i))
|
||||
return op.emitOpError("result type mismatch");
|
||||
|
||||
return success();
|
||||
|
@ -616,7 +616,7 @@ static Type getI1SameShape(Builder *build, Type type) {
|
|||
static void buildCmpIOp(Builder *build, OperationState &result,
|
||||
CmpIPredicate predicate, Value lhs, Value rhs) {
|
||||
result.addOperands({lhs, rhs});
|
||||
result.types.push_back(getI1SameShape(build, lhs->getType()));
|
||||
result.types.push_back(getI1SameShape(build, lhs.getType()));
|
||||
result.addAttribute(
|
||||
CmpIOp::getPredicateAttrName(),
|
||||
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
|
||||
|
@ -668,7 +668,7 @@ static void print(OpAsmPrinter &p, CmpIOp op) {
|
|||
<< '"' << ", " << op.lhs() << ", " << op.rhs();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpIOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs()->getType();
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
||||
|
@ -769,7 +769,7 @@ CmpFPredicate CmpFOp::getPredicateByName(StringRef name) {
|
|||
static void buildCmpFOp(Builder *build, OperationState &result,
|
||||
CmpFPredicate predicate, Value lhs, Value rhs) {
|
||||
result.addOperands({lhs, rhs});
|
||||
result.types.push_back(getI1SameShape(build, lhs->getType()));
|
||||
result.types.push_back(getI1SameShape(build, lhs.getType()));
|
||||
result.addAttribute(
|
||||
CmpFOp::getPredicateAttrName(),
|
||||
build->getI64IntegerAttr(static_cast<int64_t>(predicate)));
|
||||
|
@ -824,7 +824,7 @@ static void print(OpAsmPrinter &p, CmpFOp op) {
|
|||
<< ", " << op.rhs();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{CmpFOp::getPredicateAttrName()});
|
||||
p << " : " << op.lhs()->getType();
|
||||
p << " : " << op.lhs().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CmpFOp op) {
|
||||
|
@ -1123,14 +1123,13 @@ void ConstantFloatOp::build(Builder *builder, OperationState &result,
|
|||
}
|
||||
|
||||
bool ConstantFloatOp::classof(Operation *op) {
|
||||
return ConstantOp::classof(op) &&
|
||||
op->getResult(0)->getType().isa<FloatType>();
|
||||
return ConstantOp::classof(op) && op->getResult(0).getType().isa<FloatType>();
|
||||
}
|
||||
|
||||
/// ConstantIntOp only matches values whose result type is an IntegerType.
|
||||
bool ConstantIntOp::classof(Operation *op) {
|
||||
return ConstantOp::classof(op) &&
|
||||
op->getResult(0)->getType().isa<IntegerType>();
|
||||
op->getResult(0).getType().isa<IntegerType>();
|
||||
}
|
||||
|
||||
void ConstantIntOp::build(Builder *builder, OperationState &result,
|
||||
|
@ -1151,7 +1150,7 @@ void ConstantIntOp::build(Builder *builder, OperationState &result,
|
|||
|
||||
/// ConstantIndexOp only matches values whose result type is Index.
|
||||
bool ConstantIndexOp::classof(Operation *op) {
|
||||
return ConstantOp::classof(op) && op->getResult(0)->getType().isIndex();
|
||||
return ConstantOp::classof(op) && op->getResult(0).getType().isIndex();
|
||||
}
|
||||
|
||||
void ConstantIndexOp::build(Builder *builder, OperationState &result,
|
||||
|
@ -1174,11 +1173,11 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Check that the memref operand's defining operation is an AllocOp.
|
||||
Value memref = dealloc.memref();
|
||||
if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
|
||||
if (!isa_and_nonnull<AllocOp>(memref.getDefiningOp()))
|
||||
return matchFailure();
|
||||
|
||||
// Check that all of the uses of the AllocOp are other DeallocOps.
|
||||
for (auto *user : memref->getUsers())
|
||||
for (auto *user : memref.getUsers())
|
||||
if (!isa<DeallocOp>(user))
|
||||
return matchFailure();
|
||||
|
||||
|
@ -1190,7 +1189,7 @@ struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
|
|||
} // end anonymous namespace.
|
||||
|
||||
static void print(OpAsmPrinter &p, DeallocOp op) {
|
||||
p << "dealloc " << *op.memref() << " : " << op.memref()->getType();
|
||||
p << "dealloc " << op.memref() << " : " << op.memref().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -1203,7 +1202,7 @@ static ParseResult parseDeallocOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static LogicalResult verify(DeallocOp op) {
|
||||
if (!op.memref()->getType().isa<MemRefType>())
|
||||
if (!op.memref().getType().isa<MemRefType>())
|
||||
return op.emitOpError("operand must be a memref");
|
||||
return success();
|
||||
}
|
||||
|
@ -1224,9 +1223,9 @@ LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, DimOp op) {
|
||||
p << "dim " << *op.getOperand() << ", " << op.getIndex();
|
||||
p << "dim " << op.getOperand() << ", " << op.getIndex();
|
||||
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"});
|
||||
p << " : " << op.getOperand()->getType();
|
||||
p << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -1251,7 +1250,7 @@ static LogicalResult verify(DimOp op) {
|
|||
return op.emitOpError("requires an integer attribute named 'index'");
|
||||
int64_t index = indexAttr.getValue().getSExtValue();
|
||||
|
||||
auto type = op.getOperand()->getType();
|
||||
auto type = op.getOperand().getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
if (index >= tensorType.getRank())
|
||||
return op.emitOpError("index is out of range");
|
||||
|
@ -1270,7 +1269,7 @@ static LogicalResult verify(DimOp op) {
|
|||
|
||||
OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold dim when the size along the index referred to is a constant.
|
||||
auto opType = memrefOrTensor()->getType();
|
||||
auto opType = memrefOrTensor().getType();
|
||||
int64_t indexSize = -1;
|
||||
if (auto tensorType = opType.dyn_cast<RankedTensorType>())
|
||||
indexSize = tensorType.getShape()[getIndex()];
|
||||
|
@ -1286,7 +1285,7 @@ OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
|
||||
// The size at getIndex() is now a dynamic size of a memref.
|
||||
auto memref = memrefOrTensor()->getDefiningOp();
|
||||
auto memref = memrefOrTensor().getDefiningOp();
|
||||
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
|
||||
return *(alloc.getDynamicSizes().begin() +
|
||||
memrefType.getDynamicDimIndex(getIndex()));
|
||||
|
@ -1367,16 +1366,15 @@ void DmaStartOp::build(Builder *builder, OperationState &result,
|
|||
}
|
||||
|
||||
void DmaStartOp::print(OpAsmPrinter &p) {
|
||||
p << "dma_start " << *getSrcMemRef() << '[' << getSrcIndices() << "], "
|
||||
<< *getDstMemRef() << '[' << getDstIndices() << "], " << *getNumElements()
|
||||
<< ", " << *getTagMemRef() << '[' << getTagIndices() << ']';
|
||||
p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], "
|
||||
<< getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
|
||||
<< ", " << getTagMemRef() << '[' << getTagIndices() << ']';
|
||||
if (isStrided())
|
||||
p << ", " << *getStride() << ", " << *getNumElementsPerStride();
|
||||
p << ", " << getStride() << ", " << getNumElementsPerStride();
|
||||
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getSrcMemRef()->getType();
|
||||
p << ", " << getDstMemRef()->getType();
|
||||
p << ", " << getTagMemRef()->getType();
|
||||
p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
|
||||
<< ", " << getTagMemRef().getType();
|
||||
}
|
||||
|
||||
// Parse DmaStartOp.
|
||||
|
@ -1506,7 +1504,7 @@ void DmaWaitOp::print(OpAsmPrinter &p) {
|
|||
p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], "
|
||||
<< getNumElements();
|
||||
p.printOptionalAttrDict(getAttrs());
|
||||
p << " : " << getTagMemRef()->getType();
|
||||
p << " : " << getTagMemRef().getType();
|
||||
}
|
||||
|
||||
// Parse DmaWaitOp.
|
||||
|
@ -1553,10 +1551,10 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ExtractElementOp op) {
|
||||
p << "extract_element " << *op.getAggregate() << '[' << op.getIndices();
|
||||
p << "extract_element " << op.getAggregate() << '[' << op.getIndices();
|
||||
p << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getAggregate()->getType();
|
||||
p << " : " << op.getAggregate().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
|
@ -1577,7 +1575,7 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
static LogicalResult verify(ExtractElementOp op) {
|
||||
auto aggregateType = op.getAggregate()->getType().cast<ShapedType>();
|
||||
auto aggregateType = op.getAggregate().getType().cast<ShapedType>();
|
||||
|
||||
// This should be possible with tablegen type constraints
|
||||
if (op.getType() != aggregateType.getElementType())
|
||||
|
@ -1634,7 +1632,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, LoadOp op) {
|
||||
p << "load " << *op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p << "load " << op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
@ -1781,7 +1779,7 @@ OpFoldResult MulIOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, PrefetchOp op) {
|
||||
p << PrefetchOp::getOperationName() << " " << *op.memref() << '[';
|
||||
p << PrefetchOp::getOperationName() << " " << op.memref() << '[';
|
||||
p.printOperands(op.indices());
|
||||
p << ']' << ", " << (op.isWrite() ? "write" : "read");
|
||||
p << ", locality<" << op.localityHint();
|
||||
|
@ -1851,7 +1849,7 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, RankOp op) {
|
||||
p << "rank " << *op.getOperand() << " : " << op.getOperand()->getType();
|
||||
p << "rank " << op.getOperand() << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -1866,7 +1864,7 @@ static ParseResult parseRankOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
|
||||
// Constant fold rank when the rank of the tensor is known.
|
||||
auto type = getOperand()->getType();
|
||||
auto type = getOperand().getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
|
||||
return IntegerAttr();
|
||||
|
@ -1954,10 +1952,10 @@ static LogicalResult verify(ReturnOp op) {
|
|||
<< " operands, but enclosing function returns " << results.size();
|
||||
|
||||
for (unsigned i = 0, e = results.size(); i != e; ++i)
|
||||
if (op.getOperand(i)->getType() != results[i])
|
||||
if (op.getOperand(i).getType() != results[i])
|
||||
return op.emitError()
|
||||
<< "type of return operand " << i << " ("
|
||||
<< op.getOperand(i)->getType()
|
||||
<< op.getOperand(i).getType()
|
||||
<< ") doesn't match function result type (" << results[i] << ")";
|
||||
|
||||
return success();
|
||||
|
@ -1997,13 +1995,13 @@ static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, SelectOp op) {
|
||||
p << "select " << op.getOperands() << " : " << op.getTrueValue()->getType();
|
||||
p << "select " << op.getOperands() << " : " << op.getTrueValue().getType();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
}
|
||||
|
||||
static LogicalResult verify(SelectOp op) {
|
||||
auto trueType = op.getTrueValue()->getType();
|
||||
auto falseType = op.getFalseValue()->getType();
|
||||
auto trueType = op.getTrueValue().getType();
|
||||
auto falseType = op.getFalseValue().getType();
|
||||
|
||||
if (trueType != falseType)
|
||||
return op.emitOpError(
|
||||
|
@ -2032,7 +2030,7 @@ OpFoldResult SelectOp::fold(ArrayRef<Attribute> operands) {
|
|||
static LogicalResult verify(SignExtendIOp op) {
|
||||
// Get the scalar type (which is either directly the type of the operand
|
||||
// or the vector's/tensor's element type.
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand().getType());
|
||||
auto dstType = getElementTypeOrSelf(op.getType());
|
||||
|
||||
// For now, index is forbidden for the source and the destination type.
|
||||
|
@ -2054,7 +2052,7 @@ static LogicalResult verify(SignExtendIOp op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, SplatOp op) {
|
||||
p << "splat " << *op.getOperand();
|
||||
p << "splat " << op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
@ -2074,7 +2072,7 @@ static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static LogicalResult verify(SplatOp op) {
|
||||
// TODO: we could replace this by a trait.
|
||||
if (op.getOperand()->getType() !=
|
||||
if (op.getOperand().getType() !=
|
||||
op.getType().cast<ShapedType>().getElementType())
|
||||
return op.emitError("operand should be of elemental type of result type");
|
||||
|
||||
|
@ -2103,8 +2101,8 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, StoreOp op) {
|
||||
p << "store " << *op.getValueToStore();
|
||||
p << ", " << *op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p << "store " << op.getValueToStore();
|
||||
p << ", " << op.getMemRef() << '[' << op.getIndices() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getMemRefType();
|
||||
}
|
||||
|
@ -2130,7 +2128,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
// First operand must have same type as memref element type.
|
||||
if (op.getValueToStore()->getType() != op.getMemRefType().getElementType())
|
||||
if (op.getValueToStore().getType() != op.getMemRefType().getElementType())
|
||||
return op.emitOpError(
|
||||
"first operand must have same type memref element type");
|
||||
|
||||
|
@ -2251,9 +2249,9 @@ static Type getTensorTypeFromMemRefType(Builder &b, Type type) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, TensorLoadOp op) {
|
||||
p << "tensor_load " << *op.getOperand();
|
||||
p << "tensor_load " << op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getOperand()->getType();
|
||||
p << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorLoadOp(OpAsmParser &parser,
|
||||
|
@ -2274,9 +2272,9 @@ static ParseResult parseTensorLoadOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, TensorStoreOp op) {
|
||||
p << "tensor_store " << *op.tensor() << ", " << *op.memref();
|
||||
p << "tensor_store " << op.tensor() << ", " << op.memref();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.memref()->getType();
|
||||
p << " : " << op.memref().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorStoreOp(OpAsmParser &parser,
|
||||
|
@ -2298,7 +2296,7 @@ static ParseResult parseTensorStoreOp(OpAsmParser &parser,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(TruncateIOp op) {
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand().getType());
|
||||
auto dstType = getElementTypeOrSelf(op.getType());
|
||||
|
||||
if (srcType.isa<IndexType>())
|
||||
|
@ -2344,13 +2342,13 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ViewOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[';
|
||||
p << op.getOperationName() << ' ' << op.getOperand(0) << '[';
|
||||
auto dynamicOffset = op.getDynamicOffset();
|
||||
if (dynamicOffset != nullptr)
|
||||
p.printOperand(dynamicOffset);
|
||||
p << "][" << op.getDynamicSizes() << ']';
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
|
||||
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
Value ViewOp::getDynamicOffset() {
|
||||
|
@ -2382,8 +2380,8 @@ static LogicalResult verifyDynamicStrides(MemRefType memrefType,
|
|||
}
|
||||
|
||||
static LogicalResult verify(ViewOp op) {
|
||||
auto baseType = op.getOperand(0)->getType().cast<MemRefType>();
|
||||
auto viewType = op.getResult()->getType().cast<MemRefType>();
|
||||
auto baseType = op.getOperand(0).getType().cast<MemRefType>();
|
||||
auto viewType = op.getResult().getType().cast<MemRefType>();
|
||||
|
||||
// The base memref should have identity layout map (or none).
|
||||
if (baseType.getAffineMaps().size() > 1 ||
|
||||
|
@ -2453,7 +2451,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
int64_t newOffset = oldOffset;
|
||||
unsigned dynamicOffsetOperandCount = 0;
|
||||
if (dynamicOffset != nullptr) {
|
||||
auto *defOp = dynamicOffset->getDefiningOp();
|
||||
auto *defOp = dynamicOffset.getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic offset will be folded into the map.
|
||||
newOffset = constantIndexOp.getValue();
|
||||
|
@ -2478,7 +2476,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
|
|||
newShapeConstants.push_back(dimSize);
|
||||
continue;
|
||||
}
|
||||
auto *defOp = viewOp.getOperand(dynamicDimPos)->getDefiningOp();
|
||||
auto *defOp = viewOp.getOperand(dynamicDimPos).getDefiningOp();
|
||||
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
|
||||
// Dynamic shape dimension will be folded.
|
||||
newShapeConstants.push_back(constantIndexOp.getValue());
|
||||
|
@ -2590,7 +2588,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
|
|||
ValueRange strides, Type resultType,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
if (!resultType)
|
||||
resultType = inferSubViewResultType(source->getType().cast<MemRefType>());
|
||||
resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
|
||||
auto segmentAttr = b->getI32VectorAttr(
|
||||
{1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
|
||||
static_cast<int32_t>(strides.size())});
|
||||
|
@ -2637,13 +2635,13 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, SubViewOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand(0) << '[' << op.offsets()
|
||||
p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets()
|
||||
<< "][" << op.sizes() << "][" << op.strides() << ']';
|
||||
|
||||
SmallVector<StringRef, 1> elidedAttrs = {
|
||||
SubViewOp::getOperandSegmentSizeAttr()};
|
||||
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
|
||||
p << " : " << op.getOperand(0)->getType() << " to " << op.getType();
|
||||
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(SubViewOp op) {
|
||||
|
@ -2757,8 +2755,8 @@ static LogicalResult verify(SubViewOp op) {
|
|||
}
|
||||
|
||||
raw_ostream &mlir::operator<<(raw_ostream &os, SubViewOp::Range &range) {
|
||||
return os << "range " << *range.offset << ":" << *range.size << ":"
|
||||
<< *range.stride;
|
||||
return os << "range " << range.offset << ":" << range.size << ":"
|
||||
<< range.stride;
|
||||
}
|
||||
|
||||
SmallVector<SubViewOp::Range, 8> SubViewOp::getRanges() {
|
||||
|
@ -2827,7 +2825,7 @@ public:
|
|||
}
|
||||
SmallVector<int64_t, 4> staticShape(subViewOp.getNumSizes());
|
||||
for (auto size : llvm::enumerate(subViewOp.sizes())) {
|
||||
auto defOp = size.value()->getDefiningOp();
|
||||
auto defOp = size.value().getDefiningOp();
|
||||
assert(defOp);
|
||||
staticShape[size.index()] = cast<ConstantIndexOp>(defOp).getValue();
|
||||
}
|
||||
|
@ -2873,7 +2871,7 @@ public:
|
|||
|
||||
SmallVector<int64_t, 4> staticStrides(subViewOp.getNumStrides());
|
||||
for (auto stride : llvm::enumerate(subViewOp.strides())) {
|
||||
auto defOp = stride.value()->getDefiningOp();
|
||||
auto defOp = stride.value().getDefiningOp();
|
||||
assert(defOp);
|
||||
assert(baseStrides[stride.index()] > 0);
|
||||
staticStrides[stride.index()] =
|
||||
|
@ -2924,7 +2922,7 @@ public:
|
|||
|
||||
auto staticOffset = baseOffset;
|
||||
for (auto offset : llvm::enumerate(subViewOp.offsets())) {
|
||||
auto defOp = offset.value()->getDefiningOp();
|
||||
auto defOp = offset.value().getDefiningOp();
|
||||
assert(defOp);
|
||||
assert(baseStrides[offset.index()] > 0);
|
||||
staticOffset +=
|
||||
|
@ -2959,7 +2957,7 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ZeroExtendIOp op) {
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand()->getType());
|
||||
auto srcType = getElementTypeOrSelf(op.getOperand().getType());
|
||||
auto dstType = getElementTypeOrSelf(op.getType());
|
||||
|
||||
if (srcType.isa<IndexType>())
|
||||
|
|
|
@ -163,9 +163,9 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
|||
assert(op->getNumResults() == 1 &&
|
||||
"only support broadcast check on one result");
|
||||
|
||||
auto type1 = op->getOperand(0)->getType();
|
||||
auto type2 = op->getOperand(1)->getType();
|
||||
auto retType = op->getResult(0)->getType();
|
||||
auto type1 = op->getOperand(0).getType();
|
||||
auto type2 = op->getOperand(1).getType();
|
||||
auto retType = op->getResult(0).getType();
|
||||
|
||||
// We forbid broadcasting vector and tensor.
|
||||
if (hasBothVectorAndTensorType({type1, type2, retType}))
|
||||
|
|
|
@ -67,7 +67,7 @@ void vector::ContractionOp::build(Builder *builder, OperationState &result,
|
|||
ArrayAttr indexingMaps,
|
||||
ArrayAttr iteratorTypes) {
|
||||
result.addOperands({lhs, rhs, acc});
|
||||
result.addTypes(acc->getType());
|
||||
result.addTypes(acc.getType());
|
||||
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
|
||||
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
|
||||
}
|
||||
|
@ -125,13 +125,13 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
|
|||
attrs.push_back(attr);
|
||||
|
||||
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
||||
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
|
||||
p << *op.rhs() << ", " << *op.acc();
|
||||
p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
|
||||
p << op.rhs() << ", " << op.acc();
|
||||
if (op.masks().size() == 2)
|
||||
p << ", " << op.masks();
|
||||
|
||||
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
||||
p << " : " << op.lhs().getType() << ", " << op.rhs().getType() << " into "
|
||||
<< op.getResultType();
|
||||
}
|
||||
|
||||
|
@ -211,7 +211,7 @@ static LogicalResult verify(ContractionOp op) {
|
|||
if (map.getNumDims() != numIterators)
|
||||
return op.emitOpError("expected indexing map ")
|
||||
<< index << " to have " << numIterators << " number of inputs";
|
||||
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
|
||||
auto operandType = op.getOperand(index).getType().cast<VectorType>();
|
||||
unsigned rank = operandType.getShape().size();
|
||||
if (map.getNumResults() != rank)
|
||||
return op.emitOpError("expected indexing map ")
|
||||
|
@ -351,10 +351,10 @@ SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, vector::ExtractElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << "[" << *op.position()
|
||||
<< " : " << op.position()->getType() << "]";
|
||||
p << op.getOperationName() << " " << op.vector() << "[" << op.position()
|
||||
<< " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType();
|
||||
p << " : " << op.vector().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||
|
@ -398,15 +398,15 @@ void vector::ExtractOp::build(Builder *builder, OperationState &result,
|
|||
Value source, ArrayRef<int64_t> position) {
|
||||
result.addOperands(source);
|
||||
auto positionAttr = getVectorSubscriptAttr(*builder, position);
|
||||
result.addTypes(inferExtractOpResultType(source->getType().cast<VectorType>(),
|
||||
result.addTypes(inferExtractOpResultType(source.getType().cast<VectorType>(),
|
||||
positionAttr));
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, vector::ExtractOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << op.position();
|
||||
p << op.getOperationName() << " " << op.vector() << op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {"position"});
|
||||
p << " : " << op.vector()->getType();
|
||||
p << " : " << op.vector().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) {
|
||||
|
@ -495,13 +495,13 @@ static ParseResult parseExtractSlicesOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ExtractSlicesOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.vector() << ", ";
|
||||
p << op.getOperationName() << ' ' << op.vector() << ", ";
|
||||
p << op.sizes() << ", " << op.strides();
|
||||
p.printOptionalAttrDict(
|
||||
op.getAttrs(),
|
||||
/*elidedAttrs=*/{ExtractSlicesOp::getSizesAttrName(),
|
||||
ExtractSlicesOp::getStridesAttrName()});
|
||||
p << " : " << op.vector()->getType();
|
||||
p << " : " << op.vector().getType();
|
||||
p << " into " << op.getResultTupleType();
|
||||
}
|
||||
|
||||
|
@ -594,7 +594,7 @@ void ExtractSlicesOp::getStrides(SmallVectorImpl<int64_t> &results) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, BroadcastOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << " : "
|
||||
p << op.getOperationName() << " " << op.source() << " : "
|
||||
<< op.getSourceType() << " to " << op.getVectorType();
|
||||
}
|
||||
|
||||
|
@ -642,15 +642,15 @@ void ShuffleOp::build(Builder *builder, OperationState &result, Value v1,
|
|||
Value v2, ArrayRef<int64_t> mask) {
|
||||
result.addOperands({v1, v2});
|
||||
auto maskAttr = getVectorSubscriptAttr(*builder, mask);
|
||||
result.addTypes(v1->getType());
|
||||
result.addTypes(v1.getType());
|
||||
result.addAttribute(getMaskAttrName(), maskAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, ShuffleOp op) {
|
||||
p << op.getOperationName() << " " << *op.v1() << ", " << *op.v2() << " "
|
||||
p << op.getOperationName() << " " << op.v1() << ", " << op.v2() << " "
|
||||
<< op.mask();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {ShuffleOp::getMaskAttrName()});
|
||||
p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
|
||||
p << " : " << op.v1().getType() << ", " << op.v2().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ShuffleOp op) {
|
||||
|
@ -725,10 +725,10 @@ static ParseResult parseShuffleOp(OpAsmParser &parser, OperationState &result) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertElementOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest() << "["
|
||||
<< *op.position() << " : " << op.position()->getType() << "]";
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << "["
|
||||
<< op.position() << " : " << op.position().getType() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.dest()->getType();
|
||||
p << " : " << op.dest().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
||||
|
@ -766,12 +766,12 @@ void InsertOp::build(Builder *builder, OperationState &result, Value source,
|
|||
Value dest, ArrayRef<int64_t> position) {
|
||||
result.addOperands({source, dest});
|
||||
auto positionAttr = getVectorSubscriptAttr(*builder, position);
|
||||
result.addTypes(dest->getType());
|
||||
result.addTypes(dest.getType());
|
||||
result.addAttribute(getPositionAttrName(), positionAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest()
|
||||
<< op.position();
|
||||
p.printOptionalAttrDict(op.getAttrs(), {InsertOp::getPositionAttrName()});
|
||||
p << " : " << op.getSourceType() << " into " << op.getDestVectorType();
|
||||
|
@ -851,13 +851,13 @@ static ParseResult parseInsertSlicesOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertSlicesOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.vectors() << ", ";
|
||||
p << op.getOperationName() << ' ' << op.vectors() << ", ";
|
||||
p << op.sizes() << ", " << op.strides();
|
||||
p.printOptionalAttrDict(
|
||||
op.getAttrs(),
|
||||
/*elidedAttrs=*/{InsertSlicesOp::getSizesAttrName(),
|
||||
InsertSlicesOp::getStridesAttrName()});
|
||||
p << " : " << op.vectors()->getType();
|
||||
p << " : " << op.vectors().getType();
|
||||
p << " into " << op.getResultVectorType();
|
||||
}
|
||||
|
||||
|
@ -890,14 +890,13 @@ void InsertStridedSliceOp::build(Builder *builder, OperationState &result,
|
|||
result.addOperands({source, dest});
|
||||
auto offsetsAttr = getVectorSubscriptAttr(*builder, offsets);
|
||||
auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
|
||||
result.addTypes(dest->getType());
|
||||
result.addTypes(dest.getType());
|
||||
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
|
||||
result.addAttribute(getStridesAttrName(), stridesAttr);
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, InsertStridedSliceOp op) {
|
||||
p << op.getOperationName() << " " << *op.source() << ", " << *op.dest()
|
||||
<< " ";
|
||||
p << op.getOperationName() << " " << op.source() << ", " << op.dest() << " ";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getSourceVectorType() << " into " << op.getDestVectorType();
|
||||
}
|
||||
|
@ -1049,10 +1048,10 @@ static LogicalResult verify(InsertStridedSliceOp op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, OuterProductOp op) {
|
||||
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||
p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
|
||||
if (!op.acc().empty())
|
||||
p << ", " << op.acc();
|
||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType();
|
||||
p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseOuterProductOp(OpAsmParser &parser,
|
||||
|
@ -1103,7 +1102,7 @@ static LogicalResult verify(OuterProductOp op) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ReshapeOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << ", [" << op.input_shape()
|
||||
p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape()
|
||||
<< "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
|
||||
SmallVector<StringRef, 2> elidedAttrs = {
|
||||
ReshapeOp::getOperandSegmentSizeAttr(),
|
||||
|
@ -1193,18 +1192,18 @@ static LogicalResult verify(ReshapeOp op) {
|
|||
// If all shape operands are produced by constant ops, verify that product
|
||||
// of dimensions for input/output shape match.
|
||||
auto isDefByConstant = [](Value operand) {
|
||||
return isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
|
||||
return isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
|
||||
};
|
||||
if (llvm::all_of(op.input_shape(), isDefByConstant) &&
|
||||
llvm::all_of(op.output_shape(), isDefByConstant)) {
|
||||
int64_t numInputElements = 1;
|
||||
for (auto operand : op.input_shape())
|
||||
numInputElements *=
|
||||
cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
|
||||
cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
|
||||
int64_t numOutputElements = 1;
|
||||
for (auto operand : op.output_shape())
|
||||
numOutputElements *=
|
||||
cast<ConstantIndexOp>(operand->getDefiningOp()).getValue();
|
||||
cast<ConstantIndexOp>(operand.getDefiningOp()).getValue();
|
||||
if (numInputElements != numOutputElements)
|
||||
return op.emitError("product of input and output shape sizes must match");
|
||||
}
|
||||
|
@ -1245,7 +1244,7 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
|
|||
auto sizesAttr = getVectorSubscriptAttr(*builder, sizes);
|
||||
auto stridesAttr = getVectorSubscriptAttr(*builder, strides);
|
||||
result.addTypes(
|
||||
inferStridedSliceOpResultType(source->getType().cast<VectorType>(),
|
||||
inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
|
||||
offsetsAttr, sizesAttr, stridesAttr));
|
||||
result.addAttribute(getOffsetsAttrName(), offsetsAttr);
|
||||
result.addAttribute(getSizesAttrName(), sizesAttr);
|
||||
|
@ -1253,9 +1252,9 @@ void StridedSliceOp::build(Builder *builder, OperationState &result,
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, StridedSliceOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector();
|
||||
p << op.getOperationName() << " " << op.vector();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.vector()->getType() << " to " << op.getResult()->getType();
|
||||
p << " : " << op.vector().getType() << " to " << op.getResult().getType();
|
||||
}
|
||||
|
||||
static ParseResult parseStridedSliceOp(OpAsmParser &parser,
|
||||
|
@ -1305,7 +1304,7 @@ static LogicalResult verify(StridedSliceOp op) {
|
|||
|
||||
auto resultType = inferStridedSliceOpResultType(
|
||||
op.getVectorType(), op.offsets(), op.sizes(), op.strides());
|
||||
if (op.getResult()->getType() != resultType) {
|
||||
if (op.getResult().getType() != resultType) {
|
||||
op.emitOpError("expected result type to be ") << resultType;
|
||||
return failure();
|
||||
}
|
||||
|
@ -1328,7 +1327,7 @@ public:
|
|||
PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp.
|
||||
auto defOp = stridedSliceOp.vector()->getDefiningOp();
|
||||
auto defOp = stridedSliceOp.vector().getDefiningOp();
|
||||
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
|
||||
if (!constantMaskOp)
|
||||
return matchFailure();
|
||||
|
@ -1365,7 +1364,7 @@ public:
|
|||
|
||||
// Replace 'stridedSliceOp' with ConstantMaskOp with sliced mask region.
|
||||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
stridedSliceOp, stridedSliceOp.getResult()->getType(),
|
||||
stridedSliceOp, stridedSliceOp.getResult().getType(),
|
||||
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
@ -1503,7 +1502,7 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
// Consistency of elemental types in memref and vector.
|
||||
MemRefType memrefType = op.getMemRefType();
|
||||
VectorType vectorType = op.getVectorType();
|
||||
auto paddingType = op.padding()->getType();
|
||||
auto paddingType = op.padding().getType();
|
||||
auto permutationMap = op.permutation_map();
|
||||
auto memrefElementType = memrefType.getElementType();
|
||||
|
||||
|
@ -1540,8 +1539,8 @@ static LogicalResult verify(TransferReadOp op) {
|
|||
// TransferWriteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void print(OpAsmPrinter &p, TransferWriteOp op) {
|
||||
p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref()
|
||||
<< "[" << op.indices() << "]";
|
||||
p << op.getOperationName() << " " << op.vector() << ", " << op.memref() << "["
|
||||
<< op.indices() << "]";
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getVectorType() << ", " << op.getMemRefType();
|
||||
}
|
||||
|
@ -1596,12 +1595,12 @@ static MemRefType inferVectorTypeCastResultType(MemRefType t) {
|
|||
void TypeCastOp::build(Builder *builder, OperationState &result, Value source) {
|
||||
result.addOperands(source);
|
||||
result.addTypes(
|
||||
inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
|
||||
inferVectorTypeCastResultType(source.getType().cast<MemRefType>()));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TypeCastOp op) {
|
||||
auto type = op.getOperand()->getType().cast<MemRefType>();
|
||||
p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
|
||||
auto type = op.getOperand().getType().cast<MemRefType>();
|
||||
p << op.getOperationName() << ' ' << op.memref() << " : " << type << " to "
|
||||
<< inferVectorTypeCastResultType(type);
|
||||
}
|
||||
|
||||
|
@ -1665,14 +1664,14 @@ static ParseResult parseTupleGetOp(OpAsmParser &parser,
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TupleGetOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.getOperand() << ", " << op.index();
|
||||
p << op.getOperationName() << ' ' << op.getOperand() << ", " << op.index();
|
||||
p.printOptionalAttrDict(op.getAttrs(),
|
||||
/*elidedAttrs=*/{TupleGetOp::getIndexAttrName()});
|
||||
p << " : " << op.getOperand()->getType();
|
||||
p << " : " << op.getOperand().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(TupleGetOp op) {
|
||||
auto tupleType = op.getOperand()->getType().cast<TupleType>();
|
||||
auto tupleType = op.getOperand().getType().cast<TupleType>();
|
||||
if (op.getIndex() < 0 ||
|
||||
op.getIndex() >= static_cast<int64_t>(tupleType.size()))
|
||||
return op.emitOpError("tuple get index out of range");
|
||||
|
@ -1696,12 +1695,12 @@ ParseResult parseConstantMaskOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static void print(OpAsmPrinter &p, ConstantMaskOp op) {
|
||||
p << op.getOperationName() << ' ' << op.mask_dim_sizes() << " : "
|
||||
<< op.getResult()->getType();
|
||||
<< op.getResult().getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ConstantMaskOp &op) {
|
||||
// Verify that array attr size matches the rank of the vector result.
|
||||
auto resultType = op.getResult()->getType().cast<VectorType>();
|
||||
auto resultType = op.getResult().getType().cast<VectorType>();
|
||||
if (static_cast<int64_t>(op.mask_dim_sizes().size()) != resultType.getRank())
|
||||
return op.emitOpError(
|
||||
"must specify array attr of size equal vector result rank");
|
||||
|
@ -1749,7 +1748,7 @@ static void print(OpAsmPrinter &p, CreateMaskOp op) {
|
|||
static LogicalResult verify(CreateMaskOp op) {
|
||||
// Verify that an operand was specified for each result vector each dimension.
|
||||
if (op.getNumOperands() !=
|
||||
op.getResult()->getType().cast<VectorType>().getRank())
|
||||
op.getResult().getType().cast<VectorType>().getRank())
|
||||
return op.emitOpError(
|
||||
"must specify an operand for each result vector dimension");
|
||||
return success();
|
||||
|
@ -1768,7 +1767,7 @@ ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, PrintOp op) {
|
||||
p << op.getOperationName() << ' ' << *op.source() << " : "
|
||||
p << op.getOperationName() << ' ' << op.source() << " : "
|
||||
<< op.getPrintType();
|
||||
}
|
||||
|
||||
|
@ -1783,19 +1782,19 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Return if any of 'createMaskOp' operands are not defined by a constant.
|
||||
auto is_not_def_by_constant = [](Value operand) {
|
||||
return !isa_and_nonnull<ConstantIndexOp>(operand->getDefiningOp());
|
||||
return !isa_and_nonnull<ConstantIndexOp>(operand.getDefiningOp());
|
||||
};
|
||||
if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant))
|
||||
return matchFailure();
|
||||
// Gather constant mask dimension sizes.
|
||||
SmallVector<int64_t, 4> maskDimSizes;
|
||||
for (auto operand : createMaskOp.operands()) {
|
||||
auto defOp = operand->getDefiningOp();
|
||||
auto defOp = operand.getDefiningOp();
|
||||
maskDimSizes.push_back(cast<ConstantIndexOp>(defOp).getValue());
|
||||
}
|
||||
// Replace 'createMaskOp' with ConstantMaskOp.
|
||||
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
|
||||
createMaskOp, createMaskOp.getResult()->getType(),
|
||||
createMaskOp, createMaskOp.getResult().getType(),
|
||||
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
|
@ -195,7 +195,7 @@ static void initUnrolledVectorState(VectorType vectorType, Value initValue,
|
|||
auto tupleType =
|
||||
generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
|
||||
state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
|
||||
initValue->getLoc(), tupleType, initValue, sizes, strides);
|
||||
initValue.getLoc(), tupleType, initValue, sizes, strides);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,7 +232,7 @@ static Value getOrCreateUnrolledVectorSlice(
|
|||
if (valueSlice == nullptr) {
|
||||
// Return tuple element at 'sliceLinearIndex'.
|
||||
auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
|
||||
auto initValueType = initValue->getType().cast<VectorType>();
|
||||
auto initValueType = initValue.getType().cast<VectorType>();
|
||||
auto vectorType =
|
||||
VectorType::get(state.unrolledShape, initValueType.getElementType());
|
||||
// Initialize 'cache' with slice from 'initValue'.
|
||||
|
@ -311,7 +311,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
|||
unsigned resultIndex,
|
||||
ArrayRef<int64_t> targetShape,
|
||||
PatternRewriter &builder) {
|
||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||
auto shapedType = op->getResult(0).getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
assert(false && "Expected a statically shaped result type");
|
||||
|
||||
|
@ -379,7 +379,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
|||
SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
|
||||
SmallVector<Value, 4> vectorTupleValues(resultValueState.numInstances);
|
||||
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
|
||||
vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>();
|
||||
vectorTupleTypes[i] = caches[resultIndex][i].getType().cast<VectorType>();
|
||||
vectorTupleValues[i] = caches[resultIndex][i];
|
||||
}
|
||||
TupleType tupleType = builder.getTupleType(vectorTupleTypes);
|
||||
|
@ -387,7 +387,7 @@ static Value unrollSingleResultStructuredOp(Operation *op,
|
|||
vectorTupleValues);
|
||||
|
||||
// Create InsertSlicesOp(Tuple(result_vectors)).
|
||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||
auto resultVectorType = op->getResult(0).getType().cast<VectorType>();
|
||||
SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
|
||||
SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
|
||||
|
||||
|
@ -411,7 +411,7 @@ static void getVectorContractionOpUnrollState(
|
|||
vectors.resize(numIterators);
|
||||
unsigned accOperandIndex = vector::ContractionOp::getAccOperandIndex();
|
||||
for (unsigned i = 0; i < numIterators; ++i) {
|
||||
vectors[i].type = contractionOp.getOperand(i)->getType().cast<VectorType>();
|
||||
vectors[i].type = contractionOp.getOperand(i).getType().cast<VectorType>();
|
||||
vectors[i].indexMap = iterationIndexMapList[i];
|
||||
vectors[i].operandIndex = i;
|
||||
vectors[i].isAcc = i == accOperandIndex ? true : false;
|
||||
|
@ -437,7 +437,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape,
|
|||
std::vector<VectorState> &vectors,
|
||||
unsigned &resultIndex) {
|
||||
// Verify that operation and operands all have the same vector shape.
|
||||
auto resultType = op->getResult(0)->getType().dyn_cast_or_null<VectorType>();
|
||||
auto resultType = op->getResult(0).getType().dyn_cast_or_null<VectorType>();
|
||||
assert(resultType && "Expected op with vector result type");
|
||||
auto resultShape = resultType.getShape();
|
||||
// Verify that all operands have the same vector type as result.
|
||||
|
@ -515,7 +515,7 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType,
|
|||
getAffineConstantExpr(offsets[it.index()], ctx);
|
||||
auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||
sliceIndices[it.index()] = rewriter.create<AffineApplyOp>(
|
||||
it.value()->getLoc(), map, ArrayRef<Value>(it.value()));
|
||||
it.value().getLoc(), map, ArrayRef<Value>(it.value()));
|
||||
}
|
||||
// Call 'fn' to generate slice 'i' at 'sliceIndices'.
|
||||
fn(i, sliceIndices);
|
||||
|
@ -536,8 +536,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
|
|||
// Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
|
||||
Value xferReadResult = xferReadOp.getResult();
|
||||
auto extractSlicesOp =
|
||||
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin());
|
||||
if (!xferReadResult->hasOneUse() || !extractSlicesOp)
|
||||
dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
|
||||
if (!xferReadResult.hasOneUse() || !extractSlicesOp)
|
||||
return matchFailure();
|
||||
|
||||
// Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
|
||||
|
@ -587,14 +587,14 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
|
|||
if (!xferWriteOp.permutation_map().isIdentity())
|
||||
return matchFailure();
|
||||
// Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
|
||||
auto *vectorDefOp = xferWriteOp.vector()->getDefiningOp();
|
||||
auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
|
||||
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
|
||||
if (!insertSlicesOp)
|
||||
return matchFailure();
|
||||
|
||||
// Get TupleOp operand of 'insertSlicesOp'.
|
||||
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
|
||||
insertSlicesOp.vectors()->getDefiningOp());
|
||||
insertSlicesOp.vectors().getDefiningOp());
|
||||
if (!tupleOp)
|
||||
return matchFailure();
|
||||
|
||||
|
@ -634,19 +634,19 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
|
|||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
|
||||
auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
|
||||
tupleGetOp.vectors()->getDefiningOp());
|
||||
tupleGetOp.vectors().getDefiningOp());
|
||||
if (!extractSlicesOp)
|
||||
return matchFailure();
|
||||
|
||||
// Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
|
||||
auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
|
||||
extractSlicesOp.vector()->getDefiningOp());
|
||||
extractSlicesOp.vector().getDefiningOp());
|
||||
if (!insertSlicesOp)
|
||||
return matchFailure();
|
||||
|
||||
// Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
|
||||
auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
|
||||
insertSlicesOp.vectors()->getDefiningOp());
|
||||
insertSlicesOp.vectors().getDefiningOp());
|
||||
if (!tupleOp)
|
||||
return matchFailure();
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ mlir::edsc::ValueHandle::ValueHandle(index_t cst) {
|
|||
auto &b = ScopedContext::getBuilder();
|
||||
auto loc = ScopedContext::getLocation();
|
||||
v = b.create<ConstantIndexOp>(loc, cst.v).getResult();
|
||||
t = v->getType();
|
||||
t = v.getType();
|
||||
}
|
||||
|
||||
ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
|
||||
|
@ -139,8 +139,8 @@ static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
|
|||
if (lbs.size() != 1 || ubs.size() != 1)
|
||||
return Optional<ValueHandle>();
|
||||
|
||||
auto *lbDef = lbs.front().getValue()->getDefiningOp();
|
||||
auto *ubDef = ubs.front().getValue()->getDefiningOp();
|
||||
auto *lbDef = lbs.front().getValue().getDefiningOp();
|
||||
auto *ubDef = ubs.front().getValue().getDefiningOp();
|
||||
if (!lbDef || !ubDef)
|
||||
return Optional<ValueHandle>();
|
||||
|
||||
|
@ -305,7 +305,7 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
|
|||
unsigned &numSymbols) {
|
||||
AffineExpr d;
|
||||
Value resultVal = nullptr;
|
||||
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val->getDefiningOp())) {
|
||||
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) {
|
||||
d = getAffineConstantExpr(constant.getValue(), context);
|
||||
} else if (isValidSymbol(val) && !isValidDim(val)) {
|
||||
d = getAffineSymbolExpr(numSymbols++, context);
|
||||
|
@ -344,8 +344,8 @@ template <typename IOp, typename FOp>
|
|||
static ValueHandle createBinaryHandle(
|
||||
ValueHandle lhs, ValueHandle rhs,
|
||||
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
|
||||
auto thisType = lhs.getValue()->getType();
|
||||
auto thatType = rhs.getValue()->getType();
|
||||
auto thisType = lhs.getValue().getType();
|
||||
auto thatType = rhs.getValue().getType();
|
||||
assert(thisType == thatType && "cannot mix types in operators");
|
||||
(void)thisType;
|
||||
(void)thatType;
|
||||
|
|
|
@ -14,7 +14,7 @@ using namespace mlir;
|
|||
using namespace mlir::edsc;
|
||||
|
||||
static SmallVector<ValueHandle, 8> getMemRefSizes(Value memRef) {
|
||||
MemRefType memRefType = memRef->getType().cast<MemRefType>();
|
||||
MemRefType memRefType = memRef.getType().cast<MemRefType>();
|
||||
assert(isStrided(memRefType) && "Expected strided MemRef type");
|
||||
|
||||
SmallVector<ValueHandle, 8> res;
|
||||
|
@ -31,7 +31,7 @@ static SmallVector<ValueHandle, 8> getMemRefSizes(Value memRef) {
|
|||
}
|
||||
|
||||
mlir::edsc::MemRefView::MemRefView(Value v) : base(v) {
|
||||
assert(v->getType().isa<MemRefType>() && "MemRefType expected");
|
||||
assert(v.getType().isa<MemRefType>() && "MemRefType expected");
|
||||
|
||||
auto memrefSizeValues = getMemRefSizes(v);
|
||||
for (auto &size : memrefSizeValues) {
|
||||
|
@ -42,7 +42,7 @@ mlir::edsc::MemRefView::MemRefView(Value v) : base(v) {
|
|||
}
|
||||
|
||||
mlir::edsc::VectorView::VectorView(Value v) : base(v) {
|
||||
auto vectorType = v->getType().cast<VectorType>();
|
||||
auto vectorType = v.getType().cast<VectorType>();
|
||||
|
||||
for (auto s : vectorType.getShape()) {
|
||||
lbs.push_back(static_cast<index_t>(0));
|
||||
|
|
|
@ -405,7 +405,7 @@ void AliasState::visitOperation(Operation *op) {
|
|||
for (auto ®ion : op->getRegions())
|
||||
for (auto &block : region)
|
||||
for (auto arg : block.getArguments())
|
||||
visitType(arg->getType());
|
||||
visitType(arg.getType());
|
||||
|
||||
// Visit each of the attributes.
|
||||
for (auto elt : op->getAttrs())
|
||||
|
@ -615,7 +615,7 @@ void SSANameState::numberValuesInBlock(
|
|||
DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
|
||||
auto setArgNameFn = [&](Value arg, StringRef name) {
|
||||
assert(!valueIDs.count(arg) && "arg numbered multiple times");
|
||||
assert(arg.cast<BlockArgument>()->getOwner() == &block &&
|
||||
assert(arg.cast<BlockArgument>().getOwner() == &block &&
|
||||
"arg not defined in 'block'");
|
||||
setValueName(arg, name);
|
||||
};
|
||||
|
@ -659,11 +659,11 @@ void SSANameState::numberValuesInOp(
|
|||
SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
|
||||
auto setResultNameFn = [&](Value result, StringRef name) {
|
||||
assert(!valueIDs.count(result) && "result numbered multiple times");
|
||||
assert(result->getDefiningOp() == &op && "result not defined by 'op'");
|
||||
assert(result.getDefiningOp() == &op && "result not defined by 'op'");
|
||||
setValueName(result, name);
|
||||
|
||||
// Record the result number for groups not anchored at 0.
|
||||
if (int resultNo = result.cast<OpResult>()->getResultNumber())
|
||||
if (int resultNo = result.cast<OpResult>().getResultNumber())
|
||||
resultGroups.push_back(resultNo);
|
||||
};
|
||||
if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
|
||||
|
@ -684,10 +684,10 @@ void SSANameState::numberValuesInOp(
|
|||
|
||||
void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
|
||||
Optional<int> &lookupResultNo) const {
|
||||
Operation *owner = result->getOwner();
|
||||
Operation *owner = result.getOwner();
|
||||
if (owner->getNumResults() == 1)
|
||||
return;
|
||||
int resultNo = result->getResultNumber();
|
||||
int resultNo = result.getResultNumber();
|
||||
|
||||
// If this operation has multiple result groups, we will need to find the
|
||||
// one corresponding to this result.
|
||||
|
@ -2009,7 +2009,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
|
|||
interleaveComma(block->getArguments(), [&](BlockArgument arg) {
|
||||
printValueID(arg);
|
||||
os << ": ";
|
||||
printType(arg->getType());
|
||||
printType(arg.getType());
|
||||
});
|
||||
os << ')';
|
||||
}
|
||||
|
@ -2068,7 +2068,7 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
|
|||
[this](Value operand) { printValueID(operand); });
|
||||
os << " : ";
|
||||
interleaveComma(succOperands,
|
||||
[this](Value operand) { printType(operand->getType()); });
|
||||
[this](Value operand) { printType(operand.getType()); });
|
||||
os << ')';
|
||||
}
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ void Block::dropAllReferences() {
|
|||
|
||||
void Block::dropAllDefinedValueUses() {
|
||||
for (auto arg : getArguments())
|
||||
arg->dropAllUses();
|
||||
arg.dropAllUses();
|
||||
for (auto &op : *this)
|
||||
op.dropAllDefinedValueUses();
|
||||
dropAllUses();
|
||||
|
|
|
@ -377,7 +377,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
|
|||
// Ask the dialect to materialize a constant operation for this value.
|
||||
Attribute attr = it.value().get<Attribute>();
|
||||
auto *constOp = dialect->materializeConstant(
|
||||
cstBuilder, attr, op->getResult(it.index())->getType(), op->getLoc());
|
||||
cstBuilder, attr, op->getResult(it.index()).getType(), op->getLoc());
|
||||
if (!constOp) {
|
||||
// Erase any generated constants.
|
||||
for (Operation *cst : generatedConstants)
|
||||
|
|
|
@ -96,9 +96,9 @@ LogicalResult FuncOp::verify() {
|
|||
auto fnInputTypes = getType().getInputs();
|
||||
Block &entryBlock = front();
|
||||
for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i)
|
||||
if (fnInputTypes[i] != entryBlock.getArgument(i)->getType())
|
||||
if (fnInputTypes[i] != entryBlock.getArgument(i).getType())
|
||||
return emitOpError("type of entry block argument #")
|
||||
<< i << '(' << entryBlock.getArgument(i)->getType()
|
||||
<< i << '(' << entryBlock.getArgument(i).getType()
|
||||
<< ") must match the type of the corresponding argument in "
|
||||
<< "function signature(" << fnInputTypes[i] << ')';
|
||||
|
||||
|
|
|
@ -809,7 +809,7 @@ LogicalResult OpTrait::impl::verifySameTypeOperands(Operation *op) {
|
|||
if (nOperands < 2)
|
||||
return success();
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1))
|
||||
if (opType != type)
|
||||
return op->emitOpError() << "requires all operands to have the same type";
|
||||
|
@ -847,7 +847,7 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
|
|||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
|
||||
if (failed(verifyCompatibleShape(opType, type)))
|
||||
return op->emitOpError() << "requires the same shape for all operands";
|
||||
|
@ -860,7 +860,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
|||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes()) {
|
||||
if (failed(verifyCompatibleShape(resultType, type)))
|
||||
return op->emitOpError()
|
||||
|
@ -917,7 +917,7 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
|
|||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getResult(0)->getType();
|
||||
auto type = op->getResult(0).getType();
|
||||
auto elementType = getElementTypeOrSelf(type);
|
||||
for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
|
||||
if (getElementTypeOrSelf(resultType) != elementType ||
|
||||
|
@ -946,7 +946,7 @@ static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
|
|||
|
||||
auto operandIt = operands.begin();
|
||||
for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
|
||||
if ((*operandIt)->getType() != destBB->getArgument(i)->getType())
|
||||
if ((*operandIt).getType() != destBB->getArgument(i).getType())
|
||||
return op->emitError() << "type mismatch for bb argument #" << i
|
||||
<< " of successor #" << succNo;
|
||||
}
|
||||
|
@ -1056,9 +1056,9 @@ LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
|
|||
|
||||
void impl::buildBinaryOp(Builder *builder, OperationState &result, Value lhs,
|
||||
Value rhs) {
|
||||
assert(lhs->getType() == rhs->getType());
|
||||
assert(lhs.getType() == rhs.getType());
|
||||
result.addOperands({lhs, rhs});
|
||||
result.types.push_back(lhs->getType());
|
||||
result.types.push_back(lhs.getType());
|
||||
}
|
||||
|
||||
ParseResult impl::parseOneResultSameOperandTypeOp(OpAsmParser &parser,
|
||||
|
@ -1077,7 +1077,7 @@ void impl::printOneResultOp(Operation *op, OpAsmPrinter &p) {
|
|||
|
||||
// If not all the operand and result types are the same, just use the
|
||||
// generic assembly form to avoid omitting information in printing.
|
||||
auto resultType = op->getResult(0)->getType();
|
||||
auto resultType = op->getResult(0).getType();
|
||||
if (llvm::any_of(op->getOperandTypes(),
|
||||
[&](Type type) { return type != resultType; })) {
|
||||
p.printGenericOp(op);
|
||||
|
@ -1113,15 +1113,15 @@ ParseResult impl::parseCastOp(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
void impl::printCastOp(Operation *op, OpAsmPrinter &p) {
|
||||
p << op->getName() << ' ' << *op->getOperand(0);
|
||||
p << op->getName() << ' ' << op->getOperand(0);
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
p << " : " << op->getOperand(0)->getType() << " to "
|
||||
<< op->getResult(0)->getType();
|
||||
p << " : " << op->getOperand(0).getType() << " to "
|
||||
<< op->getResult(0).getType();
|
||||
}
|
||||
|
||||
Value impl::foldCastOp(Operation *op) {
|
||||
// Identity cast
|
||||
if (op->getOperand(0)->getType() == op->getResult(0)->getType())
|
||||
if (op->getOperand(0).getType() == op->getResult(0).getType())
|
||||
return op->getOperand(0);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue