Convert the dialect type parse/print hooks into virtual functions on the Dialect class.

PiperOrigin-RevId: 235589945
This commit is contained in:
River Riddle 2019-02-25 13:16:24 -08:00 committed by jpienaar
parent f1f86eac60
commit b4f033f6c6
6 changed files with 45 additions and 42 deletions

View File

@ -35,9 +35,6 @@ using DialectConstantFoldHook = std::function<bool(
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectExtractElementHook =
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
using DialectTypeParserHook =
std::function<Type(StringRef, Location, MLIRContext *)>;
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
@ -80,11 +77,16 @@ public:
return Attribute();
};
/// Registered parsing/printing hooks for types registered to the dialect.
DialectTypeParserHook typeParseHook = nullptr;
/// Parse a type registered to this dialect.
virtual Type parseType(StringRef tyData, Location loc,
MLIRContext *context) const;
/// Print a type registered to this dialect.
/// Note: The data printed for the provided type must not include any '"'
/// characters.
DialectTypePrinterHook typePrintHook = nullptr;
virtual void printType(Type, raw_ostream &) const {
assert(0 && "dialect has no registered type printing hook");
}
/// Registered hooks for getting identifier aliases for symbols. The
/// identifier is used in place of the symbol when printing textual IR.

View File

@ -76,6 +76,13 @@ public:
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
llvm::Module &getLLVMModule() { return module; }
/// Parse a type registered to this dialect.
Type parseType(StringRef tyData, Location loc,
MLIRContext *context) const override;
/// Print a type registered to this dialect.
void printType(Type type, raw_ostream &os) const override;
private:
llvm::LLVMContext llvmContext;
llvm::Module module;

View File

@ -715,8 +715,7 @@ void ModulePrinter::printType(Type type) {
default: {
auto &dialect = type.getDialect();
os << '!' << dialect.getNamespace() << "<\"";
assert(dialect.typePrintHook && "Expected dialect type printing hook.");
dialect.typePrintHook(type, os);
dialect.printType(type, os);
os << "\">";
return;
}

View File

@ -18,6 +18,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectHooks.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
@ -65,3 +66,11 @@ Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
}
Dialect::~Dialect() {}
/// Parse a type registered to this dialect.
Type Dialect::parseType(StringRef tyData, Location loc,
MLIRContext *context) const {
context->emitError(loc, "dialect '" + getNamespace() +
"' provides no type parsing hook");
return Type();
}

View File

@ -57,27 +57,6 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
}
static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) {
llvm::SMDiagnostic errorMessage;
auto *llvmDialect =
static_cast<LLVMDialect *>(ctx->getRegisteredDialect("llvm"));
assert(llvmDialect && "LLVM dialect not registered");
llvm::Type *type =
llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule());
if (!type) {
ctx->emitError(loc, errorMessage.getMessage());
return {};
}
return LLVMType::get(ctx, type);
}
static void printLLVMType(Type ty, raw_ostream &os) {
auto type = ty.dyn_cast<LLVMType>();
assert(type && "printing wrong type");
assert(type.getUnderlyingType() && "no underlying LLVM type");
type.getUnderlyingType()->print(os);
}
llvm::Type *LLVMType::getUnderlyingType() const {
return static_cast<ImplType *>(type)->underlyingType;
}
@ -91,9 +70,24 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
addOperations<
#include "mlir/LLVMIR/llvm_ops.inc"
>();
}
typeParseHook = parseLLVMType;
typePrintHook = printLLVMType;
/// Parse a type registered to this dialect.
Type LLVMDialect::parseType(StringRef tyData, Location loc,
MLIRContext *context) const {
llvm::SMDiagnostic errorMessage;
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
if (!type)
return (context->emitError(loc, errorMessage.getMessage()), nullptr);
return LLVMType::get(context, type);
}
/// Print a type registered to this dialect.
void LLVMDialect::printType(Type type, raw_ostream &os) const {
auto llvmType = type.dyn_cast<LLVMType>();
assert(llvmType && "printing wrong type");
assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
llvmType.getUnderlyingType()->print(os);
}
static DialectRegistration<LLVMDialect> llvmDialect;

View File

@ -496,15 +496,7 @@ Type Parser::parseExtendedType() {
return aliasIt->second;
}
// Otherwise, check for a registered dialect with this name.
auto *dialect = state.context->getRegisteredDialect(identifier);
if (dialect) {
// Make sure that the dialect provides a parsing hook.
if (!dialect->typeParseHook)
return (emitError("dialect '" + dialect->getNamespace() +
"' provides no type parsing hook"),
nullptr);
}
// Otherwise, we are parsing a dialect-specific type.
// Consume the '<'.
if (parseToken(Token::less, "expected '<' in dialect type"))
@ -522,8 +514,8 @@ Type Parser::parseExtendedType() {
Type result;
// If we found a registered dialect, then ask it to parse the type.
if (dialect) {
result = dialect->typeParseHook(typeData, loc, state.context);
if (auto *dialect = state.context->getRegisteredDialect(identifier)) {
result = dialect->parseType(typeData, loc, state.context);
if (!result)
return nullptr;
} else {