[MLIR] Move eraseArguments and eraseResults to FunctionLike

Previously, they were only defined for `FuncOp`.

To support this, `FunctionLike` needs a way to get an updated type
from the concrete operation. This adds a new hook for that purpose,
called `getTypeWithoutArgsAndResults`.

For now, `FunctionLike` continues to assume the type is
`FunctionType`, and concrete operations that use another type can hide
the `getType`, `setType`, and `getTypeWithoutArgsAndResults` methods.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D90363
This commit is contained in:
mikeurbach 2020-10-22 11:39:39 -06:00
parent 50c2f2b6f0
commit 2e36e0dad5
11 changed files with 255 additions and 89 deletions

View File

@ -255,11 +255,17 @@ particular:
- they can have argument and result attributes that are stored in dictionary - they can have argument and result attributes that are stored in dictionary
attributes on the operation itself. attributes on the operation itself.
This trait does *NOT* provide type support for the functions, meaning that This trait provides limited type support for the declared or defined functions.
concrete Ops must handle the type of the declared or defined function. The convenience function `getTypeAttrName()` returns the name of an attribute
`getTypeAttrName()` is a convenience function that returns the name of the that can be used to store the function type. In addition, this trait provides
attribute that can be used to store the function type, but the trait makes no `getType` and `setType` helpers to store a `FunctionType` in the attribute named
assumption based on it. by `getTypeAttrName()`.
In general, this trait assumes concrete ops use `FunctionType` under the hood.
If this is not the case, in order to use the function type support, concrete ops
must define the following methods, using the same name, to hide the ones defined
for `FunctionType`: `addBodyBlock`, `getType`, `getTypeWithoutArgsAndResults`
and `setType`.
### HasParent ### HasParent

View File

@ -16,6 +16,10 @@
#include "mlir/IR/BlockSupport.h" #include "mlir/IR/BlockSupport.h"
#include "mlir/IR/Visitors.h" #include "mlir/IR/Visitors.h"
namespace llvm {
class BitVector;
} // end namespace llvm
namespace mlir { namespace mlir {
class TypeRange; class TypeRange;
template <typename ValueRangeT> class ValueTypeRange; template <typename ValueRangeT> class ValueTypeRange;
@ -98,6 +102,13 @@ public:
/// Erase the argument at 'index' and remove it from the argument list. /// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index); void eraseArgument(unsigned index);
/// Erases the arguments listed in `argIndices` and removes them from the
/// argument list.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erases the arguments that have their corresponding bit set in
/// `eraseIndices` and removes them from the argument list.
void eraseArguments(llvm::BitVector eraseIndices);
unsigned getNumArguments() { return arguments.size(); } unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; } BlockArgument getArgument(unsigned i) { return arguments[i]; }

View File

@ -59,18 +59,6 @@ public:
void print(OpAsmPrinter &p); void print(OpAsmPrinter &p);
LogicalResult verify(); LogicalResult verify();
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
/// Erases the arguments listed in `argIndices`.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices);
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
/// Erases the results listed in `resultIndices`.
/// `resultIndices` is allowed to have duplicates and can be in any order.
void eraseResults(ArrayRef<unsigned> resultIndices);
/// Create a deep copy of this function and all of its blocks, remapping /// Create a deep copy of this function and all of its blocks, remapping
/// any operands that use values outside of the function using the map that is /// any operands that use values outside of the function using the map that is
/// provided (leaving them alone if no entry is present). If the mapper /// provided (leaving them alone if no entry is present). If the mapper

View File

@ -71,6 +71,14 @@ inline ArrayRef<NamedAttribute> getResultAttrs(Operation *op, unsigned index) {
return resultDict ? resultDict.getValue() : llvm::None; return resultDict ? resultDict.getValue() : llvm::None;
} }
/// Erase the specified arguments and update the function type attribute.
void eraseFunctionArguments(Operation *op, ArrayRef<unsigned> argIndices,
unsigned originalNumArgs, Type newType);
/// Erase the specified results and update the function type attribute.
void eraseFunctionResults(Operation *op, ArrayRef<unsigned> resultIndices,
unsigned originalNumResults, Type newType);
} // namespace impl } // namespace impl
namespace OpTrait { namespace OpTrait {
@ -84,12 +92,21 @@ namespace OpTrait {
/// arguments; /// arguments;
/// - they can have argument attributes that are stored in a dictionary /// - they can have argument attributes that are stored in a dictionary
/// attribute on the Op itself. /// attribute on the Op itself.
/// This trait does *NOT* provide type support for the functions, meaning that
/// concrete Ops must handle the type of the declared or defined function.
/// `getTypeAttrName()` is a convenience function that returns the name of the
/// attribute that can be used to store the function type, but the trait makes
/// no assumption based on it.
/// ///
/// This trait provides limited type support for the declared or defined
/// functions. The convenience function `getTypeAttrName()` returns the name of
/// an attribute that can be used to store the function type. In addition, this
/// trait provides `getType` and `setType` helpers to store a `FunctionType` in
/// the attribute named by `getTypeAttrName()`.
///
/// In general, this trait assumes concrete ops use `FunctionType` under the
/// hood. If this is not the case, in order to use the function type support,
/// concrete ops must define the following methods, using the same name, to hide
/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`,
/// `getTypeWithoutArgsAndResults` and `setType`.
///
/// Besides the requirements above, concrete ops must interact with this trait
/// using the following functions:
/// - Concrete ops *must* define a member function `getNumFuncArguments()` that /// - Concrete ops *must* define a member function `getNumFuncArguments()` that
/// returns the number of function arguments based exclusively on type (so /// returns the number of function arguments based exclusively on type (so
/// that it can be called on function declarations). /// that it can be called on function declarations).
@ -183,6 +200,19 @@ public:
return getTypeAttr().getValue().template cast<FunctionType>(); return getTypeAttr().getValue().template cast<FunctionType>();
} }
/// Return the type of this function without the specified arguments and
/// results. This is used to update the function's signature in the
/// `eraseArguments` and `eraseResults` methods. The arrays of indices are
/// allowed to have duplicates and can be in any order.
///
/// Note that the concrete class must define a method with the same name to
/// hide this one if the concrete class does not use FunctionType for the
/// function type under the hood.
FunctionType getTypeWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices) {
return getType().getWithoutArgsAndResults(argIndices, resultIndices);
}
bool isTypeAttrValid() { bool isTypeAttrValid() {
auto typeAttr = getTypeAttr(); auto typeAttr = getTypeAttr();
if (!typeAttr) if (!typeAttr)
@ -204,7 +234,7 @@ public:
void setType(FunctionType newType); void setType(FunctionType newType);
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Argument Handling // Argument and Result Handling
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
using BlockArgListType = Region::BlockArgListType; using BlockArgListType = Region::BlockArgListType;
@ -229,6 +259,30 @@ public:
return getBody().getArgumentTypes(); return getBody().getArgumentTypes();
} }
/// Erase a single argument at `argIndex`.
void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
/// Erases the arguments listed in `argIndices`.
/// `argIndices` is allowed to have duplicates and can be in any order.
void eraseArguments(ArrayRef<unsigned> argIndices) {
unsigned originalNumArgs = getNumArguments();
Type newType = getTypeWithoutArgsAndResults(argIndices, {});
::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices,
originalNumArgs, newType);
}
/// Erase a single result at `resultIndex`.
void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); }
/// Erases the results listed in `resultIndices`.
/// `resultIndices` is allowed to have duplicates and can be in any order.
void eraseResults(ArrayRef<unsigned> resultIndices) {
unsigned originalNumResults = getNumResults();
Type newType = getTypeWithoutArgsAndResults({}, resultIndices);
::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices,
originalNumResults, newType);
}
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Argument Attributes // Argument Attributes
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//

View File

@ -238,15 +238,19 @@ public:
static FunctionType get(TypeRange inputs, TypeRange results, static FunctionType get(TypeRange inputs, TypeRange results,
MLIRContext *context); MLIRContext *context);
// Input types. /// Input types.
unsigned getNumInputs() const; unsigned getNumInputs() const;
Type getInput(unsigned i) const { return getInputs()[i]; } Type getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type> getInputs() const; ArrayRef<Type> getInputs() const;
// Result types. /// Result types.
unsigned getNumResults() const; unsigned getNumResults() const;
Type getResult(unsigned i) const { return getResults()[i]; } Type getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type> getResults() const; ArrayRef<Type> getResults() const;
/// Returns a new function type without the specified arguments and results.
FunctionType getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices);
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -25,8 +25,6 @@
#include "llvm/Support/CommandLine.h" #include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include <set>
#define DEBUG_TYPE "linalg-drop-unit-dims" #define DEBUG_TYPE "linalg-drop-unit-dims"
using namespace mlir; using namespace mlir;
@ -166,9 +164,8 @@ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
for (unsigned unitDimLoop : unitDims) { for (unsigned unitDimLoop : unitDims) {
entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero);
} }
std::set<unsigned> orderedUnitDims(unitDims.begin(), unitDims.end()); SmallVector<unsigned, 8> unitDimsToErase(unitDims.begin(), unitDims.end());
for (unsigned i : llvm::reverse(orderedUnitDims)) entryBlock->eraseArguments(unitDimsToErase);
entryBlock->eraseArgument(i);
return success(); return success();
} }

View File

@ -9,6 +9,7 @@
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "llvm/ADT/BitVector.h"
using namespace mlir; using namespace mlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -176,6 +177,22 @@ void Block::eraseArgument(unsigned index) {
arguments.erase(arguments.begin() + index); arguments.erase(arguments.begin() + index);
} }
void Block::eraseArguments(ArrayRef<unsigned> argIndices) {
llvm::BitVector eraseIndices(getNumArguments());
for (unsigned i : argIndices)
eraseIndices.set(i);
eraseArguments(eraseIndices);
}
void Block::eraseArguments(llvm::BitVector eraseIndices) {
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
unsigned originalNumArgs = getNumArguments();
for (unsigned i = 0; i < originalNumArgs; ++i)
if (eraseIndices.test(originalNumArgs - i - 1))
eraseArgument(originalNumArgs - i - 1);
}
/// Insert one value to the given position of the argument list. The existing /// Insert one value to the given position of the argument list. The existing
/// arguments are shifted. The block is expected not to have predecessors. /// arguments are shifted. The block is expected not to have predecessors.
BlockArgument Block::insertArgument(args_iterator it, Type type) { BlockArgument Block::insertArgument(args_iterator it, Type type) {

View File

@ -10,6 +10,7 @@ add_mlir_library(MLIRIR
Dominance.cpp Dominance.cpp
Function.cpp Function.cpp
FunctionImplementation.cpp FunctionImplementation.cpp
FunctionSupport.cpp
IntegerSet.cpp IntegerSet.cpp
Location.cpp Location.cpp
MLIRContext.cpp MLIRContext.cpp

View File

@ -98,65 +98,6 @@ LogicalResult FuncOp::verify() {
return success(); return success();
} }
void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
auto oldType = getType();
int originalNumArgs = oldType.getNumInputs();
llvm::BitVector eraseIndices(originalNumArgs);
for (auto index : argIndices)
eraseIndices.set(index);
auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
// Update the function type and arg attrs.
SmallVector<Type, 4> newInputTypes;
SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
for (int i = 0; i < originalNumArgs; i++) {
if (shouldEraseArg(i))
continue;
newInputTypes.emplace_back(oldType.getInput(i));
newArgAttrs.emplace_back(getArgAttrDict(i));
}
setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
setAllArgAttrs(newArgAttrs);
// Update the entry block's arguments.
// We do this in reverse so that we erase later indices before earlier
// indices, to avoid shifting the later indices.
Block &entry = front();
for (int i = 0; i < originalNumArgs; i++)
if (shouldEraseArg(originalNumArgs - i - 1))
entry.eraseArgument(originalNumArgs - i - 1);
}
void FuncOp::eraseResults(ArrayRef<unsigned> resultIndices) {
auto oldType = getType();
int originalNumResults = oldType.getNumResults();
llvm::BitVector eraseIndices(originalNumResults);
for (auto index : resultIndices)
eraseIndices.set(index);
auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); };
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the function type and result attrs.
SmallVector<Type, 4> newResultTypes;
SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
for (int i = 0; i < originalNumResults; i++) {
if (shouldEraseResult(i))
continue;
newResultTypes.emplace_back(oldType.getResult(i));
newResultAttrs.emplace_back(getResultAttrDict(i));
}
setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext()));
setAllResultAttrs(newResultAttrs);
}
/// Clone the internal blocks from this function into dest and all attributes /// Clone the internal blocks from this function into dest and all attributes
/// from this function to dest. /// from this function to dest.
void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {

View File

@ -0,0 +1,103 @@
//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/FunctionSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/BitVector.h"
using namespace mlir;
/// Helper to call a callback once on each index in the range
/// [0, `totalIndices`), *except* for the indices given in `indices`.
/// `indices` is allowed to have duplicates and can be in any order.
inline void iterateIndicesExcept(unsigned totalIndices,
ArrayRef<unsigned> indices,
function_ref<void(unsigned)> callback) {
llvm::BitVector skipIndices(totalIndices);
for (unsigned i : indices)
skipIndices.set(i);
for (unsigned i = 0; i < totalIndices; ++i)
if (!skipIndices.test(i))
callback(i);
}
//===----------------------------------------------------------------------===//
// Function Arguments and Results.
//===----------------------------------------------------------------------===//
void mlir::impl::eraseFunctionArguments(Operation *op,
ArrayRef<unsigned> argIndices,
unsigned originalNumArgs,
Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
Block &entry = op->getRegion(0).front();
SmallString<8> nameBuf;
// Collect arg attrs to set.
SmallVector<MutableDictionaryAttr, 4> newArgAttrs;
iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
newArgAttrs.emplace_back(getArgAttrDict(op, i));
});
// Remove any arg attrs that are no longer needed.
for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i)
op->removeAttr(getArgAttrName(i, nameBuf));
// Set the function type.
op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
// Set the new arg attrs, or remove them if empty.
for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) {
auto nameAttr = getArgAttrName(i, nameBuf);
auto argAttr = newArgAttrs[i];
if (argAttr.empty())
op->removeAttr(nameAttr);
else
op->setAttr(nameAttr, argAttr.getDictionary(op->getContext()));
}
// Update the entry block's arguments.
entry.eraseArguments(argIndices);
}
void mlir::impl::eraseFunctionResults(Operation *op,
ArrayRef<unsigned> resultIndices,
unsigned originalNumResults,
Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
SmallString<8> nameBuf;
// Collect result attrs to set.
SmallVector<MutableDictionaryAttr, 4> newResultAttrs;
iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
newResultAttrs.emplace_back(getResultAttrDict(op, i));
});
// Remove any result attrs that are no longer needed.
for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i)
op->removeAttr(getResultAttrName(i, nameBuf));
// Set the function type.
op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
// Set the new result attrs, or remove them if empty.
for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) {
auto nameAttr = getResultAttrName(i, nameBuf);
auto resultAttr = newResultAttrs[i];
if (resultAttr.empty())
op->removeAttr(nameAttr);
else
op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext()));
}
}

View File

@ -10,6 +10,8 @@
#include "TypeDetail.h" #include "TypeDetail.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/Twine.h" #include "llvm/ADT/Twine.h"
using namespace mlir; using namespace mlir;
@ -46,6 +48,48 @@ ArrayRef<Type> FunctionType::getResults() const {
return getImpl()->getResults(); return getImpl()->getResults();
} }
/// Helper to call a callback once on each index in the range
/// [0, `totalIndices`), *except* for the indices given in `indices`.
/// `indices` is allowed to have duplicates and can be in any order.
inline void iterateIndicesExcept(unsigned totalIndices,
ArrayRef<unsigned> indices,
function_ref<void(unsigned)> callback) {
llvm::BitVector skipIndices(totalIndices);
for (unsigned i : indices)
skipIndices.set(i);
for (unsigned i = 0; i < totalIndices; ++i)
if (!skipIndices.test(i))
callback(i);
}
/// Returns a new function type without the specified arguments and results.
FunctionType
FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
ArrayRef<unsigned> resultIndices) {
ArrayRef<Type> newInputTypes = getInputs();
SmallVector<Type, 4> newInputTypesBuffer;
if (!argIndices.empty()) {
unsigned originalNumArgs = getNumInputs();
iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
newInputTypesBuffer.emplace_back(getInput(i));
});
newInputTypes = newInputTypesBuffer;
}
ArrayRef<Type> newResultTypes = getResults();
SmallVector<Type, 4> newResultTypesBuffer;
if (!resultIndices.empty()) {
unsigned originalNumResults = getNumResults();
iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
newResultTypesBuffer.emplace_back(getResult(i));
});
newResultTypes = newResultTypesBuffer;
}
return get(newInputTypes, newResultTypes, getContext());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OpaqueType // OpaqueType
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//