Refactor FunctionAttr to hold the internal function reference by name instead of pointer. The one downside to this is that the function reference held by a FunctionAttr needs to be explicitly looked up from the parent module. This provides several benefits though:
* There is no longer a need to explicitly remap function attrs. - This removes a potentially expensive call from the destructor of Function. - This will enable some interprocedural transformations to now run intraprocedurally. - This wasn't scalable and forces dialect defined attributes to override a virtual function. * Replacing a function is now a trivial operation. * This is a necessary first step to representing functions as operations. -- PiperOrigin-RevId: 249510802
This commit is contained in:
parent
d5397f4efe
commit
c33862b0ed
|
@ -860,11 +860,11 @@ of the specified [float type](#floating-point-types).
|
|||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
function-attribute ::= function-id `:` function-type
|
||||
function-attribute ::= function-id
|
||||
```
|
||||
|
||||
A function attribute is a literal attribute that represents a reference to the
|
||||
given function object.
|
||||
A function attribute is a literal attribute that represents a named reference to
|
||||
the given function.
|
||||
|
||||
#### String Attribute
|
||||
|
||||
|
|
|
@ -45,11 +45,6 @@ class AttributeStorage : public StorageUniquer::BaseStorage {
|
|||
friend StorageUniquer;
|
||||
|
||||
public:
|
||||
/// Returns if the derived attribute is or contains a function pointer.
|
||||
bool isOrContainsFunctionCache() const {
|
||||
return typeAndContainsFunctionAttrPair.getInt();
|
||||
}
|
||||
|
||||
/// Get the type of this attribute.
|
||||
Type getType() const;
|
||||
|
||||
|
@ -60,14 +55,11 @@ public:
|
|||
}
|
||||
|
||||
protected:
|
||||
/// Construct a new attribute storage instance with the given type and a
|
||||
/// boolean that signals if the derived attribute is or contains a function
|
||||
/// pointer.
|
||||
/// Construct a new attribute storage instance with the given type.
|
||||
/// Note: All attributes require a valid type. If no type is provided here,
|
||||
/// the type of the attribute will automatically default to NoneType
|
||||
/// upon initialization in the uniquer.
|
||||
AttributeStorage(Type type, bool isOrContainsFunctionCache = false);
|
||||
AttributeStorage(bool isOrContainsFunctionCache);
|
||||
AttributeStorage(Type type);
|
||||
AttributeStorage();
|
||||
|
||||
/// Set the type of this attribute.
|
||||
|
@ -81,10 +73,8 @@ private:
|
|||
/// The dialect for this attribute.
|
||||
Dialect *dialect;
|
||||
|
||||
/// This field is a pair of:
|
||||
/// - The type of the attribute value.
|
||||
/// - A boolean that is true if this is, or contains, a function attribute.
|
||||
llvm::PointerIntPair<const void *, 1, bool> typeAndContainsFunctionAttrPair;
|
||||
/// The opaque type of the attribute value.
|
||||
const void *type;
|
||||
};
|
||||
|
||||
/// Default storage type for attributes that require no additional
|
||||
|
@ -114,13 +104,6 @@ public:
|
|||
getInitFn(ctx, T::getClassID()), kind, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Erase a uniqued instance of attribute T.
|
||||
template <typename T, typename... Args>
|
||||
static void erase(MLIRContext *ctx, unsigned kind, Args &&... args) {
|
||||
return ctx->getAttributeUniquer().erase<typename T::ImplType>(
|
||||
kind, std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns a functor used to initialize new attribute storage instances.
|
||||
static std::function<void(AttributeStorage *)>
|
||||
|
|
|
@ -25,7 +25,6 @@ namespace mlir {
|
|||
class AffineMap;
|
||||
class Dialect;
|
||||
class Function;
|
||||
class FunctionAttr;
|
||||
class FunctionType;
|
||||
class Identifier;
|
||||
class IntegerSet;
|
||||
|
@ -45,7 +44,6 @@ struct ArrayAttributeStorage;
|
|||
struct AffineMapAttributeStorage;
|
||||
struct IntegerSetAttributeStorage;
|
||||
struct TypeAttributeStorage;
|
||||
struct FunctionAttributeStorage;
|
||||
struct SplatElementsAttributeStorage;
|
||||
struct DenseElementsAttributeStorage;
|
||||
struct DenseIntElementsAttributeStorage;
|
||||
|
@ -118,16 +116,6 @@ public:
|
|||
/// Get the dialect this attribute is registered to.
|
||||
Dialect &getDialect() const;
|
||||
|
||||
/// Return true if this field is, or contains, a function attribute.
|
||||
bool isOrContainsFunction() const;
|
||||
|
||||
/// Replace a function attribute or function attributes nested in an array
|
||||
/// attribute with another function attribute as defined by the provided
|
||||
/// remapping table. Return the original attribute if it (or any of nested
|
||||
/// attributes) is not present in the table.
|
||||
Attribute remapFunctionAttrs(
|
||||
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const;
|
||||
|
||||
/// Print the attribute.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
@ -383,34 +371,24 @@ public:
|
|||
};
|
||||
|
||||
/// A function attribute represents a reference to a function object.
|
||||
///
|
||||
/// When working with IR, it is important to know that a function attribute can
|
||||
/// exist with a null Function inside of it, which occurs when a function object
|
||||
/// is deleted that had an attribute which referenced it. No references to this
|
||||
/// attribute should persist across the transformation, but that attribute will
|
||||
/// remain in MLIRContext.
|
||||
class FunctionAttr
|
||||
: public Attribute::AttrBase<FunctionAttr, Attribute,
|
||||
detail::FunctionAttributeStorage> {
|
||||
detail::StringAttributeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
using ValueType = Function *;
|
||||
|
||||
static FunctionAttr get(Function *value);
|
||||
static FunctionAttr get(StringRef value, MLIRContext *ctx);
|
||||
|
||||
Function *getValue() const;
|
||||
|
||||
FunctionType getType() const;
|
||||
/// Returns the name of the held function reference.
|
||||
StringRef getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::Function;
|
||||
}
|
||||
|
||||
/// This function is used by the internals of the Function class to null out
|
||||
/// attributes referring to functions that are about to be deleted.
|
||||
static void dropFunctionReference(Function *value);
|
||||
|
||||
/// This function is used by the internals of the Function class to update the
|
||||
/// type of the function attribute for 'value'.
|
||||
static void resetType(Function *value);
|
||||
|
|
|
@ -113,6 +113,7 @@ public:
|
|||
IntegerSetAttr getIntegerSetAttr(IntegerSet set);
|
||||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(Function *value);
|
||||
FunctionAttr getFunctionAttr(StringRef value);
|
||||
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef<char> data);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
|
|
|
@ -43,8 +43,6 @@ public:
|
|||
ArrayRef<NamedAttribute> attrs,
|
||||
ArrayRef<NamedAttributeList> argAttrs);
|
||||
|
||||
~Function();
|
||||
|
||||
/// The source location the function was defined or derived from.
|
||||
Location getLoc() { return location; }
|
||||
|
||||
|
@ -72,7 +70,6 @@ public:
|
|||
void setType(FunctionType newType) {
|
||||
type = newType;
|
||||
argAttrs.resize(type.getNumInputs());
|
||||
FunctionAttr::resetType(this);
|
||||
}
|
||||
|
||||
MLIRContext *getContext();
|
||||
|
|
|
@ -665,7 +665,7 @@ def StrArrayAttr : TypedArrayAttrBase<StrAttr, "string array attribute"> {
|
|||
def FunctionAttr : Attr<CPred<"$_self.isa<FunctionAttr>()">,
|
||||
"function attribute"> {
|
||||
let storageType = [{ FunctionAttr }];
|
||||
let returnType = [{ Function * }];
|
||||
let returnType = [{ StringRef }];
|
||||
let constBuilderCall = "$_builder.getFunctionAttr($0)";
|
||||
}
|
||||
|
||||
|
|
|
@ -380,11 +380,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
virtual ParseResult resolveFunctionName(StringRef name, FunctionType type,
|
||||
llvm::SMLoc loc,
|
||||
Function *&result) = 0;
|
||||
|
||||
/// Emit a diagnostic at the specified location and return failure.
|
||||
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
|
||||
const Twine &message = {}) = 0;
|
||||
|
|
|
@ -219,12 +219,16 @@ def CallOp : Std_Op<"call"> {
|
|||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType().getResults());
|
||||
}]>, OpBuilder<
|
||||
"Builder *builder, OperationState *result, StringRef callee,"
|
||||
"ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(results);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Function *getCallee() {
|
||||
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
}
|
||||
Function *getCallee();
|
||||
|
||||
/// Get the argument operands to the called function.
|
||||
operand_range getArgOperands() {
|
||||
|
|
|
@ -87,7 +87,7 @@ private:
|
|||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
|
||||
// Mappings between original and translated values, used for lookups.
|
||||
llvm::DenseMap<Function *, llvm::Function *> functionMapping;
|
||||
llvm::StringMap<llvm::Function *> functionMapping;
|
||||
llvm::DenseMap<Value *, llvm::Value *> valueMapping;
|
||||
llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
|
||||
};
|
||||
|
|
|
@ -120,21 +120,6 @@ Operation *createComposedAffineApplyOp(FuncBuilder *builder, Location loc,
|
|||
void createAffineComputationSlice(Operation *opInst,
|
||||
SmallVectorImpl<AffineApplyOp> *sliceOps);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the operation "op"
|
||||
/// with those specified in "remappingTable".
|
||||
void remapFunctionAttrs(
|
||||
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
/// Replaces (potentially nested) function attributes all operations of the
|
||||
/// Function "fn" with those specified in "remappingTable".
|
||||
void remapFunctionAttrs(
|
||||
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
/// Replaces (potentially nested) function attributes in the entire module
|
||||
/// with those specified in "remappingTable". Ignores external functions.
|
||||
void remapFunctionAttrs(
|
||||
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_UTILS_H
|
||||
|
|
|
@ -78,33 +78,6 @@ public:
|
|||
return fn.getContext()->getRegisteredDialect(dialectNamePair.first);
|
||||
}
|
||||
|
||||
template <typename ErrorContext>
|
||||
LogicalResult verifyAttribute(Attribute attr, ErrorContext &ctx) {
|
||||
if (!attr.isOrContainsFunction())
|
||||
return success();
|
||||
|
||||
// If we have a function attribute, check that it is non-null and in the
|
||||
// same module as the operation that refers to it.
|
||||
if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
|
||||
if (!fnAttr.getValue())
|
||||
return failure("attribute refers to deallocated function!", ctx);
|
||||
|
||||
if (fnAttr.getValue()->getModule() != fn.getModule())
|
||||
return failure("attribute refers to function '" +
|
||||
Twine(fnAttr.getValue()->getName()) +
|
||||
"' defined in another module!",
|
||||
ctx);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
for (auto elt : attr.cast<ArrayAttr>().getValue())
|
||||
if (failed(verifyAttribute(elt, ctx)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verify();
|
||||
LogicalResult verifyBlock(Block &block, bool isTopLevel);
|
||||
LogicalResult verifyOperation(Operation &op);
|
||||
|
@ -143,8 +116,6 @@ LogicalResult FuncVerifier::verify() {
|
|||
if (!identifierRegex.match(attr.first))
|
||||
return failure("invalid attribute name '" + attr.first.strref() + "'",
|
||||
fn);
|
||||
if (failed(verifyAttribute(attr.second, fn)))
|
||||
return failure();
|
||||
|
||||
/// Check that the attribute is a dialect attribute, i.e. contains a '.' for
|
||||
/// the namespace.
|
||||
|
@ -165,8 +136,6 @@ LogicalResult FuncVerifier::verify() {
|
|||
llvm::formatv("invalid attribute name '{0}' on argument {1}",
|
||||
attr.first.strref(), i),
|
||||
fn);
|
||||
if (failed(verifyAttribute(attr.second, fn)))
|
||||
return failure();
|
||||
|
||||
/// Check that the attribute is a dialect attribute, i.e. contains a '.'
|
||||
/// for the namespace.
|
||||
|
@ -280,8 +249,6 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) {
|
|||
if (!identifierRegex.match(attr.first))
|
||||
return failure("invalid attribute name '" + attr.first.strref() + "'",
|
||||
op);
|
||||
if (failed(verifyAttribute(attr.second, op)))
|
||||
return failure();
|
||||
|
||||
// Check for any optional dialect specific attributes.
|
||||
if (!attr.first.strref().contains('.'))
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/GPU/GPUDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
|
@ -303,7 +304,9 @@ void LaunchFuncOp::build(Builder *builder, OperationState *result,
|
|||
}
|
||||
|
||||
Function *LaunchFuncOp::kernel() {
|
||||
return this->getAttr(getKernelAttrName()).cast<FunctionAttr>().getValue();
|
||||
auto kernelAttr = getAttr(getKernelAttrName()).cast<FunctionAttr>();
|
||||
return getOperation()->getFunction()->getModule()->getNamedFunction(
|
||||
kernelAttr.getValue());
|
||||
}
|
||||
|
||||
unsigned LaunchFuncOp::getNumKernelOperands() {
|
||||
|
@ -321,7 +324,11 @@ LogicalResult LaunchFuncOp::verify() {
|
|||
} else if (!kernelAttr.isa<FunctionAttr>()) {
|
||||
return emitOpError("attribute 'kernel' must be a function");
|
||||
}
|
||||
|
||||
Function *kernelFunc = this->kernel();
|
||||
if (!kernelFunc)
|
||||
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
|
||||
|
||||
if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
|
||||
GPUDialect::getKernelFuncAttrName())) {
|
||||
return emitError("kernel function is missing the '")
|
||||
|
|
|
@ -679,14 +679,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
|
|||
printType(attr.cast<TypeAttr>().getValue());
|
||||
break;
|
||||
case StandardAttributes::Function: {
|
||||
auto *function = attr.cast<FunctionAttr>().getValue();
|
||||
if (!function) {
|
||||
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
||||
} else {
|
||||
printFunctionReference(function);
|
||||
os << " : ";
|
||||
printType(function->getType());
|
||||
}
|
||||
os << '@' << attr.cast<FunctionAttr>().getValue();
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::OpaqueElements: {
|
||||
|
@ -1317,7 +1310,7 @@ void FunctionPrinter::numberValueID(Value *value) {
|
|||
} else {
|
||||
specialName << 'c' << intCst.getInt() << '_' << type;
|
||||
}
|
||||
} else if (cst.isa<FunctionAttr>()) {
|
||||
} else if (type.isa<FunctionType>()) {
|
||||
specialName << 'f';
|
||||
} else {
|
||||
specialName << "cst";
|
||||
|
|
|
@ -214,8 +214,7 @@ struct StringAttributeStorage : public AttributeStorage {
|
|||
struct ArrayAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = ArrayRef<Attribute>;
|
||||
|
||||
ArrayAttributeStorage(bool hasFunctionAttr, ArrayRef<Attribute> value)
|
||||
: AttributeStorage(hasFunctionAttr), value(value) {}
|
||||
ArrayAttributeStorage(ArrayRef<Attribute> value) : value(value) {}
|
||||
|
||||
/// Key equality function.
|
||||
bool operator==(const KeyTy &key) const { return key == value; }
|
||||
|
@ -223,13 +222,8 @@ struct ArrayAttributeStorage : public AttributeStorage {
|
|||
/// Construct a new storage instance.
|
||||
static ArrayAttributeStorage *construct(AttributeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
// Check to see if any of the elements have a function attr.
|
||||
bool hasFunctionAttr = llvm::any_of(
|
||||
key, [](Attribute elt) { return elt.isOrContainsFunction(); });
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
return new (allocator.allocate<ArrayAttributeStorage>())
|
||||
ArrayAttributeStorage(hasFunctionAttr, allocator.copyInto(key));
|
||||
ArrayAttributeStorage(allocator.copyInto(key));
|
||||
}
|
||||
|
||||
ArrayRef<Attribute> value;
|
||||
|
@ -293,37 +287,6 @@ struct TypeAttributeStorage : public AttributeStorage {
|
|||
Type value;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a function.
|
||||
struct FunctionAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = Function *;
|
||||
|
||||
FunctionAttributeStorage(Function *value)
|
||||
: AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
|
||||
value(value) {}
|
||||
|
||||
/// Key equality function.
|
||||
bool operator==(const KeyTy &key) const { return key == value; }
|
||||
|
||||
/// Construct a new storage instance.
|
||||
static FunctionAttributeStorage *
|
||||
construct(AttributeStorageAllocator &allocator, KeyTy key) {
|
||||
return new (allocator.allocate<FunctionAttributeStorage>())
|
||||
FunctionAttributeStorage(key);
|
||||
}
|
||||
|
||||
/// Storage cleanup function.
|
||||
void cleanup() {
|
||||
// Null out the function reference in the attribute to avoid dangling
|
||||
// pointers.
|
||||
value = nullptr;
|
||||
}
|
||||
|
||||
/// Reset the type of this attribute to the type of the held function.
|
||||
void resetType() { setType(value->getType()); }
|
||||
|
||||
Function *value;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a vector or tensor constant,
|
||||
/// inwhich all elements have the same value.
|
||||
struct SplatElementsAttributeStorage : public AttributeStorage {
|
||||
|
|
|
@ -32,20 +32,15 @@ using namespace mlir::detail;
|
|||
// AttributeStorage
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
AttributeStorage::AttributeStorage(Type type, bool isOrContainsFunctionCache)
|
||||
: typeAndContainsFunctionAttrPair(type.getAsOpaquePointer(),
|
||||
isOrContainsFunctionCache) {}
|
||||
AttributeStorage::AttributeStorage(bool isOrContainsFunctionCache)
|
||||
: AttributeStorage(/*type=*/nullptr, isOrContainsFunctionCache) {}
|
||||
AttributeStorage::AttributeStorage()
|
||||
: AttributeStorage(/*type=*/nullptr, /*isOrContainsFunctionCache=*/false) {}
|
||||
AttributeStorage::AttributeStorage(Type type)
|
||||
: type(type.getAsOpaquePointer()) {}
|
||||
AttributeStorage::AttributeStorage() : type(nullptr) {}
|
||||
|
||||
Type AttributeStorage::getType() const {
|
||||
return Type::getFromOpaquePointer(
|
||||
typeAndContainsFunctionAttrPair.getPointer());
|
||||
return Type::getFromOpaquePointer(type);
|
||||
}
|
||||
void AttributeStorage::setType(Type type) {
|
||||
typeAndContainsFunctionAttrPair.setPointer(type.getAsOpaquePointer());
|
||||
void AttributeStorage::setType(Type newType) {
|
||||
type = newType.getAsOpaquePointer();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -61,42 +56,6 @@ MLIRContext *Attribute::getContext() const { return getType().getContext(); }
|
|||
/// Get the dialect this attribute is registered to.
|
||||
Dialect &Attribute::getDialect() const { return impl->getDialect(); }
|
||||
|
||||
bool Attribute::isOrContainsFunction() const {
|
||||
return impl->isOrContainsFunctionCache();
|
||||
}
|
||||
|
||||
// Given an attribute that could refer to a function attribute in the remapping
|
||||
// table, walk it and rewrite it to use the mapped function. If it doesn't
|
||||
// refer to anything in the table, then it is returned unmodified.
|
||||
Attribute Attribute::remapFunctionAttrs(
|
||||
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const {
|
||||
// Most attributes are trivially unrelated to function attributes, skip them
|
||||
// rapidly.
|
||||
if (!isOrContainsFunction())
|
||||
return *this;
|
||||
|
||||
// If we have a function attribute, remap it.
|
||||
if (auto fnAttr = this->dyn_cast<FunctionAttr>()) {
|
||||
auto it = remappingTable.find(fnAttr);
|
||||
return it != remappingTable.end() ? it->second : *this;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
auto arrayAttr = this->cast<ArrayAttr>();
|
||||
SmallVector<Attribute, 8> remappedElts;
|
||||
bool anyChange = false;
|
||||
for (auto elt : arrayAttr.getValue()) {
|
||||
auto newElt = elt.remapFunctionAttrs(remappingTable);
|
||||
remappedElts.push_back(newElt);
|
||||
anyChange |= (elt != newElt);
|
||||
}
|
||||
|
||||
if (!anyChange)
|
||||
return *this;
|
||||
|
||||
return ArrayAttr::get(remappedElts, getContext());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpaqueAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -293,27 +252,14 @@ Type TypeAttr::getValue() const { return getImpl()->value; }
|
|||
|
||||
FunctionAttr FunctionAttr::get(Function *value) {
|
||||
assert(value && "Cannot get FunctionAttr for a null function");
|
||||
return Base::get(value->getContext(), StandardAttributes::Function, value);
|
||||
return get(value->getName(), value->getContext());
|
||||
}
|
||||
|
||||
/// This function is used by the internals of the Function class to null out
|
||||
/// attributes referring to functions that are about to be deleted.
|
||||
void FunctionAttr::dropFunctionReference(Function *value) {
|
||||
AttributeUniquer::erase<FunctionAttr>(value->getContext(),
|
||||
StandardAttributes::Function, value);
|
||||
FunctionAttr FunctionAttr::get(StringRef value, MLIRContext *ctx) {
|
||||
return Base::get(ctx, StandardAttributes::Function, value);
|
||||
}
|
||||
|
||||
/// This function is used by the internals of the Function class to update the
|
||||
/// type of the attribute for 'value'.
|
||||
void FunctionAttr::resetType(Function *value) {
|
||||
FunctionAttr::get(value).getImpl()->resetType();
|
||||
}
|
||||
|
||||
Function *FunctionAttr::getValue() const { return getImpl()->value; }
|
||||
|
||||
FunctionType FunctionAttr::getType() const {
|
||||
return Attribute::getType().cast<FunctionType>();
|
||||
}
|
||||
StringRef FunctionAttr::getValue() const { return getImpl()->value; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ElementsAttr
|
||||
|
|
|
@ -172,6 +172,9 @@ TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
|
|||
FunctionAttr Builder::getFunctionAttr(Function *value) {
|
||||
return FunctionAttr::get(value);
|
||||
}
|
||||
FunctionAttr Builder::getFunctionAttr(StringRef value) {
|
||||
return FunctionAttr::get(value, getContext());
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
|
|
|
@ -37,11 +37,6 @@ Function::Function(Location location, StringRef name, FunctionType type,
|
|||
: name(Identifier::get(name, type.getContext())), location(location),
|
||||
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
|
||||
|
||||
Function::~Function() {
|
||||
// Clean up function attributes referring to this function.
|
||||
FunctionAttr::dropFunctionReference(this);
|
||||
}
|
||||
|
||||
MLIRContext *Function::getContext() { return getType().getContext(); }
|
||||
|
||||
/// Swap the name of the given function with this one.
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "llvm/AsmParser/Parser.h"
|
||||
|
@ -393,43 +394,22 @@ static void printCallOp(OpAsmPrinter *p, CallOp &op) {
|
|||
// callee (first operand) otherwise.
|
||||
*p << op.getOperationName() << ' ';
|
||||
if (isDirect)
|
||||
*p << '@' << callee.getValue()->getName().strref();
|
||||
*p << '@' << callee.getValue();
|
||||
else
|
||||
*p << *op.getOperand(0);
|
||||
|
||||
*p << '(';
|
||||
p->printOperands(std::next(op.operand_begin(), callee.hasValue() ? 0 : 1),
|
||||
op.operand_end());
|
||||
p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
|
||||
*p << ')';
|
||||
|
||||
p->printOptionalAttrDict(op.getAttrs(), {"callee"});
|
||||
|
||||
if (isDirect) {
|
||||
*p << " : " << callee.getValue()->getType();
|
||||
return;
|
||||
}
|
||||
|
||||
// Reconstruct the function MLIR function type from LLVM function type,
|
||||
// and print it.
|
||||
auto operandType = op.getOperand(0)->getType().cast<LLVM::LLVMType>();
|
||||
auto *llvmPtrType =
|
||||
dyn_cast<llvm::PointerType>(operandType.getUnderlyingType());
|
||||
assert(llvmPtrType &&
|
||||
"operand #0 must have LLVM pointer type for indirect calls");
|
||||
auto *llvmType = dyn_cast<llvm::FunctionType>(llvmPtrType->getElementType());
|
||||
assert(llvmType &&
|
||||
"operand #0 must have LLVM Function pointer type for indirect calls");
|
||||
|
||||
auto *llvmResultType = llvmType->getReturnType();
|
||||
SmallVector<Type, 1> resultTypes;
|
||||
if (!llvmResultType->isVoidTy())
|
||||
resultTypes.push_back(LLVM::LLVMType::get(op.getContext(), llvmResultType));
|
||||
|
||||
// Reconstruct the function MLIR function type from operand and result types.
|
||||
SmallVector<Type, 1> resultTypes(op.getOperation()->getResultTypes());
|
||||
SmallVector<Type, 8> argTypes;
|
||||
argTypes.reserve(llvmType->getNumParams());
|
||||
for (int i = 0, e = llvmType->getNumParams(); i < e; ++i)
|
||||
argTypes.push_back(
|
||||
LLVM::LLVMType::get(op.getContext(), llvmType->getParamType(i)));
|
||||
argTypes.reserve(op.getNumOperands());
|
||||
for (auto *operand : llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1))
|
||||
argTypes.push_back(operand->getType());
|
||||
|
||||
*p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
||||
}
|
||||
|
@ -467,10 +447,7 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
|||
return parser->emitError(trailingTypeLoc, "expected function type");
|
||||
if (isDirect) {
|
||||
// Add the direct callee as an Op attribute.
|
||||
Function *func;
|
||||
if (parser->resolveFunctionName(calleeName, funcType, calleeLoc, func))
|
||||
return failure();
|
||||
auto funcAttr = parser->getBuilder().getFunctionAttr(func);
|
||||
auto funcAttr = parser->getBuilder().getFunctionAttr(calleeName);
|
||||
attrs.push_back(parser->getBuilder().getNamedAttr("callee", funcAttr));
|
||||
|
||||
// Make sure types match.
|
||||
|
|
|
@ -57,23 +57,12 @@ public:
|
|||
: context(module->getContext()), module(module), lex(sourceMgr, context),
|
||||
curToken(lex.lexToken()) {}
|
||||
|
||||
~ParserState() {
|
||||
// Destroy the forward references upon error.
|
||||
for (auto forwardRef : functionForwardRefs)
|
||||
delete forwardRef.second;
|
||||
functionForwardRefs.clear();
|
||||
}
|
||||
|
||||
// A map from attribute alias identifier to Attribute.
|
||||
llvm::StringMap<Attribute> attributeAliasDefinitions;
|
||||
|
||||
// A map from type alias identifier to Type.
|
||||
llvm::StringMap<Type> typeAliasDefinitions;
|
||||
|
||||
// This keeps track of all forward references to functions along with the
|
||||
// temporary function used to represent them.
|
||||
llvm::DenseMap<Identifier, Function *> functionForwardRefs;
|
||||
|
||||
private:
|
||||
ParserState(const ParserState &) = delete;
|
||||
void operator=(const ParserState &) = delete;
|
||||
|
@ -190,8 +179,6 @@ public:
|
|||
ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
|
||||
|
||||
// Attribute parsing.
|
||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType type);
|
||||
Attribute parseExtendedAttribute(Type type);
|
||||
Attribute parseAttribute(Type type = {});
|
||||
|
||||
|
@ -998,32 +985,6 @@ TensorLiteralParser::parseList(llvm::SmallVectorImpl<int64_t> &dims) {
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Given a parsed reference to a function name like @foo and a type that it
|
||||
/// corresponds to, resolve it to a concrete function object (possibly
|
||||
/// synthesizing a forward reference) or emit an error and return null on
|
||||
/// failure.
|
||||
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType type) {
|
||||
Identifier name = builder.getIdentifier(nameStr.drop_front());
|
||||
|
||||
// See if the function has already been defined in the module.
|
||||
Function *function = getModule()->getNamedFunction(name);
|
||||
|
||||
// If not, get or create a forward reference to one.
|
||||
if (!function) {
|
||||
auto &entry = state.functionForwardRefs[name];
|
||||
if (!entry)
|
||||
entry = new Function(getEncodedSourceLocation(nameLoc), name, type,
|
||||
/*attrs=*/{});
|
||||
function = entry;
|
||||
}
|
||||
|
||||
if (function->getType() != type)
|
||||
return (emitError(nameLoc, "reference to function with mismatched type"),
|
||||
nullptr);
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Parse an extended attribute.
|
||||
///
|
||||
/// extended-attribute ::= (dialect-attribute | attribute-alias)
|
||||
|
@ -1218,22 +1179,9 @@ Attribute Parser::parseAttribute(Type type) {
|
|||
}
|
||||
|
||||
case Token::at_identifier: {
|
||||
auto nameLoc = getToken().getLoc();
|
||||
auto nameStr = getTokenSpelling();
|
||||
consumeToken(Token::at_identifier);
|
||||
|
||||
if (parseToken(Token::colon, "expected ':' and function type"))
|
||||
return nullptr;
|
||||
auto typeLoc = getToken().getLoc();
|
||||
Type type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto fnType = type.dyn_cast<FunctionType>();
|
||||
if (!fnType)
|
||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||
|
||||
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
|
||||
return function ? builder.getFunctionAttr(function) : nullptr;
|
||||
return builder.getFunctionAttr(nameStr.drop_front());
|
||||
}
|
||||
case Token::kw_opaque: {
|
||||
consumeToken(Token::kw_opaque);
|
||||
|
@ -3224,7 +3172,7 @@ public:
|
|||
if (parser.getToken().isNot(Token::at_identifier))
|
||||
return failure();
|
||||
|
||||
result = parser.getTokenSpelling();
|
||||
result = parser.getTokenSpelling().drop_front();
|
||||
parser.consumeToken(Token::at_identifier);
|
||||
return success();
|
||||
}
|
||||
|
@ -3325,13 +3273,6 @@ public:
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
ParseResult resolveFunctionName(StringRef name, FunctionType type,
|
||||
llvm::SMLoc loc, Function *&result) override {
|
||||
result = parser.resolveFunctionReference(name, loc, type);
|
||||
return failure(result == nullptr);
|
||||
}
|
||||
|
||||
/// Parse a region that takes `arguments` of `argTypes` types. This
|
||||
/// effectively defines the SSA values of `arguments` and assignes their type.
|
||||
ParseResult parseRegion(Region ®ion, ArrayRef<OperandType> arguments,
|
||||
|
@ -3549,8 +3490,6 @@ public:
|
|||
ParseResult parseModule();
|
||||
|
||||
private:
|
||||
ParseResult finalizeModule();
|
||||
|
||||
ParseResult parseAttributeAliasDef();
|
||||
|
||||
ParseResult parseTypeAliasDef();
|
||||
|
@ -3782,42 +3721,6 @@ ParseResult ModuleParser::parseFunc() {
|
|||
return parser.parseFunctionBody(hadNamedArguments);
|
||||
}
|
||||
|
||||
/// Finish the end of module parsing - when the result is valid, do final
|
||||
/// checking.
|
||||
ParseResult ModuleParser::finalizeModule() {
|
||||
// Resolve all forward references, building a remapping table of attributes.
|
||||
DenseMap<Attribute, FunctionAttr> remappingTable;
|
||||
for (auto forwardRef : getState().functionForwardRefs) {
|
||||
auto name = forwardRef.first;
|
||||
|
||||
// Resolve the reference.
|
||||
auto *resolvedFunction = getModule()->getNamedFunction(name);
|
||||
if (!resolvedFunction) {
|
||||
forwardRef.second->emitError("reference to undefined function '")
|
||||
<< name << "'";
|
||||
return failure();
|
||||
}
|
||||
|
||||
remappingTable[builder.getFunctionAttr(forwardRef.second)] =
|
||||
builder.getFunctionAttr(resolvedFunction);
|
||||
}
|
||||
|
||||
// If there was nothing to remap, then we're done.
|
||||
if (remappingTable.empty())
|
||||
return success();
|
||||
|
||||
// Otherwise, walk the entire module replacing uses of one attribute set
|
||||
// with the correct ones.
|
||||
remapFunctionAttrs(*getModule(), remappingTable);
|
||||
|
||||
// Now that all references to the forward definition placeholders are
|
||||
// resolved, we can deallocate the placeholders.
|
||||
for (auto forwardRef : getState().functionForwardRefs)
|
||||
delete forwardRef.second;
|
||||
getState().functionForwardRefs.clear();
|
||||
return success();
|
||||
}
|
||||
|
||||
/// This is the top-level module parser.
|
||||
ParseResult ModuleParser::parseModule() {
|
||||
while (1) {
|
||||
|
@ -3828,7 +3731,7 @@ ParseResult ModuleParser::parseModule() {
|
|||
|
||||
// If we got to the end of the file, then we're done.
|
||||
case Token::eof:
|
||||
return finalizeModule();
|
||||
return success();
|
||||
|
||||
// If we got an error token, then the lexer already emitted an error, just
|
||||
// stop. Someday we could introduce error recovery if there was demand
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
@ -419,19 +420,18 @@ static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
|
|||
llvm::SMLoc calleeLoc;
|
||||
FunctionType calleeType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
Function *callee = nullptr;
|
||||
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
||||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(calleeType) ||
|
||||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
||||
parser->addTypesToList(calleeType.getResults(), result->types) ||
|
||||
parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
|
||||
result->operands))
|
||||
return failure();
|
||||
|
||||
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
|
||||
result->addAttribute("callee",
|
||||
parser->getBuilder().getFunctionAttr(calleeName));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -450,9 +450,14 @@ static LogicalResult verify(CallOp op) {
|
|||
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires a 'callee' function attribute");
|
||||
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << fnAttr.getValue()
|
||||
<< "' does not reference a valid function";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto fnType = fnAttr.getValue()->getType();
|
||||
auto fnType = fn->getType();
|
||||
if (fnType.getNumInputs() != op.getNumOperands())
|
||||
return op.emitOpError("incorrect number of operands for callee");
|
||||
|
||||
|
@ -470,6 +475,11 @@ static LogicalResult verify(CallOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
Function *CallOp::getCallee() {
|
||||
auto name = getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
return getOperation()->getFunction()->getModule()->getNamedFunction(name);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CallIndirectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -494,8 +504,10 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
|
|||
return matchFailure();
|
||||
|
||||
// Replace with a direct call.
|
||||
SmallVector<Type, 8> callResults(op->getResultTypes());
|
||||
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callOperands);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
|
||||
callOperands);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
@ -1108,19 +1120,30 @@ static void print(OpAsmPrinter *p, ConstantOp &op) {
|
|||
if (op.getAttrs().size() > 1)
|
||||
*p << ' ';
|
||||
p->printAttributeAndType(op.getValue());
|
||||
|
||||
// If the value is a function, print a trailing type.
|
||||
if (op.getValue().isa<FunctionAttr>()) {
|
||||
*p << " : ";
|
||||
p->printType(op.getType());
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseConstantOp(OpAsmParser *parser,
|
||||
OperationState *result) {
|
||||
Attribute valueAttr;
|
||||
Type type;
|
||||
|
||||
if (parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseAttribute(valueAttr, "value", result->attributes))
|
||||
return failure();
|
||||
|
||||
// If the attribute is a function, then we expect a trailing type.
|
||||
Type type;
|
||||
if (!valueAttr.isa<FunctionAttr>())
|
||||
type = valueAttr.getType();
|
||||
else if (parser->parseColonType(type))
|
||||
return failure();
|
||||
|
||||
// Add the attribute type to the list.
|
||||
return parser->addTypeToList(valueAttr.getType(), result->types);
|
||||
return parser->addTypeToList(type, result->types);
|
||||
}
|
||||
|
||||
/// The constant op requires an attribute, and furthermore requires that it
|
||||
|
@ -1131,7 +1154,7 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
return op.emitOpError("requires a 'value' attribute");
|
||||
|
||||
auto type = op.getType();
|
||||
if (type != value.getType())
|
||||
if (!value.getType().isa<NoneType>() && type != value.getType())
|
||||
return op.emitOpError() << "requires attribute's type (" << value.getType()
|
||||
<< ") to match op's return type (" << type << ")";
|
||||
|
||||
|
@ -1162,8 +1185,20 @@ static LogicalResult verify(ConstantOp &op) {
|
|||
}
|
||||
|
||||
if (type.isa<FunctionType>()) {
|
||||
if (!value.isa<FunctionAttr>())
|
||||
auto fnAttr = value.dyn_cast<FunctionAttr>();
|
||||
if (!fnAttr)
|
||||
return op.emitOpError("requires 'value' to be a function reference");
|
||||
|
||||
// Try to find the referenced function.
|
||||
auto *fn = op.getOperation()->getFunction()->getModule()->getNamedFunction(
|
||||
fnAttr.getValue());
|
||||
if (!fn)
|
||||
return op.emitOpError("reference to undefined function 'bar'");
|
||||
|
||||
// Check that the referenced function has the correct type.
|
||||
if (fn->getType() != type)
|
||||
return op.emitOpError("reference to function with mismatched type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -326,7 +326,7 @@ bool ModuleTranslation::convertOneFunction(Function &func) {
|
|||
// function.
|
||||
blockMapping.clear();
|
||||
valueMapping.clear();
|
||||
llvm::Function *llvmFunc = functionMapping.lookup(&func);
|
||||
llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
|
||||
// Add function arguments to the value remapping table.
|
||||
// If there was noalias info then we decorate each argument accordingly.
|
||||
unsigned int argIdx = 0;
|
||||
|
@ -378,7 +378,6 @@ bool ModuleTranslation::convertFunctions() {
|
|||
// Declare all functions first because there may be function calls that form a
|
||||
// call graph with cycles.
|
||||
for (Function &function : mlirModule) {
|
||||
Function *functionPtr = &function;
|
||||
mlir::BoolAttr isVarArgsAttr =
|
||||
function.getAttrOfType<BoolAttr>("std.varargs");
|
||||
bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
|
||||
|
@ -390,7 +389,7 @@ bool ModuleTranslation::convertFunctions() {
|
|||
llvm::FunctionCallee llvmFuncCst =
|
||||
llvmModule->getOrInsertFunction(function.getName(), functionType);
|
||||
assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
|
||||
functionMapping[functionPtr] =
|
||||
functionMapping[function.getName()] =
|
||||
cast<llvm::Function>(llvmFuncCst.getCallee());
|
||||
}
|
||||
|
||||
|
|
|
@ -284,45 +284,3 @@ void mlir::createAffineComputationSlice(
|
|||
opInst->setOperand(idx, newOperands[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
remapFunctionAttrs(NamedAttributeList &attrs,
|
||||
const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : attrs.getAttrs()) {
|
||||
// Do the remapping, if we got the same thing back, then it must contain
|
||||
// functions that aren't getting remapped.
|
||||
auto newVal = attr.second.remapFunctionAttrs(remappingTable);
|
||||
if (newVal == attr.second)
|
||||
continue;
|
||||
|
||||
// Otherwise, replace the existing attribute with the new one. It is safe
|
||||
// to mutate the attribute list while we walk it because underlying
|
||||
// attribute lists are uniqued and immortal.
|
||||
attrs.set(attr.first, newVal);
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
::remapFunctionAttrs(op.getAttrList(), remappingTable);
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
|
||||
// Remap the attributes of the function.
|
||||
::remapFunctionAttrs(fn.getAttrList(), remappingTable);
|
||||
|
||||
// Remap the attributes of the arguments of this function.
|
||||
for (auto &attrList : fn.getAllArgAttrs())
|
||||
::remapFunctionAttrs(attrList, remappingTable);
|
||||
|
||||
// Look at all operations in a Function.
|
||||
fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); });
|
||||
}
|
||||
|
||||
void mlir::remapFunctionAttrs(
|
||||
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto &fn : module)
|
||||
remapFunctionAttrs(fn, remappingTable);
|
||||
}
|
||||
|
|
|
@ -94,7 +94,16 @@ func @launch_func_missing_callee_attribute(%sz : index) {
|
|||
|
||||
func @launch_func_no_function_attribute(%sz : index) {
|
||||
// expected-error@+1 {{attribute 'kernel' must be a function}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: "bar"}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) {kernel: 10}
|
||||
: (index, index, index, index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @launch_func_undefined_function(%sz : index) {
|
||||
// expected-error@+1 {{kernel function '@kernel_1' is undefined}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz) { kernel: @kernel_1 }
|
||||
: (index, index, index, index, index, index) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -107,8 +116,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) {
|
|||
|
||||
func @launch_func_missing_kernel_attr(%sz : index, %arg : !llvm<"float*">) {
|
||||
// expected-error@+1 {{kernel function is missing the 'gpu.kernel' attribute}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
|
||||
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg) {kernel: @kernel_1}
|
||||
: (index, index, index, index, index, index, !llvm<"float*">) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -122,7 +130,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
|
|||
func @launch_func_kernel_operand_size(%sz : index, %arg : !llvm<"float*">) {
|
||||
// expected-error@+1 {{got 2 kernel operands but expected 1}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg, %arg)
|
||||
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
|
||||
{kernel: @kernel_1}
|
||||
: (index, index, index, index, index, index, !llvm<"float*">,
|
||||
!llvm<"float*">) -> ()
|
||||
return
|
||||
|
@ -137,7 +145,7 @@ func @kernel_1(%arg1 : !llvm<"float*">) attributes { gpu.kernel } {
|
|||
func @launch_func_kernel_operand_types(%sz : index, %arg : f32) {
|
||||
// expected-error@+1 {{type of function argument 0 does not match}}
|
||||
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz, %arg)
|
||||
{kernel: @kernel_1 : (!llvm<"float*">) -> ()}
|
||||
{kernel: @kernel_1}
|
||||
: (index, index, index, index, index, index, f32) -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -82,9 +82,8 @@ func @foo() {
|
|||
// CHECK: %c8 = constant 8
|
||||
%cst = constant 8 : index
|
||||
|
||||
// CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1)
|
||||
{kernel: @kernel_1 : (f32, memref<?xf32, 1>) -> ()}
|
||||
// CHECK: "gpu.launch_func"(%c8, %c8, %c8, %c8, %c8, %c8, %0, %1) {kernel: @kernel_1} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
|
||||
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel: @kernel_1 }
|
||||
: (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -319,7 +319,7 @@ func @calls(%arg0: i32) {
|
|||
// CHECK: %1 = call @return_op(%0) : (i32) -> i32
|
||||
%y = call @return_op(%x) : (i32) -> i32
|
||||
// CHECK: %2 = call @return_op(%0) : (i32) -> i32
|
||||
%z = "std.call"(%x) {callee: @return_op : (i32) -> i32} : (i32) -> i32
|
||||
%z = "std.call"(%x) {callee: @return_op} : (i32) -> i32
|
||||
|
||||
// CHECK: %f = constant @affine_apply : () -> ()
|
||||
%f = constant @affine_apply : () -> ()
|
||||
|
|
|
@ -26,7 +26,7 @@ func @dim3(tensor<1xf32>) {
|
|||
|
||||
func @constant() {
|
||||
^bb:
|
||||
%x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires attribute's type (none) to match op's return type (i32)}}
|
||||
%x = "std.constant"(){value: "xyz"} : () -> i32 // expected-error {{requires a result type that aligns with the 'value' attribute}}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ func @intlimit2() {
|
|||
// -----
|
||||
|
||||
func @calls(%arg0: i32) {
|
||||
%x = call @calls() : () -> i32 // expected-error {{reference to function with mismatched type}}
|
||||
%x = call @calls() : () -> i32 // expected-error {{incorrect number of operands for callee}}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -492,14 +492,6 @@ func @invalid_result_type() -> () -> () // expected-error {{expected a top leve
|
|||
|
||||
// -----
|
||||
|
||||
func @func() -> (() -> ())
|
||||
func @referer() {
|
||||
%f = constant @func : () -> () -> () // expected-error {{reference to function with mismatched type}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map1 = (i)[j] -> (i+j)
|
||||
|
||||
func @bound_symbol_mismatch(%N : index) {
|
||||
|
|
|
@ -372,8 +372,8 @@ func @attributes() {
|
|||
// CHECK: "foo"() {d: 1.000000e-09 : f64, func: [], i123: 7 : i64, if: "foo"} : () -> ()
|
||||
"foo"() {if: "foo", func: [], i123: 7, d: 1.e-9} : () -> ()
|
||||
|
||||
// CHECK: "foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
|
||||
"foo"() {fn: @attributes : () -> (), if: @ifinst : (index) -> ()} : () -> ()
|
||||
// CHECK: "foo"() {fn: @attributes, if: @ifinst} : () -> ()
|
||||
"foo"() {fn: @attributes, if: @ifinst} : () -> ()
|
||||
|
||||
// CHECK: "foo"() {int: 0 : i42} : () -> ()
|
||||
"foo"() {int: 0 : i42} : () -> ()
|
||||
|
@ -923,15 +923,3 @@ func @none_type() {
|
|||
%none_val = "foo.unknown_op"() : () -> none
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fn_attr_remap
|
||||
// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()}
|
||||
func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> ()
|
||||
// CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()}
|
||||
attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fn_attr_ref
|
||||
func @fn_attr_ref() -> ()
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ func @body(i32)
|
|||
|
||||
// CHECK-LABEL: func @indirect_const_call(%arg0: !llvm.i32) {
|
||||
func @indirect_const_call(%arg0: i32) {
|
||||
// CHECK-NEXT: %0 = llvm.constant(@body : (!llvm.i32) -> ()) : !llvm<"void (i32)*">
|
||||
// CHECK-NEXT: %0 = llvm.constant(@body) : !llvm<"void (i32)*">
|
||||
%0 = constant @body : (i32) -> ()
|
||||
// CHECK-NEXT: llvm.call %0(%arg0) : (!llvm.i32) -> ()
|
||||
call_indirect %0(%arg0) : (i32) -> ()
|
||||
|
|
|
@ -52,12 +52,12 @@ func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) {
|
|||
// CHECK-NEXT: %17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
// CHECK-NEXT: %18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
|
||||
// CHECK-NEXT: %19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
|
||||
// CHECK-NEXT: %20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
// CHECK-NEXT: %20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
// CHECK-NEXT: %21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
%17 = llvm.call @foo(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
%18 = llvm.extractvalue %17[0] : !llvm<"{ i32, double, i32 }">
|
||||
%19 = llvm.insertvalue %18, %17[2] : !llvm<"{ i32, double, i32 }">
|
||||
%20 = llvm.constant(@foo : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
%20 = llvm.constant(@foo) : !llvm<"{ i32, double, i32 } (i32)*">
|
||||
%21 = llvm.call %20(%arg0) : (!llvm.i32) -> !llvm<"{ i32, double, i32 }">
|
||||
|
||||
|
||||
|
|
|
@ -766,7 +766,7 @@ func @ops(%arg0: !llvm.float, %arg1: !llvm.float, %arg2: !llvm.i32, %arg3: !llvm
|
|||
// CHECK-LABEL: define void @indirect_const_call(i64) {
|
||||
func @indirect_const_call(%arg0: !llvm.i64) {
|
||||
// CHECK-NEXT: call void @body(i64 %0)
|
||||
%0 = llvm.constant(@body : (!llvm.i64) -> ()) : !llvm<"void (i64)*">
|
||||
%0 = llvm.constant(@body) : !llvm<"void (i64)*">
|
||||
llvm.call %0(%arg0) : (!llvm.i64) -> ()
|
||||
// CHECK-NEXT: ret void
|
||||
llvm.return
|
||||
|
|
|
@ -104,7 +104,7 @@ def BOp : NS_Op<"b_op", []> {
|
|||
// CHECK: APFloat BOp::f64_attr()
|
||||
// CHECK: StringRef BOp::str_attr()
|
||||
// CHECK: ElementsAttr BOp::elements_attr()
|
||||
// CHECK: Function *BOp::function_attr()
|
||||
// CHECK: StringRef BOp::function_attr()
|
||||
// CHECK: SomeType BOp::type_attr()
|
||||
// CHECK: ArrayAttr BOp::array_attr()
|
||||
// CHECK: ArrayAttr BOp::some_attr_array()
|
||||
|
|
Loading…
Reference in New Issue