Convert the dialect type parse/print hooks into virtual functions on the Dialect class.
PiperOrigin-RevId: 235589945
This commit is contained in:
parent
f1f86eac60
commit
b4f033f6c6
|
@ -35,9 +35,6 @@ using DialectConstantFoldHook = std::function<bool(
|
||||||
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
||||||
using DialectExtractElementHook =
|
using DialectExtractElementHook =
|
||||||
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
|
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
|
||||||
using DialectTypeParserHook =
|
|
||||||
std::function<Type(StringRef, Location, MLIRContext *)>;
|
|
||||||
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
|
|
||||||
|
|
||||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||||
/// entire group. For example, hooks into other systems for constant folding,
|
/// entire group. For example, hooks into other systems for constant folding,
|
||||||
|
@ -80,11 +77,16 @@ public:
|
||||||
return Attribute();
|
return Attribute();
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Registered parsing/printing hooks for types registered to the dialect.
|
/// Parse a type registered to this dialect.
|
||||||
DialectTypeParserHook typeParseHook = nullptr;
|
virtual Type parseType(StringRef tyData, Location loc,
|
||||||
|
MLIRContext *context) const;
|
||||||
|
|
||||||
|
/// Print a type registered to this dialect.
|
||||||
/// Note: The data printed for the provided type must not include any '"'
|
/// Note: The data printed for the provided type must not include any '"'
|
||||||
/// characters.
|
/// characters.
|
||||||
DialectTypePrinterHook typePrintHook = nullptr;
|
virtual void printType(Type, raw_ostream &) const {
|
||||||
|
assert(0 && "dialect has no registered type printing hook");
|
||||||
|
}
|
||||||
|
|
||||||
/// Registered hooks for getting identifier aliases for symbols. The
|
/// Registered hooks for getting identifier aliases for symbols. The
|
||||||
/// identifier is used in place of the symbol when printing textual IR.
|
/// identifier is used in place of the symbol when printing textual IR.
|
||||||
|
|
|
@ -76,6 +76,13 @@ public:
|
||||||
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
|
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
|
||||||
llvm::Module &getLLVMModule() { return module; }
|
llvm::Module &getLLVMModule() { return module; }
|
||||||
|
|
||||||
|
/// Parse a type registered to this dialect.
|
||||||
|
Type parseType(StringRef tyData, Location loc,
|
||||||
|
MLIRContext *context) const override;
|
||||||
|
|
||||||
|
/// Print a type registered to this dialect.
|
||||||
|
void printType(Type type, raw_ostream &os) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::LLVMContext llvmContext;
|
llvm::LLVMContext llvmContext;
|
||||||
llvm::Module module;
|
llvm::Module module;
|
||||||
|
|
|
@ -715,8 +715,7 @@ void ModulePrinter::printType(Type type) {
|
||||||
default: {
|
default: {
|
||||||
auto &dialect = type.getDialect();
|
auto &dialect = type.getDialect();
|
||||||
os << '!' << dialect.getNamespace() << "<\"";
|
os << '!' << dialect.getNamespace() << "<\"";
|
||||||
assert(dialect.typePrintHook && "Expected dialect type printing hook.");
|
dialect.printType(type, os);
|
||||||
dialect.typePrintHook(type, os);
|
|
||||||
os << "\">";
|
os << "\">";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/DialectHooks.h"
|
#include "mlir/IR/DialectHooks.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
#include "llvm/ADT/Twine.h"
|
||||||
#include "llvm/Support/ManagedStatic.h"
|
#include "llvm/Support/ManagedStatic.h"
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -65,3 +66,11 @@ Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
|
||||||
}
|
}
|
||||||
|
|
||||||
Dialect::~Dialect() {}
|
Dialect::~Dialect() {}
|
||||||
|
|
||||||
|
/// Parse a type registered to this dialect.
|
||||||
|
Type Dialect::parseType(StringRef tyData, Location loc,
|
||||||
|
MLIRContext *context) const {
|
||||||
|
context->emitError(loc, "dialect '" + getNamespace() +
|
||||||
|
"' provides no type parsing hook");
|
||||||
|
return Type();
|
||||||
|
}
|
||||||
|
|
|
@ -57,27 +57,6 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
|
||||||
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
|
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) {
|
|
||||||
llvm::SMDiagnostic errorMessage;
|
|
||||||
auto *llvmDialect =
|
|
||||||
static_cast<LLVMDialect *>(ctx->getRegisteredDialect("llvm"));
|
|
||||||
assert(llvmDialect && "LLVM dialect not registered");
|
|
||||||
llvm::Type *type =
|
|
||||||
llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule());
|
|
||||||
if (!type) {
|
|
||||||
ctx->emitError(loc, errorMessage.getMessage());
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
return LLVMType::get(ctx, type);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void printLLVMType(Type ty, raw_ostream &os) {
|
|
||||||
auto type = ty.dyn_cast<LLVMType>();
|
|
||||||
assert(type && "printing wrong type");
|
|
||||||
assert(type.getUnderlyingType() && "no underlying LLVM type");
|
|
||||||
type.getUnderlyingType()->print(os);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::Type *LLVMType::getUnderlyingType() const {
|
llvm::Type *LLVMType::getUnderlyingType() const {
|
||||||
return static_cast<ImplType *>(type)->underlyingType;
|
return static_cast<ImplType *>(type)->underlyingType;
|
||||||
}
|
}
|
||||||
|
@ -91,9 +70,24 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
|
||||||
addOperations<
|
addOperations<
|
||||||
#include "mlir/LLVMIR/llvm_ops.inc"
|
#include "mlir/LLVMIR/llvm_ops.inc"
|
||||||
>();
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
typeParseHook = parseLLVMType;
|
/// Parse a type registered to this dialect.
|
||||||
typePrintHook = printLLVMType;
|
Type LLVMDialect::parseType(StringRef tyData, Location loc,
|
||||||
|
MLIRContext *context) const {
|
||||||
|
llvm::SMDiagnostic errorMessage;
|
||||||
|
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
|
||||||
|
if (!type)
|
||||||
|
return (context->emitError(loc, errorMessage.getMessage()), nullptr);
|
||||||
|
return LLVMType::get(context, type);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Print a type registered to this dialect.
|
||||||
|
void LLVMDialect::printType(Type type, raw_ostream &os) const {
|
||||||
|
auto llvmType = type.dyn_cast<LLVMType>();
|
||||||
|
assert(llvmType && "printing wrong type");
|
||||||
|
assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
|
||||||
|
llvmType.getUnderlyingType()->print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
static DialectRegistration<LLVMDialect> llvmDialect;
|
static DialectRegistration<LLVMDialect> llvmDialect;
|
||||||
|
|
|
@ -496,15 +496,7 @@ Type Parser::parseExtendedType() {
|
||||||
return aliasIt->second;
|
return aliasIt->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, check for a registered dialect with this name.
|
// Otherwise, we are parsing a dialect-specific type.
|
||||||
auto *dialect = state.context->getRegisteredDialect(identifier);
|
|
||||||
if (dialect) {
|
|
||||||
// Make sure that the dialect provides a parsing hook.
|
|
||||||
if (!dialect->typeParseHook)
|
|
||||||
return (emitError("dialect '" + dialect->getNamespace() +
|
|
||||||
"' provides no type parsing hook"),
|
|
||||||
nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume the '<'.
|
// Consume the '<'.
|
||||||
if (parseToken(Token::less, "expected '<' in dialect type"))
|
if (parseToken(Token::less, "expected '<' in dialect type"))
|
||||||
|
@ -522,8 +514,8 @@ Type Parser::parseExtendedType() {
|
||||||
Type result;
|
Type result;
|
||||||
|
|
||||||
// If we found a registered dialect, then ask it to parse the type.
|
// If we found a registered dialect, then ask it to parse the type.
|
||||||
if (dialect) {
|
if (auto *dialect = state.context->getRegisteredDialect(identifier)) {
|
||||||
result = dialect->typeParseHook(typeData, loc, state.context);
|
result = dialect->parseType(typeData, loc, state.context);
|
||||||
if (!result)
|
if (!result)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue