[mlir] Remove support for non-prefixed accessors

This finishes off a year long pursuit to LLVMify the generated
operation accessors, prefixing them with get/set. Support for
any other accessor naming is fully removed after this commit.

https://discourse.llvm.org/t/psa-raw-accessors-are-being-removed/65629

Differential Revision: https://reviews.llvm.org/D136727
This commit is contained in:
River Riddle 2022-10-25 18:29:53 -07:00
parent f9048cc131
commit b74192b7ae
23 changed files with 251 additions and 360 deletions

View File

@ -371,7 +371,7 @@ public:
}
patterns.insert<ReturnOpConversion>(context, newArg);
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
[](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
[](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
assert(func.getFunctionType() ==
getNewFunctionType(funcTy, shouldBoxResult));
} else {

View File

@ -108,6 +108,29 @@ are included, you may want to specify a full namespace path or a partial one. In
to use full namespaces whenever you can. This makes it easier for dialects within different namespaces,
and projects, to interact with each other.
### C++ Accessor Generation
When generating accessors for dialects and their components (attributes, operations, types, etc.),
we prefix the name with `get` and `set` respectively, and transform `snake_style` names to camel
case (`UpperCamel` when prefixed, and `lowerCamel` for individual variable names). For example, if an
operation were defined as:
```tablegen
def MyOp : MyDialect<"op"> {
let arguments = (ins StrAttr:$value, StrAttr:$other_value);
}
```
It would have accessors generated for the `value` and `other_value` attributes as follows:
```c++
StringAttr MyOp::getValue();
void MyOp::setValue(StringAttr newValue);
StringAttr MyOp::getOtherValue();
void MyOp::setOtherValue(StringAttr newValue);
```
### Dependent Dialects
MLIR has a very large ecosystem, and contains dialects that server many different purposes. It
@ -279,59 +302,6 @@ void MyDialect::getCanonicalizationPatterns(RewritePatternSet &results) const;
See the documentation for [Canonicalization in MLIR](Canonicalization.md) for a much more
detailed description about canonicalization patterns.
### C++ Accessor Prefix
Historically, MLIR has generated accessors for operation components (such as attribute, operands,
results) using the tablegen definition name verbatim. This means that if an operation was defined
as:
```tablegen
def MyOp : MyDialect<"op"> {
let arguments = (ins StrAttr:$value, StrAttr:$other_value);
}
```
It would have accessors generated for the `value` and `other_value` attributes as follows:
```c++
StringAttr MyOp::value();
void MyOp::value(StringAttr newValue);
StringAttr MyOp::other_value();
void MyOp::other_value(StringAttr newValue);
```
Since then, we have decided to move accessors over to a style that matches the rest of the
code base. More specifically, this means that we prefix accessors with `get` and `set`
respectively, and transform `snake_style` names to camel case (`UpperCamel` when prefixed,
and `lowerCamel` for individual variable names). If we look at the same example as above, this
would produce:
```c++
StringAttr MyOp::getValue();
void MyOp::setValue(StringAttr newValue);
StringAttr MyOp::getOtherValue();
void MyOp::setOtherValue(StringAttr newValue);
```
The form in which accessors are generated is controlled by the `emitAccessorPrefix` field.
This field may any of the following values:
* `kEmitAccessorPrefix_Raw`
- Don't emit any `get`/`set` prefix.
* `kEmitAccessorPrefix_Prefixed`
- Only emit with `get`/`set` prefix.
* `kEmitAccessorPrefix_Both`
- Emit with **and** without prefix.
All new dialects are strongly encouraged to use the default `kEmitAccessorPrefix_Prefixed`
value, as the `Raw` form is deprecated and in the process of being removed.
Note: Remove this section when all dialects have been switched to the new accessor form.
## Defining an Extensible dialect
This section documents the design and API of the extensible dialects. Extensible

View File

@ -560,14 +560,14 @@ class AffineMinMaxOpBase<string mnemonic, list<Trait> traits = []> :
let extraClassDeclaration = [{
static StringRef getMapAttrStrName() { return "map"; }
AffineMap getAffineMap() { return getMap(); }
ValueRange getMapOperands() { return operands(); }
ValueRange getMapOperands() { return getOperands(); }
ValueRange getDimOperands() {
return OperandRange{operands().begin(),
operands().begin() + getMap().getNumDims()};
return OperandRange{getOperands().begin(),
getOperands().begin() + getMap().getNumDims()};
}
ValueRange getSymbolOperands() {
return OperandRange{operands().begin() + getMap().getNumDims(),
operands().end()};
return OperandRange{getOperands().begin() + getMap().getNumDims(),
getOperands().end()};
}
}];
let hasCustomAssemblyFormat = 1;

View File

@ -17,11 +17,6 @@
// Dialect definitions
//===----------------------------------------------------------------------===//
// "Enum" values for emitAccessorPrefix of Dialect.
defvar kEmitAccessorPrefix_Raw = 0; // Don't emit any getter/setter prefix.
defvar kEmitAccessorPrefix_Prefixed = 1; // Only emit with getter/setter prefix.
defvar kEmitAccessorPrefix_Both = 2; // Emit without and with prefix.
class Dialect {
// The name of the dialect.
string name = ?;
@ -88,17 +83,6 @@ class Dialect {
// If this dialect overrides the hook for canonicalization patterns.
bit hasCanonicalizer = 0;
// Whether to emit raw/with no prefix or format changes, or emit with
// accessor with prefix only and UpperCamel suffix or to emit accessors with
// both.
//
// If emitting with prefix is specified then the attribute/operand's
// name is converted to UpperCamel from snake_case (which would result in
// leaving UpperCamel unchanged while also converting lowerCamel to
// UpperCamel) and prefixed with `get` or `set` depending on if it is a getter
// or setter.
int emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
// If this dialect can be extended at runtime with new operations or types.
bit isExtensible = 0;
}

View File

@ -98,10 +98,6 @@ public:
// Returns whether the dialect is defined.
explicit operator bool() const { return def != nullptr; }
// Returns how the accessors should be prefixed in dialect.
enum class EmitPrefix { Raw = 0, Prefixed = 1, Both = 2 };
EmitPrefix getEmitAccessorPrefix() const;
private:
const llvm::Record *def;
std::vector<StringRef> dependentDialects;

View File

@ -302,16 +302,11 @@ public:
// Returns the builders of this operation.
ArrayRef<Builder> getBuilders() const { return builders; }
// Returns the preferred getter name for the accessor.
std::string getGetterName(StringRef name) const {
return getGetterNames(name).front();
}
// Returns the getter name for the accessor of `name`.
std::string getGetterName(StringRef name) const;
// Returns the getter names for the accessor.
SmallVector<std::string, 2> getGetterNames(StringRef name) const;
// Returns the setter names for the accessor.
SmallVector<std::string, 2> getSetterNames(StringRef name) const;
// Returns the setter name for the accessor of `name`.
std::string getSetterName(StringRef name) const;
private:
// Populates the vectors containing operands, attributes, results and traits.

View File

@ -103,7 +103,7 @@ public:
LogicalResult matchAndRewrite(AffineMinOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.operands());
lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.getOperands());
if (!reduced)
return failure();
@ -119,7 +119,7 @@ public:
LogicalResult matchAndRewrite(AffineMaxOp op,
PatternRewriter &rewriter) const override {
Value reduced =
lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.operands());
lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.getOperands());
if (!reduced)
return failure();
@ -141,7 +141,7 @@ public:
rewriter.replaceOpWithNewOp<scf::YieldOp>(op);
return success();
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.operands());
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
return success();
}
};

View File

@ -536,7 +536,7 @@ static bool isGpuAsyncTokenType(Value value) {
LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
async::YieldOp yieldOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
if (llvm::none_of(yieldOp.getOperands(), isGpuAsyncTokenType))
return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
Location loc = yieldOp.getLoc();

View File

@ -54,7 +54,7 @@ LogicalResult YieldOp::verify() {
MutableOperandRange
YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
return operandsMutable();
return getOperandsMutable();
}
//===----------------------------------------------------------------------===//

View File

@ -82,7 +82,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
SmallVector<Value> newReturnValues;
BitVector erasedResultIndices(funcOp.getFunctionType().getNumResults());
DenseMap<int64_t, int64_t> resultToArgs;
for (const auto &it : llvm::enumerate(returnOp.operands())) {
for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
bool erased = false;
for (BlockArgument bbArg : funcOp.getArguments()) {
Value val = it.value();
@ -105,7 +105,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
// Update function.
funcOp.eraseResults(erasedResultIndices);
returnOp.operandsMutable().assign(newReturnValues);
returnOp.getOperandsMutable().assign(newReturnValues);
// Update function calls.
module.walk([&](func::CallOp callOp) {
@ -114,7 +114,7 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) {
rewriter.setInsertionPoint(callOp);
auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp,
callOp.operands());
callOp.getOperands());
SmallVector<Value> newResults;
int64_t nextResult = 0;
for (int64_t i = 0; i < callOp.getNumResults(); ++i) {

View File

@ -484,7 +484,7 @@ struct FuncOpInterface
}
// 3. Rewrite the terminator without the in-place bufferizable values.
returnOp.operandsMutable().assign(returnValues);
returnOp.getOperandsMutable().assign(returnValues);
// 4. Rewrite the FuncOp type to buffer form.
funcOp.setType(FunctionType::get(op->getContext(), argTypes,

View File

@ -1066,14 +1066,14 @@ LogicalResult gpu::ReturnOp::verify() {
FunctionType funType = function.getFunctionType();
if (funType.getNumResults() != operands().size())
if (funType.getNumResults() != getOperands().size())
return emitOpError()
.append("expected ", funType.getNumResults(), " result operands")
.attachNote(function.getLoc())
.append("return type declared here");
for (const auto &pair : llvm::enumerate(
llvm::zip(function.getFunctionType().getResults(), operands()))) {
llvm::zip(function.getFunctionType().getResults(), getOperands()))) {
auto [type, operand] = pair.value();
if (type != operand.getType())
return emitOpError() << "unexpected type `" << operand.getType()

View File

@ -1091,7 +1091,7 @@ struct PadOpVectorizationWithTransferWritePattern
auto minOp1 = v1.getDefiningOp<AffineMinOp>();
auto minOp2 = v2.getDefiningOp<AffineMinOp>();
if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
minOp1.operands() == minOp2.operands())
minOp1.getOperands() == minOp2.getOperands())
continue;
// Add additional cases as needed.

View File

@ -192,8 +192,8 @@ struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
return failure();
};
return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
op.operands(), IsMin, loopMatcher);
return scf::canonicalizeMinMaxOpInLoop(
rewriter, op, op.getAffineMap(), op.getOperands(), IsMin, loopMatcher);
}
};

View File

@ -167,14 +167,14 @@ static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
forOp.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, mainIv,
affineOp.getOperands(), IsMin, mainIv,
previousUb, step,
/*insideLoop=*/true);
});
partialIteration.walk([&](OpTy affineOp) {
AffineMap map = affineOp.getAffineMap();
(void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map,
affineOp.operands(), IsMin, partialIv,
affineOp.getOperands(), IsMin, partialIv,
previousUb, step, /*insideLoop=*/false);
});
}

View File

@ -67,7 +67,7 @@ struct AssumingOpInterface
assumingOp.getDoRegion().front().getTerminator());
// Create new op and move over region.
TypeRange newResultTypes(yieldOp.operands());
TypeRange newResultTypes(yieldOp.getOperands());
auto newOp = rewriter.create<shape::AssumingOp>(
op->getLoc(), newResultTypes, assumingOp.getWitness());
newOp.getDoRegion().takeBody(assumingOp.getRegion());
@ -130,7 +130,7 @@ struct AssumingYieldOpInterface
const BufferizationOptions &options) const {
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
for (Value value : yieldOp.operands()) {
for (Value value : yieldOp.getOperands()) {
if (value.getType().isa<TensorType>()) {
FailureOr<Value> buffer = getBuffer(rewriter, value, options);
if (failed(buffer))

View File

@ -5187,7 +5187,7 @@ public:
auto isNotDefByConstant = [](Value operand) {
return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
};
if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
return failure();
// CreateMaskOp for scalable vectors can be folded only if all dimensions
@ -5206,7 +5206,7 @@ public:
SmallVector<int64_t, 4> maskDimSizes;
maskDimSizes.reserve(createMaskOp->getNumOperands());
for (auto [operand, maxDimSize] : llvm::zip_equal(
createMaskOp.operands(), createMaskOp.getType().getShape())) {
createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
Operation *defOp = operand.getDefiningOp();
int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
dimSize = std::min(dimSize, maxDimSize);

View File

@ -183,7 +183,7 @@ static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns(
cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator());
rewriter.updateRootInPlace(
yield, [&]() { yield.operandsMutable().assign(newYieldedValues); });
yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); });
return newWarpOp;
}
@ -349,7 +349,7 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
SmallVector<Value> replacements;
auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator());
Location yieldLoc = yieldOp.getLoc();
for (const auto &it : llvm::enumerate(yieldOp.operands())) {
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
Value sequentialVal = it.value();
Value distributedVal = warpOp->getResult(it.index());
DistributedLoadStoreHelper helper(sequentialVal, distributedVal,
@ -379,7 +379,7 @@ struct WarpOpToScfIfPattern : public OpRewritePattern<WarpExecuteOnLane0Op> {
}
// Step 6. Insert sync after all the stores and before all the loads.
if (!yieldOp.operands().empty()) {
if (!yieldOp.getOperands().empty()) {
rewriter.setInsertionPointAfter(ifOp);
options.warpSyncronizationFn(loc, rewriter, warpOp);
}

View File

@ -98,14 +98,6 @@ bool Dialect::useDefaultTypePrinterParser() const {
return def->getValueAsBit("useDefaultTypePrinterParser");
}
Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const {
int prefix = def->getValueAsInt("emitAccessorPrefix");
if (prefix < 0 || prefix > static_cast<int>(EmitPrefix::Both))
PrintFatalError(def->getLoc(), "Invalid accessor prefix value");
return static_cast<EmitPrefix>(prefix);
}
bool Dialect::isExtensible() const {
return def->getValueAsBit("isExtensible");
}

View File

@ -69,6 +69,41 @@ std::string Operator::getAdaptorName() const {
return std::string(llvm::formatv("{0}Adaptor", getCppClassName()));
}
/// Assert the invariants of accessors generated for the given name.
static void assertAccessorInvariants(const Operator &op, StringRef name) {
std::string accessorName =
convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
// Functor used to detect when an accessor will cause an overlap with an
// operation API.
//
// There are a little bit more invasive checks possible for cases where not
// all ops have the trait that would cause overlap. For many cases here,
// renaming would be better (e.g., we can only guard in limited manner
// against methods from traits and interfaces here, so avoiding these in op
// definition is safer).
auto nameOverlapsWithOpAPI = [&](StringRef newName) {
if (newName == "AttributeNames" || newName == "Attributes" ||
newName == "Operation")
return true;
if (newName == "Operands")
return op.getNumOperands() != 1 || op.getNumVariableLengthOperands() != 1;
if (newName == "Regions")
return op.getNumRegions() != 1 || op.getNumVariadicRegions() != 1;
if (newName == "Type")
return op.getNumResults() != 1;
return false;
};
if (nameOverlapsWithOpAPI(accessorName)) {
// This error could be avoided in situations where the final function is
// identical, but preferably the op definition should avoid using generic
// names.
PrintFatalError(op.getLoc(), "generated accessor for `" + name +
"` overlaps with a default one; please "
"rename to avoid overlap");
}
}
void Operator::assertInvariants() const {
// Check that the name of arguments/results/regions/successors don't overlap.
DenseMap<StringRef, StringRef> existingNames;
@ -76,8 +111,11 @@ void Operator::assertInvariants() const {
if (name.empty())
return;
auto insertion = existingNames.insert({name, entity});
if (insertion.second)
if (insertion.second) {
// Assert invariants for accessors generated for this name.
assertAccessorInvariants(*this, name);
return;
}
if (entity == insertion.first->second)
PrintFatalError(getLoc(), "op has a conflict with two " + entity +
" having the same name '" + name + "'");
@ -692,82 +730,10 @@ auto Operator::getArgToOperandOrAttribute(int index) const
return attrOrOperandMapping[index];
}
// Helper to return the names for accessor.
static SmallVector<std::string, 2>
getGetterOrSetterNames(bool isGetter, const Operator &op, StringRef name) {
Dialect::EmitPrefix prefixType = op.getDialect().getEmitAccessorPrefix();
std::string prefix;
if (prefixType != Dialect::EmitPrefix::Raw)
prefix = isGetter ? "get" : "set";
SmallVector<std::string, 2> names;
bool rawToo = prefixType == Dialect::EmitPrefix::Both;
// Whether to skip generating prefixed form for argument. This just does some
// basic checks.
//
// There are a little bit more invasive checks possible for cases where not
// all ops have the trait that would cause overlap. For many cases here,
// renaming would be better (e.g., we can only guard in limited manner against
// methods from traits and interfaces here, so avoiding these in op definition
// is safer).
auto skip = [&](StringRef newName) {
bool shouldSkip = newName == "getAttributeNames" ||
newName == "getAttributes" || newName == "getOperation";
if (newName == "getOperands") {
// To reduce noise, skip generating the prefixed form and the warning if
// $operands correspond to single variadic argument.
if (op.getNumOperands() == 1 && op.getNumVariableLengthOperands() == 1)
return true;
shouldSkip = true;
}
if (newName == "getRegions") {
if (op.getNumRegions() == 1 && op.getNumVariadicRegions() == 1)
return true;
shouldSkip = true;
}
if (newName == "getType") {
if (op.getNumResults() != 1)
return false;
shouldSkip = true;
}
if (!shouldSkip)
return false;
// This note could be avoided where the final function generated would
// have been identical. But preferably in the op definition avoiding using
// the generic name and then getting a more specialize type is better.
PrintNote(op.getLoc(),
"Skipping generation of prefixed accessor `" + newName +
"` as it overlaps with default one; generating raw form (`" +
name + "`) still");
return true;
};
if (!prefix.empty()) {
names.push_back(
prefix + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true));
// Skip cases which would overlap with default ones for now.
if (skip(names.back())) {
rawToo = true;
names.clear();
} else if (rawToo) {
LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName()
<< "::" << name << "\")\n"
<< "WITH_GETTER(\"" << op.getQualCppClassName()
<< "Adaptor::" << name << "\")\n";);
}
}
if (prefix.empty() || rawToo)
names.push_back(name.str());
return names;
std::string Operator::getGetterName(StringRef name) const {
return "get" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
}
SmallVector<std::string, 2> Operator::getGetterNames(StringRef name) const {
return getGetterOrSetterNames(/*isGetter=*/true, *this, name);
}
SmallVector<std::string, 2> Operator::getSetterNames(StringRef name) const {
return getGetterOrSetterNames(/*isGetter=*/false, *this, name);
std::string Operator::getSetterName(StringRef name) const {
return "set" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
}

View File

@ -30,6 +30,6 @@ def OpMultipleSingleResult : Op<Test_Dialect, "multiple_single_result"> {
}
def OpMultiVariadic : Op<Test_Dialect, "multi_variadic"> {
let arguments = (ins Variadic<I64>:$operands, Variadic<I64>:$operand2);
let results = (outs Variadic<I64>:$results, Variadic<I64>:$results2);
let arguments = (ins Variadic<I64>:$operands1, Variadic<I64>:$operand2);
let results = (outs Variadic<I64>:$results1, Variadic<I64>:$results2);
}

View File

@ -109,9 +109,9 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// DEFS-LABEL: NS::AOp definitions
// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions)
// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions()
// DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions()
// DEFS-NEXT: return odsRegions.drop_front(1);
// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions()
// Check AttrSizedOperandSegments
// ---

View File

@ -933,26 +933,24 @@ void OpEmitter::genAttrNameGetters() {
// users.
const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) {
for (StringRef name : op.getGetterNames(attrIt.value())) {
std::string methodName = (name + "AttrName").str();
std::string name = op.getGetterName(attrIt.value());
std::string methodName = name + "AttrName";
// Generate the non-static variant.
{
auto *method =
opClass.addInlineMethod("::mlir::StringAttr", methodName);
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
}
// Generate the non-static variant.
{
auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
}
// Generate the static variant.
{
auto *method = opClass.addStaticInlineMethod(
"::mlir::StringAttr", methodName,
MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
"name, " + Twine(attrIt.index()));
}
// Generate the static variant.
{
auto *method = opClass.addStaticInlineMethod(
"::mlir::StringAttr", methodName,
MethodParameter("::mlir::OperationName", "name"));
ERROR_IF_PRUNED(method, methodName, op);
method->body() << llvm::formatv(attrNameMethodBody,
"name, " + Twine(attrIt.index()));
}
}
}
@ -1014,13 +1012,12 @@ void OpEmitter::genAttrGetters() {
};
for (const NamedAttribute &namedAttr : op.getAttributes()) {
for (StringRef name : op.getGetterNames(namedAttr.name)) {
if (namedAttr.attr.isDerivedAttr()) {
emitDerivedAttr(name, namedAttr.attr);
} else {
emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
}
std::string name = op.getGetterName(namedAttr.name);
if (namedAttr.attr.isDerivedAttr()) {
emitDerivedAttr(name, namedAttr.attr);
} else {
emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
}
}
@ -1165,12 +1162,10 @@ void OpEmitter::genAttrSetters() {
for (const NamedAttribute &namedAttr : op.getAttributes()) {
if (namedAttr.attr.isDerivedAttr())
continue;
for (auto [setterName, getterName] :
llvm::zip(op.getSetterNames(namedAttr.name),
op.getGetterNames(namedAttr.name))) {
emitAttrWithStorageType(setterName, getterName, namedAttr.attr);
emitAttrWithReturnType(setterName, getterName, namedAttr.attr);
}
std::string setterName = op.getSetterName(namedAttr.name);
std::string getterName = op.getGetterName(namedAttr.name);
emitAttrWithStorageType(setterName, getterName, namedAttr.attr);
emitAttrWithReturnType(setterName, getterName, namedAttr.attr);
}
}
@ -1305,38 +1300,36 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
if (operand.isOptional()) {
m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
std::string name = op.getGetterName(operand.name);
if (operand.isOptional()) {
m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : "
"*operands.begin();";
} else if (operand.isVariadicOfVariadic()) {
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
if (isAdaptor) {
m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : "
"*operands.begin();";
} else if (operand.isVariadicOfVariadic()) {
std::string segmentAttr = op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
if (isAdaptor) {
m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
name);
ERROR_IF_PRUNED(m, name, op);
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
segmentAttr, i);
continue;
}
m = opClass.addMethod("::mlir::OperandRangeRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ").split("
<< segmentAttr << "Attr());";
} else if (operand.isVariadic()) {
m = opClass.addMethod(rangeType, name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
segmentAttr, i);
continue;
}
m = opClass.addMethod("::mlir::OperandRangeRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
<< "Attr());";
} else if (operand.isVariadic()) {
m = opClass.addMethod(rangeType, name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSOperands(" << i << ");";
} else {
m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSOperands(" << i << ").begin();";
}
}
}
@ -1367,37 +1360,37 @@ void OpEmitter::genNamedOperandSetters() {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
for (StringRef name : op.getGetterNames(operand.name)) {
auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
? "::mlir::MutableOperandRangeRange"
: "::mlir::MutableOperandRange",
(name + "Mutable").str());
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " auto mutableRange = "
"::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands) {
body << formatv(
", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
}
body << ");\n";
std::string name = op.getGetterName(operand.name);
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
// sub-ranges.
if (operand.isVariadicOfVariadic()) {
body << " return "
"mutableRange.split(*(*this)->getAttrDictionary().getNamed("
<< op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
<< "AttrName()));\n";
} else {
// Otherwise, we use the full range directly.
body << " return mutableRange;\n";
}
auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
? "::mlir::MutableOperandRangeRange"
: "::mlir::MutableOperandRange",
name + "Mutable");
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " auto mutableRange = "
"::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands) {
body << formatv(
", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
}
body << ");\n";
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
// sub-ranges.
if (operand.isVariadicOfVariadic()) {
body << " return "
"mutableRange.split(*(*this)->getAttrDictionary().getNamed("
<< op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
<< "AttrName()));\n";
} else {
// Otherwise, we use the full range directly.
body << " return mutableRange;\n";
}
}
}
@ -1454,24 +1447,23 @@ void OpEmitter::genNamedResultGetters() {
const auto &result = op.getResult(i);
if (result.name.empty())
continue;
for (StringRef name : op.getGetterNames(result.name)) {
if (result.isOptional()) {
m = opClass.addMethod(
generateTypeForGetter(/*isAdaptor=*/false, result), name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
} else if (result.isVariadic()) {
m = opClass.addMethod("::mlir::Operation::result_range", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
m = opClass.addMethod(
generateTypeForGetter(/*isAdaptor=*/false, result), name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
std::string name = op.getGetterName(result.name);
if (result.isOptional()) {
m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
name);
ERROR_IF_PRUNED(m, name, op);
m->body()
<< " auto results = getODSResults(" << i << ");\n"
<< " return results.empty() ? ::mlir::Value() : *results.begin();";
} else if (result.isVariadic()) {
m = opClass.addMethod("::mlir::Operation::result_range", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return getODSResults(" << i << ");";
} else {
m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result),
name);
ERROR_IF_PRUNED(m, name, op);
m->body() << " return *getODSResults(" << i << ").begin();";
}
}
}
@ -1482,22 +1474,21 @@ void OpEmitter::genNamedRegionGetters() {
const auto &region = op.getRegion(i);
if (region.name.empty())
continue;
std::string name = op.getGetterName(region.name);
for (StringRef name : op.getGetterNames(region.name)) {
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
auto *m =
opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
i);
continue;
}
auto *m = opClass.addMethod("::mlir::Region &", name);
// Generate the accessors for a variadic region.
if (region.isVariadic()) {
auto *m =
opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegion({0});", i);
m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
i);
continue;
}
auto *m = opClass.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getRegion({0});", i);
}
}
@ -1507,23 +1498,21 @@ void OpEmitter::genNamedSuccessorGetters() {
const NamedSuccessor &successor = op.getSuccessor(i);
if (successor.name.empty())
continue;
for (StringRef name : op.getGetterNames(successor.name)) {
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(
" return {std::next((*this)->successor_begin(), {0}), "
"(*this)->successor_end()};",
i);
continue;
}
auto *m = opClass.addMethod("::mlir::Block *", name);
std::string name = op.getGetterName(successor.name);
// Generate the accessors for a variadic successor list.
if (successor.isVariadic()) {
auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
m->body() << formatv(
" return {std::next((*this)->successor_begin(), {0}), "
"(*this)->successor_end()};",
i);
continue;
}
auto *m = opClass.addMethod("::mlir::Block *", name);
ERROR_IF_PRUNED(m, name, op);
m->body() << formatv(" return (*this)->getSuccessor({0});", i);
}
}
@ -2992,11 +2981,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
constructor->addMemberInitializer("odsOpName", "op->getName()");
}
{
auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
ERROR_IF_PRUNED(m, "getOperands", op);
m->body() << " return odsOperands;";
}
std::string sizeAttrInit;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
@ -3009,6 +2993,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
/*rangeSizeCall=*/"odsOperands.size()",
/*getOperandCallPattern=*/"odsOperands[{0}]");
// Any invalid overlap for `getOperands` will have been diagnosed before here
// already.
if (auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"))
m->body() << " return odsOperands;";
FmtContext fctx;
fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
@ -3046,36 +3035,35 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
const auto &attr = namedAttr.attr;
if (attr.isDerivedAttr())
continue;
for (const auto &emitName : op.getGetterNames(name)) {
emitAttrWithStorageType(name, emitName, attr);
emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
}
std::string emitName = op.getGetterName(name);
emitAttrWithStorageType(name, emitName, attr);
emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
}
unsigned numRegions = op.getNumRegions();
if (numRegions > 0) {
auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
m->body() << " return odsRegions;";
}
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
if (region.name.empty())
continue;
// Generate the accessors for a variadic region.
for (StringRef name : op.getGetterNames(region.name)) {
if (region.isVariadic()) {
auto *m = adaptor.addMethod("::mlir::RegionRange", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
auto *m = adaptor.addMethod("::mlir::Region &", name);
std::string name = op.getGetterName(region.name);
if (region.isVariadic()) {
auto *m = adaptor.addMethod("::mlir::RegionRange", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return *odsRegions[{0}];", i);
m->body() << formatv(" return odsRegions.drop_front({0});", i);
continue;
}
auto *m = adaptor.addMethod("::mlir::Region &", name);
ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
m->body() << formatv(" return *odsRegions[{0}];", i);
}
if (numRegions > 0) {
// Any invalid overlap for `getRegions` will have been diagnosed before here
// already.
if (auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"))
m->body() << " return odsRegions;";
}
// Add verification function.