Refactor FunctionAttr to hold the internal function reference by name instead of pointer. The one downside to this is that the function reference held by a FunctionAttr needs to be explicitly looked up from the parent module. This provides several benefits though:

* There is no longer a need to explicitly remap function attrs.
      - This removes a potentially expensive call from the destructor of Function.
      - This will enable some interprocedural transformations to now run intraprocedurally.
      - This wasn't scalable and forces dialect defined attributes to override
        a virtual function.
    * Replacing a function is now a trivial operation.
    * This is a necessary first step to representing functions as operations.

--

PiperOrigin-RevId: 249510802
This commit is contained in:
River Riddle 2019-05-22 13:41:23 -07:00 committed by Mehdi Amini
parent d5397f4efe
commit c33862b0ed
32 changed files with 130 additions and 454 deletions

View File

@ -860,11 +860,11 @@ of the specified [float type](#floating-point-types).
Syntax:
``` {.ebnf}
function-attribute ::= function-id `:` function-type
function-attribute ::= function-id
```
A function attribute is a literal attribute that represents a reference to the
given function object.
A function attribute is a literal attribute that represents a named reference to
the given function.
#### String Attribute

View File

@ -45,11 +45,6 @@ class AttributeStorage : public StorageUniquer::BaseStorage {
friend StorageUniquer;
public:
/// Returns if the derived attribute is or contains a function pointer.
bool isOrContainsFunctionCache() const {
return typeAndContainsFunctionAttrPair.getInt();
}
/// Get the type of this attribute.
Type getType() const;
@ -60,14 +55,11 @@ public:
}
protected:
/// Construct a new attribute storage instance with the given type and a
/// boolean that signals if the derived attribute is or contains a function
/// pointer.
/// Construct a new attribute storage instance with the given type.
/// Note: All attributes require a valid type. If no type is provided here,
/// the type of the attribute will automatically default to NoneType
/// upon initialization in the uniquer.
AttributeStorage(Type type, bool isOrContainsFunctionCache = false);
AttributeStorage(bool isOrContainsFunctionCache);
AttributeStorage(Type type);
AttributeStorage();
/// Set the type of this attribute.
@ -81,10 +73,8 @@ private:
/// The dialect for this attribute.
Dialect *dialect;
/// This field is a pair of:
/// - The type of the attribute value.
/// - A boolean that is true if this is, or contains, a function attribute.
llvm::PointerIntPair<const void *, 1, bool> typeAndContainsFunctionAttrPair;
/// The opaque type of the attribute value.
const void *type;
};
/// Default storage type for attributes that require no additional
@ -114,13 +104,6 @@ public:
getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
}
/// Erase a uniqued instance of attribute T.
template <typename T, typename... Args>
static void erase(MLIRContext *ctx, unsigned kind, Args &&... args) {
return ctx->getAttributeUniquer().erase<typename T::ImplType>(
kind, std::forward<Args>(args)...);
}
private:
/// Returns a functor used to initialize new attribute storage instances.
static std::function<void(AttributeStorage *)>

View File

@ -25,7 +25,6 @@ namespace mlir {
class AffineMap;
class Dialect;
class Function;
class FunctionAttr;
class FunctionType;
class Identifier;
class IntegerSet;
@ -45,7 +44,6 @@ struct ArrayAttributeStorage;
struct AffineMapAttributeStorage;
struct IntegerSetAttributeStorage;
struct TypeAttributeStorage;
struct FunctionAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
@ -118,16 +116,6 @@ public:
/// Get the dialect this attribute is registered to.
Dialect &getDialect() const;
/// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const;
/// Replace a function attribute or function attributes nested in an array
/// attribute with another function attribute as defined by the provided
/// remapping table. Return the original attribute if it (or any of nested
/// attributes) is not present in the table.
Attribute remapFunctionAttrs(
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const;
/// Print the attribute.
void print(raw_ostream &os) const;
void dump() const;
@ -383,34 +371,24 @@ public:
};
/// A function attribute represents a reference to a function object.
///
/// When working with IR, it is important to know that a function attribute can
/// exist with a null Function inside of it, which occurs when a function object
/// is deleted that had an attribute which referenced it. No references to this
/// attribute should persist across the transformation, but that attribute will
/// remain in MLIRContext.
class FunctionAttr
: public Attribute::AttrBase<FunctionAttr, Attribute,
detail::FunctionAttributeStorage> {
detail::StringAttributeStorage> {
public:
using Base::Base;
using ValueType = Function *;
static FunctionAttr get(Function *value);
static FunctionAttr get(StringRef value, MLIRContext *ctx);
Function *getValue() const;
FunctionType getType() const;
/// Returns the name of the held function reference.
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Function;
}
/// This function is used by the internals of the Function class to null out
/// attributes referring to functions that are about to be deleted.
static void dropFunctionReference(Function *value);
/// This function is used by the internals of the Function class to update the
/// type of the function attribute for 'value'.
static void resetType(Function *value);

View File

@ -113,6 +113,7 @@ public:
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
ElementsAttr getDenseElementsAttr(ShapedType type,

View File

@ -43,8 +43,6 @@ public:
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs);
~Function();
/// The source location the function was defined or derived from.
Location getLoc() { return location; }
@ -72,7 +70,6 @@ public:
void setType(FunctionType newType) {
type = newType;
argAttrs.resize(type.getNumInputs());
FunctionAttr::resetType(this);
}
MLIRContext *getContext();

View File

@ -665,7 +665,7 @@ def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
"function attribute"> {
let storageType = [{ FunctionAttr }];
let returnType = [{ Function * }];
let returnType = [{ StringRef }];
let constBuilderCall = "$_builder.getFunctionAttr($0)";
}

View File

@ -380,11 +380,6 @@ public:
return success();
}
/// Resolve a parse function name and a type into a function reference.
virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc,
Function *&result) = 0;
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
const Twine &message = {}) = 0;

View File

@ -219,12 +219,16 @@ def CallOp : Std_Op<"call"> {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType().getResults());
}]>, OpBuilder<
"Builder *builder, OperationState *result, StringRef callee,"
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(results);
}]>];
let extraClassDeclaration = [{
Function *getCallee() {
return getAttrOfType<FunctionAttr>("callee").getValue();
}
Function *getCallee();
/// Get the argument operands to the called function.
operand_range getArgOperands() {

View File

@ -87,7 +87,7 @@ private:
std::unique_ptr<llvm::Module> llvmModule;
// Mappings between original and translated values, used for lookups.
llvm::DenseMap<Function *, llvm::Function *> functionMapping;
llvm::StringMap<llvm::Function *> functionMapping;
llvm::DenseMap<Value *, llvm::Value *> valueMapping;
llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
};

View File

@ -120,21 +120,6 @@ Operation *createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
void createAffineComputationSlice(Operation *opInst,
SmallVectorImpl<AffineApplyOp> *sliceOps);
/// Replaces (potentially nested) function attributes in the operation "op"
/// with those specified in "remappingTable".
void remapFunctionAttrs(
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable);
/// Replaces (potentially nested) function attributes all operations of the
/// Function "fn" with those specified in "remappingTable".
void remapFunctionAttrs(
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable);
/// Replaces (potentially nested) function attributes in the entire module
/// with those specified in "remappingTable". Ignores external functions.
void remapFunctionAttrs(
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_UTILS_H

View File

@ -78,33 +78,6 @@ public:
return fn.getContext()->getRegisteredDialect(dialectNamePair.first);
}
template <typename ErrorContext>
LogicalResult verifyAttribute(Attribute attr, ErrorContext &ctx) {
if (!attr.isOrContainsFunction())
return success();
// If we have a function attribute, check that it is non-null and in the
// same module as the operation that refers to it.
if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
if (!fnAttr.getValue())
return failure("attribute refers to deallocated function!", ctx);
if (fnAttr.getValue()->getModule() != fn.getModule())
return failure("attribute refers to function '" +
Twine(fnAttr.getValue()->getName()) +
"' defined in another module!",
ctx);
return success();
}
// Otherwise, we must have an array attribute, remap the elements.
for (auto elt : attr.cast<ArrayAttr>().getValue())
if (failed(verifyAttribute(elt, ctx)))
return failure();
return success();
}
LogicalResult verify();
LogicalResult verifyBlock(Block &block, bool isTopLevel);
LogicalResult verifyOperation(Operation &op);
@ -143,8 +116,6 @@ LogicalResult FuncVerifier::verify() {
if (!identifierRegex.match(attr.first))
return failure("invalid attribute name '" + attr.first.strref() + "'",
fn);
if (failed(verifyAttribute(attr.second, fn)))
return failure();
/// Check that the attribute is a dialect attribute, i.e. contains a '.' for
/// the namespace.
@ -165,8 +136,6 @@ LogicalResult FuncVerifier::verify() {
llvm::formatv("invalid attribute name '{0}' on argument {1}",
attr.first.strref(), i),
fn);
if (failed(verifyAttribute(attr.second, fn)))
return failure();
/// Check that the attribute is a dialect attribute, i.e. contains a '.'
/// for the namespace.
@ -280,8 +249,6 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) {
if (!identifierRegex.match(attr.first))
return failure("invalid attribute name '" + attr.first.strref() + "'",
op);
if (failed(verifyAttribute(attr.second, op)))
return failure();
// Check for any optional dialect specific attributes.
if (!attr.first.strref().contains('.'))

View File

@ -21,6 +21,7 @@
#include "mlir/GPU/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
@ -303,7 +304,9 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
}
Function *LaunchFuncOp::kernel() {
return this->getAttr(getKernelAttrName()).cast<FunctionAttr>().getValue();
auto kernelAttr = getAttr(getKernelAttrName()).cast<FunctionAttr>();
return getOperation()->getFunction()->getModule()->getNamedFunction(
kernelAttr.getValue());
}
unsigned LaunchFuncOp::getNumKernelOperands() {
@ -321,7 +324,11 @@ LogicalResult LaunchFuncOp::verify() {
} else if (!kernelAttr.isa<FunctionAttr>()) {
return emitOpError("attribute 'kernel' must be a function");
}
Function *kernelFunc = this->kernel();
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName())) {
return emitError("kernel function is missing the '")

View File

@ -679,14 +679,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
printType(attr.cast<TypeAttr>().getValue());
break;
case StandardAttributes::Function: {
auto *function = attr.cast<FunctionAttr>().getValue();
if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else {
printFunctionReference(function);
os << " : ";
printType(function->getType());
}
os << '@' << attr.cast<FunctionAttr>().getValue();
break;
}
case StandardAttributes::OpaqueElements: {
@ -1317,7 +1310,7 @@ void FunctionPrinter::numberValueID(Value *value) {
} else {
specialName << 'c' << intCst.getInt() << '_' << type;
}
} else if (cst.isa<FunctionAttr>()) {
} else if (type.isa<FunctionType>()) {
specialName << 'f';
} else {
specialName << "cst";

View File

@ -214,8 +214,7 @@ struct StringAttributeStorage : public AttributeStorage {
struct ArrayAttributeStorage : public AttributeStorage {
using KeyTy = ArrayRef<Attribute>;
ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef<Attribute> value)
: AttributeStorage(hasFunctionAttr), value(value) {}
ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
@ -223,13 +222,8 @@ struct ArrayAttributeStorage : public AttributeStorage {
/// Construct a new storage instance.
static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
// Check to see if any of the elements have a function attr.
bool hasFunctionAttr = llvm::any_of(
key, [](Attribute elt) { return elt.isOrContainsFunction(); });
// Initialize the memory using placement new.
return new (allocator.allocate<ArrayAttributeStorage>())
ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key));
ArrayAttributeStorage(allocator.copyInto(key));
}
ArrayRef<Attribute> value;
@ -293,37 +287,6 @@ struct TypeAttributeStorage : public AttributeStorage {
Type value;
};
/// An attribute representing a reference to a function.
struct FunctionAttributeStorage : public AttributeStorage {
using KeyTy = Function *;
FunctionAttributeStorage(Function *value)
: AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
value(value) {}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
/// Construct a new storage instance.
static FunctionAttributeStorage *
construct(AttributeStorageAllocator &allocator, KeyTy key) {
return new (allocator.allocate<FunctionAttributeStorage>())
FunctionAttributeStorage(key);
}
/// Storage cleanup function.
void cleanup() {
// Null out the function reference in the attribute to avoid dangling
// pointers.
value = nullptr;
}
/// Reset the type of this attribute to the type of the held function.
void resetType() { setType(value->getType()); }
Function *value;
};
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public AttributeStorage {

View File

@ -32,20 +32,15 @@ using namespace mlir::detail;
// AttributeStorage
//===----------------------------------------------------------------------===//
AttributeStorage::AttributeStorage(Type type, bool isOrContainsFunctionCache)
: typeAndContainsFunctionAttrPair(type.getAsOpaquePointer(),
isOrContainsFunctionCache) {}
AttributeStorage::AttributeStorage(bool isOrContainsFunctionCache)
: AttributeStorage(/*type=*/nullptr, isOrContainsFunctionCache) {}
AttributeStorage::AttributeStorage()
: AttributeStorage(/*type=*/nullptr, /*isOrContainsFunctionCache=*/false) {}
AttributeStorage::AttributeStorage(Type type)
: type(type.getAsOpaquePointer()) {}
AttributeStorage::AttributeStorage() : type(nullptr) {}
Type AttributeStorage::getType() const {
return Type::getFromOpaquePointer(
typeAndContainsFunctionAttrPair.getPointer());
return Type::getFromOpaquePointer(type);
}
void AttributeStorage::setType(Type type) {
typeAndContainsFunctionAttrPair.setPointer(type.getAsOpaquePointer());
void AttributeStorage::setType(Type newType) {
type = newType.getAsOpaquePointer();
}
//===----------------------------------------------------------------------===//
@ -61,42 +56,6 @@ MLIRContext *Attribute::getContext() const { return getType().getContext(); }
/// Get the dialect this attribute is registered to.
Dialect &Attribute::getDialect() const { return impl->getDialect(); }
bool Attribute::isOrContainsFunction() const {
return impl->isOrContainsFunctionCache();
}
// Given an attribute that could refer to a function attribute in the remapping
// table, walk it and rewrite it to use the mapped function. If it doesn't
// refer to anything in the table, then it is returned unmodified.
Attribute Attribute::remapFunctionAttrs(
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const {
// Most attributes are trivially unrelated to function attributes, skip them
// rapidly.
if (!isOrContainsFunction())
return *this;
// If we have a function attribute, remap it.
if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
auto it = remappingTable.find(fnAttr);
return it != remappingTable.end() ? it->second : *this;
}
// Otherwise, we must have an array attribute, remap the elements.
auto arrayAttr = this->cast<ArrayAttr>();
SmallVector<Attribute, 8> remappedElts;
bool anyChange = false;
for (auto elt : arrayAttr.getValue()) {
auto newElt = elt.remapFunctionAttrs(remappingTable);
remappedElts.push_back(newElt);
anyChange |= (elt != newElt);
}
if (!anyChange)
return *this;
return ArrayAttr::get(remappedElts, getContext());
}
//===----------------------------------------------------------------------===//
// OpaqueAttr
//===----------------------------------------------------------------------===//
@ -293,27 +252,14 @@ Type TypeAttr::getValue() const { return getImpl()->value; }
FunctionAttr FunctionAttr::get(Function *value) {
assert(value && "Cannot get FunctionAttr for a null function");
return Base::get(value->getContext(), StandardAttributes::Function, value);
return get(value->getName(), value->getContext());
}
/// This function is used by the internals of the Function class to null out
/// attributes referring to functions that are about to be deleted.
void FunctionAttr::dropFunctionReference(Function *value) {
AttributeUniquer::erase<FunctionAttr>(value->getContext(),
StandardAttributes::Function, value);
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, StandardAttributes::Function, value);
}
/// This function is used by the internals of the Function class to update the
/// type of the attribute for 'value'.
void FunctionAttr::resetType(Function *value) {
FunctionAttr::get(value).getImpl()->resetType();
}
Function *FunctionAttr::getValue() const { return getImpl()->value; }
FunctionType FunctionAttr::getType() const {
return Attribute::getType().cast<FunctionType>();
}
StringRef FunctionAttr::getValue() const { return getImpl()->value; }
//===----------------------------------------------------------------------===//
// ElementsAttr

View File

@ -172,6 +172,9 @@ TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
FunctionAttr Builder::getFunctionAttr(Function *value) {
return FunctionAttr::get(value);
}
FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
}
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
return SplatElementsAttr::get(type, elt);

View File

@ -37,11 +37,6 @@ Function::Function(Location location, StringRef name, FunctionType type,
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
Function::~Function() {
// Clean up function attributes referring to this function.
FunctionAttr::dropFunctionReference(this);
}
MLIRContext *Function::getContext() { return getType().getContext(); }
/// Swap the name of the given function with this one.

View File

@ -22,6 +22,7 @@
#include "mlir/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/AsmParser/Parser.h"
@ -393,43 +394,22 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
// callee (first operand) otherwise.
*p << op.getOperationName() << ' ';
if (isDirect)
*p << '@' << callee.getValue()->getName().strref();
*p << '@' << callee.getValue();
else
*p << *op.getOperand(0);
*p << '(';
p->printOperands(std::next(op.operand_begin(), callee.hasValue() ? 0 : 1),
op.operand_end());
p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
*p << ')';
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
if (isDirect) {
*p << " : " << callee.getValue()->getType();
return;
}
// Reconstruct the function MLIR function type from LLVM function type,
// and print it.
auto operandType = op.getOperand(0)->getType().cast<LLVM::LLVMType>();
auto *llvmPtrType =
dyn_cast<llvm::PointerType>(operandType.getUnderlyingType());
assert(llvmPtrType &&
"operand #0 must have LLVM pointer type for indirect calls");
auto *llvmType = dyn_cast<llvm::FunctionType>(llvmPtrType->getElementType());
assert(llvmType &&
"operand #0 must have LLVM Function pointer type for indirect calls");
auto *llvmResultType = llvmType->getReturnType();
SmallVector<Type, 1> resultTypes;
if (!llvmResultType->isVoidTy())
resultTypes.push_back(LLVM::LLVMType::get(op.getContext(), llvmResultType));
// Reconstruct the function MLIR function type from operand and result types.
SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
SmallVector<Type, 8> argTypes;
argTypes.reserve(llvmType->getNumParams());
for (int i = 0, e = llvmType->getNumParams(); i < e; ++i)
argTypes.push_back(
LLVM::LLVMType::get(op.getContext(), llvmType->getParamType(i)));
argTypes.reserve(op.getNumOperands());
for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
argTypes.push_back(operand->getType());
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
}
@ -467,10 +447,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
return parser->emitError(trailingTypeLoc, "expected function type");
if (isDirect) {
// Add the direct callee as an Op attribute.
Function *func;
if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func))
return failure();
auto funcAttr = parser->getBuilder().getFunctionAttr(func);
auto funcAttr = parser->getBuilder().getFunctionAttr(calleeName);
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
// Make sure types match.

View File

@ -57,23 +57,12 @@ public:
: context(module->getContext()), module(module), lex(sourceMgr, context),
curToken(lex.lexToken()) {}
~ParserState() {
// Destroy the forward references upon error.
for (auto forwardRef : functionForwardRefs)
delete forwardRef.second;
functionForwardRefs.clear();
}
// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;
// A map from type alias identifier to Type.
llvm::StringMap<Type> typeAliasDefinitions;
// This keeps track of all forward references to functions along with the
// temporary function used to represent them.
llvm::DenseMap<Identifier, Function *> functionForwardRefs;
private:
ParserState(const ParserState &) = delete;
void operator=(const ParserState &) = delete;
@ -190,8 +179,6 @@ public:
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType type);
Attribute parseExtendedAttribute(Type type);
Attribute parseAttribute(Type type = {});
@ -998,32 +985,6 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
return success();
}
/// Given a parsed reference to a function name like @foo and a type that it
/// corresponds to, resolve it to a concrete function object (possibly
/// synthesizing a forward reference) or emit an error and return null on
/// failure.
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType type) {
Identifier name = builder.getIdentifier(nameStr.drop_front());
// See if the function has already been defined in the module.
Function *function = getModule()->getNamedFunction(name);
// If not, get or create a forward reference to one.
if (!function) {
auto &entry = state.functionForwardRefs[name];
if (!entry)
entry = new Function(getEncodedSourceLocation(nameLoc), name, type,
/*attrs=*/{});
function = entry;
}
if (function->getType() != type)
return (emitError(nameLoc, "reference to function with mismatched type"),
nullptr);
return function;
}
/// Parse an extended attribute.
///
/// extended-attribute ::= (dialect-attribute | attribute-alias)
@ -1218,22 +1179,9 @@ Attribute Parser::parseAttribute(Type type) {
}
case Token::at_identifier: {
auto nameLoc = getToken().getLoc();
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
if (parseToken(Token::colon, "expected ':' and function type"))
return nullptr;
auto typeLoc = getToken().getLoc();
Type type = parseType();
if (!type)
return nullptr;
auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
return function ? builder.getFunctionAttr(function) : nullptr;
return builder.getFunctionAttr(nameStr.drop_front());
}
case Token::kw_opaque: {
consumeToken(Token::kw_opaque);
@ -3224,7 +3172,7 @@ public:
if (parser.getToken().isNot(Token::at_identifier))
return failure();
result = parser.getTokenSpelling();
result = parser.getTokenSpelling().drop_front();
parser.consumeToken(Token::at_identifier);
return success();
}
@ -3325,13 +3273,6 @@ public:
return success();
}
/// Resolve a parse function name and a type into a function reference.
ParseResult resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) override {
result = parser.resolveFunctionReference(name, loc, type);
return failure(result == nullptr);
}
/// Parse a region that takes `arguments` of `argTypes` types. This
/// effectively defines the SSA values of `arguments` and assignes their type.
ParseResult parseRegion(Region &region, ArrayRef<OperandType> arguments,
@ -3549,8 +3490,6 @@ public:
ParseResult parseModule();
private:
ParseResult finalizeModule();
ParseResult parseAttributeAliasDef();
ParseResult parseTypeAliasDef();
@ -3782,42 +3721,6 @@ ParseResult ModuleParser::parseFunc() {
return parser.parseFunctionBody(hadNamedArguments);
}
/// Finish the end of module parsing - when the result is valid, do final
/// checking.
ParseResult ModuleParser::finalizeModule() {
// Resolve all forward references, building a remapping table of attributes.
DenseMap<Attribute, FunctionAttr> remappingTable;
for (auto forwardRef : getState().functionForwardRefs) {
auto name = forwardRef.first;
// Resolve the reference.
auto *resolvedFunction = getModule()->getNamedFunction(name);
if (!resolvedFunction) {
forwardRef.second->emitError("reference to undefined function '")
<< name << "'";
return failure();
}
remappingTable[builder.getFunctionAttr(forwardRef.second)] =
builder.getFunctionAttr(resolvedFunction);
}
// If there was nothing to remap, then we're done.
if (remappingTable.empty())
return success();
// Otherwise, walk the entire module replacing uses of one attribute set
// with the correct ones.
remapFunctionAttrs(*getModule(), remappingTable);
// Now that all references to the forward definition placeholders are
// resolved, we can deallocate the placeholders.
for (auto forwardRef : getState().functionForwardRefs)
delete forwardRef.second;
getState().functionForwardRefs.clear();
return success();
}
/// This is the top-level module parser.
ParseResult ModuleParser::parseModule() {
while (1) {
@ -3828,7 +3731,7 @@ ParseResult ModuleParser::parseModule() {
// If we got to the end of the file, then we're done.
case Token::eof:
return finalizeModule();
return success();
// If we got an error token, then the lexer already emitted an error, just
// stop. Someday we could introduce error recovery if there was demand

View File

@ -21,6 +21,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
@ -419,19 +420,18 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
llvm::SMLoc calleeLoc;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
parser->addTypesToList(calleeType.getResults(), result->types) ||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return failure();
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
result->addAttribute("callee",
parser->getBuilder().getFunctionAttr(calleeName));
return success();
}
@ -450,9 +450,14 @@ static LogicalResult verify(CallOp op) {
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
// Verify that the operand and result types match the callee.
auto fnType = fnAttr.getValue()->getType();
auto fnType = fn->getType();
if (fnType.getNumInputs() != op.getNumOperands())
return op.emitOpError("incorrect number of operands for callee");
@ -470,6 +475,11 @@ static LogicalResult verify(CallOp op) {
return success();
}
Function *CallOp::getCallee() {
auto name = getAttrOfType<FunctionAttr>("callee").getValue();
return getOperation()->getFunction()->getModule()->getNamedFunction(name);
}
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
@ -494,8 +504,10 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
return matchFailure();
// Replace with a direct call.
SmallVector<Type, 8> callResults(op->getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
callOperands);
return matchSuccess();
}
};
@ -1108,19 +1120,30 @@ static void print(OpAsmPrinter *p, ConstantOp &op) {
if (op.getAttrs().size() > 1)
*p << ' ';
p->printAttributeAndType(op.getValue());
// If the value is a function, print a trailing type.
if (op.getValue().isa<FunctionAttr>()) {
*p << " : ";
p->printType(op.getType());
}
}
static ParseResult parseConstantOp(OpAsmParser *parser,
OperationState *result) {
Attribute valueAttr;
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseAttribute(valueAttr, "value", result->attributes))
return failure();
// If the attribute is a function, then we expect a trailing type.
Type type;
if (!valueAttr.isa<FunctionAttr>())
type = valueAttr.getType();
else if (parser->parseColonType(type))
return failure();
// Add the attribute type to the list.
return parser->addTypeToList(valueAttr.getType(), result->types);
return parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
@ -1131,7 +1154,7 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError("requires a 'value' attribute");
auto type = op.getType();
if (type != value.getType())
if (!value.getType().isa<NoneType>() && type != value.getType())
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";
@ -1162,8 +1185,20 @@ static LogicalResult verify(ConstantOp &op) {
}
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
auto fnAttr = value.dyn_cast<FunctionAttr>();
if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
// Check that the referenced function has the correct type.
if (fn->getType() != type)
return op.emitOpError("reference to function with mismatched type");
return success();
}

View File

@ -326,7 +326,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
// function.
blockMapping.clear();
valueMapping.clear();
llvm::Function *llvmFunc = functionMapping.lookup(&func);
llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
// Add function arguments to the value remapping table.
// If there was noalias info then we decorate each argument accordingly.
unsigned int argIdx = 0;
@ -378,7 +378,6 @@ bool ModuleTranslation::convertFunctions() {
// Declare all functions first because there may be function calls that form a
// call graph with cycles.
for (Function &function : mlirModule) {
Function *functionPtr = &function;
mlir::BoolAttr isVarArgsAttr =
function.getAttrOfType<BoolAttr>("std.varargs");
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
@ -390,7 +389,7 @@ bool ModuleTranslation::convertFunctions() {
llvm::FunctionCallee llvmFuncCst =
llvmModule->getOrInsertFunction(function.getName(), functionType);
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
functionMapping[functionPtr] =
functionMapping[function.getName()] =
cast<llvm::Function>(llvmFuncCst.getCallee());
}

View File

@ -284,45 +284,3 @@ void mlir::createAffineComputationSlice(
opInst->setOperand(idx, newOperands[idx]);
}
}
static void
remapFunctionAttrs(NamedAttributeList &attrs,
const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : attrs.getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.
auto newVal = attr.second.remapFunctionAttrs(remappingTable);
if (newVal == attr.second)
continue;
// Otherwise, replace the existing attribute with the new one. It is safe
// to mutate the attribute list while we walk it because underlying
// attribute lists are uniqued and immortal.
attrs.set(attr.first, newVal);
}
}
void mlir::remapFunctionAttrs(
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
::remapFunctionAttrs(op.getAttrList(), remappingTable);
}
void mlir::remapFunctionAttrs(
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
// Remap the attributes of the function.
::remapFunctionAttrs(fn.getAttrList(), remappingTable);
// Remap the attributes of the arguments of this function.
for (auto &attrList : fn.getAllArgAttrs())
::remapFunctionAttrs(attrList, remappingTable);
// Look at all operations in a Function.
fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); });
}
void mlir::remapFunctionAttrs(
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto &fn : module)
remapFunctionAttrs(fn, remappingTable);
}

View File

@ -94,7 +94,16 @@ func @launch_func_missing_callee_attribute(%sz : index) {
func @launch_func_no_function_attribute(%sz : index) {
// expected-error@+1 {{attribute 'kernel' must be a function}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: "bar"}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: 10}
: (index, index, index, index, index, index) -> ()
return
}
// -----
func @launch_func_undefined_function(%sz : index) {
// expected-error@+1 {{kernel function '@kernel_1' is undefined}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) { kernel: @kernel_1 }
: (index, index, index, index, index, index) -> ()
return
}
@ -107,8 +116,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) {
func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm<"float*">) {
// expected-error@+1 {{kernel function is missing the 'gpu.kernel' attribute}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg) {kernel: @kernel_1}
: (index, index, index, index, index, index, !llvm<"float*">) -> ()
return
}
@ -122,7 +130,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
func @launch_func_kernel_operand_size(%sz : index, %arg : !llvm<"float*">) {
// expected-error@+1 {{got 2 kernel operands but expected 1}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg, %arg)
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
{kernel: @kernel_1}
: (index, index, index, index, index, index, !llvm<"float*">,
!llvm<"float*">) -> ()
return
@ -137,7 +145,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
// expected-error@+1 {{type of function argument 0 does not match}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
{kernel: @kernel_1}
: (index, index, index, index, index, index, f32) -> ()
return
}

View File

@ -82,9 +82,8 @@ func @foo() {
// CHECK: %c8 = constant 8
%cst = constant 8 : index
// CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1)
{kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()}
// CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel: @kernel_1 }
: (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
return
}

View File

@ -319,7 +319,7 @@ func @calls(%arg0: i32) {
// CHECK: %1 = call @return_op(%0) : (i32) -> i32
%y = call @return_op(%x) : (i32) -> i32
// CHECK: %2 = call @return_op(%0) : (i32) -> i32
%z = "std.call"(%x) {callee: @return_op : (i32) -> i32} : (i32) -> i32
%z = "std.call"(%x) {callee: @return_op} : (i32) -> i32
// CHECK: %f = constant @affine_apply : () -> ()
%f = constant @affine_apply : () -> ()

View File

@ -26,7 +26,7 @@ func @dim3(tensor<1xf32>) {
func @constant() {
^bb:
%x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires attribute's type (none) to match op's return type (i32)}}
%x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}}
return
}
@ -149,7 +149,7 @@ func @intlimit2() {
// -----
func @calls(%arg0: i32) {
%x = call @calls() : () -> i32 // expected-error {{reference to function with mismatched type}}
%x = call @calls() : () -> i32 // expected-error {{incorrect number of operands for callee}}
return
}

View File

@ -492,14 +492,6 @@ func @invalid_result_type() -> () -> () // expected-error {{expected a top leve
// -----
func @func() -> (() -> ())
func @referer() {
%f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
return
}
// -----
#map1 = (i)[j] -> (i+j)
func @bound_symbol_mismatch(%N : index) {

View File

@ -372,8 +372,8 @@ func @attributes() {
// CHECK: "foo"() {d: 1.000000e-09 : f64, func: [], i123: 7 : i64, if: "foo"} : () -> ()
"foo"() {if: "foo", func: [], i123: 7, d: 1.e-9} : () -> ()
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
"foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
// CHECK: "foo"() {fn: @attributes, if: @ifinst} : () -> ()
"foo"() {fn: @attributes, if: @ifinst} : () -> ()
// CHECK: "foo"() {int: 0 : i42} : () -> ()
"foo"() {int: 0 : i42} : () -> ()
@ -923,15 +923,3 @@ func @none_type() {
%none_val = "foo.unknown_op"() : () -> none
return
}
// CHECK-LABEL: func @fn_attr_remap
// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()}
func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> ()
// CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()}
attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} {
return
}
// CHECK-LABEL: func @fn_attr_ref
func @fn_attr_ref() -> ()

View File

@ -34,7 +34,7 @@ func @body(i32)
// CHECK-LABEL: func @indirect_const_call(%arg0: !llvm.i32) {
func @indirect_const_call(%arg0: i32) {
// CHECK-NEXT: %0 = llvm.constant(@body : (!llvm.i32) -> ()) : !llvm<"void (i32)*">
// CHECK-NEXT: %0 = llvm.constant(@body) : !llvm<"void (i32)*">
%0 = constant @body : (i32) -> ()
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
call_indirect %0(%arg0) : (i32) -> ()

View File

@ -52,12 +52,12 @@ func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) {
// CHECK-NEXT: %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
// CHECK-NEXT: %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
// CHECK-NEXT: %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
// CHECK-NEXT: %20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK-NEXT: %20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
// CHECK-NEXT: %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
%20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
%20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">

View File

@ -766,7 +766,7 @@ func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm
// CHECK-LABEL: define void @indirect_const_call(i64) {
func @indirect_const_call(%arg0: !llvm.i64) {
// CHECK-NEXT: call void @body(i64 %0)
%0 = llvm.constant(@body : (!llvm.i64) -> ()) : !llvm<"void (i64)*">
%0 = llvm.constant(@body) : !llvm<"void (i64)*">
llvm.call %0(%arg0) : (!llvm.i64) -> ()
// CHECK-NEXT: ret void
llvm.return

View File

@ -104,7 +104,7 @@ def BOp : NS_Op<"b_op", []> {
// CHECK: APFloat BOp::f64_attr()
// CHECK: StringRef BOp::str_attr()
// CHECK: ElementsAttr BOp::elements_attr()
// CHECK: Function *BOp::function_attr()
// CHECK: StringRef BOp::function_attr()
// CHECK: SomeType BOp::type_attr()
// CHECK: ArrayAttr BOp::array_attr()
// CHECK: ArrayAttr BOp::some_attr_array()