Convert the dialect type parse/print hooks into virtual functions on the Dialect class.
PiperOrigin-RevId: 235589945
This commit is contained in:
parent
f1f86eac60
commit
b4f033f6c6
|
@ -35,9 +35,6 @@ using DialectConstantFoldHook = std::function<bool(
|
|||
const Instruction *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
||||
using DialectExtractElementHook =
|
||||
std::function<Attribute(const OpaqueElementsAttr, ArrayRef<uint64_t>)>;
|
||||
using DialectTypeParserHook =
|
||||
std::function<Type(StringRef, Location, MLIRContext *)>;
|
||||
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||
/// entire group. For example, hooks into other systems for constant folding,
|
||||
|
@ -80,11 +77,16 @@ public:
|
|||
return Attribute();
|
||||
};
|
||||
|
||||
/// Registered parsing/printing hooks for types registered to the dialect.
|
||||
DialectTypeParserHook typeParseHook = nullptr;
|
||||
/// Parse a type registered to this dialect.
|
||||
virtual Type parseType(StringRef tyData, Location loc,
|
||||
MLIRContext *context) const;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
/// Note: The data printed for the provided type must not include any '"'
|
||||
/// characters.
|
||||
DialectTypePrinterHook typePrintHook = nullptr;
|
||||
virtual void printType(Type, raw_ostream &) const {
|
||||
assert(0 && "dialect has no registered type printing hook");
|
||||
}
|
||||
|
||||
/// Registered hooks for getting identifier aliases for symbols. The
|
||||
/// identifier is used in place of the symbol when printing textual IR.
|
||||
|
|
|
@ -76,6 +76,13 @@ public:
|
|||
llvm::LLVMContext &getLLVMContext() { return llvmContext; }
|
||||
llvm::Module &getLLVMModule() { return module; }
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type parseType(StringRef tyData, Location loc,
|
||||
MLIRContext *context) const override;
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void printType(Type type, raw_ostream &os) const override;
|
||||
|
||||
private:
|
||||
llvm::LLVMContext llvmContext;
|
||||
llvm::Module module;
|
||||
|
|
|
@ -715,8 +715,7 @@ void ModulePrinter::printType(Type type) {
|
|||
default: {
|
||||
auto &dialect = type.getDialect();
|
||||
os << '!' << dialect.getNamespace() << "<\"";
|
||||
assert(dialect.typePrintHook && "Expected dialect type printing hook.");
|
||||
dialect.typePrintHook(type, os);
|
||||
dialect.printType(type, os);
|
||||
os << "\">";
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectHooks.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -65,3 +66,11 @@ Dialect::Dialect(StringRef namePrefix, MLIRContext *context)
|
|||
}
|
||||
|
||||
Dialect::~Dialect() {}
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type Dialect::parseType(StringRef tyData, Location loc,
|
||||
MLIRContext *context) const {
|
||||
context->emitError(loc, "dialect '" + getNamespace() +
|
||||
"' provides no type parsing hook");
|
||||
return Type();
|
||||
}
|
||||
|
|
|
@ -57,27 +57,6 @@ LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
|
|||
return Base::get(context, FIRST_LLVM_TYPE, llvmType);
|
||||
}
|
||||
|
||||
static Type parseLLVMType(StringRef data, Location loc, MLIRContext *ctx) {
|
||||
llvm::SMDiagnostic errorMessage;
|
||||
auto *llvmDialect =
|
||||
static_cast<LLVMDialect *>(ctx->getRegisteredDialect("llvm"));
|
||||
assert(llvmDialect && "LLVM dialect not registered");
|
||||
llvm::Type *type =
|
||||
llvm::parseType(data, errorMessage, llvmDialect->getLLVMModule());
|
||||
if (!type) {
|
||||
ctx->emitError(loc, errorMessage.getMessage());
|
||||
return {};
|
||||
}
|
||||
return LLVMType::get(ctx, type);
|
||||
}
|
||||
|
||||
static void printLLVMType(Type ty, raw_ostream &os) {
|
||||
auto type = ty.dyn_cast<LLVMType>();
|
||||
assert(type && "printing wrong type");
|
||||
assert(type.getUnderlyingType() && "no underlying LLVM type");
|
||||
type.getUnderlyingType()->print(os);
|
||||
}
|
||||
|
||||
llvm::Type *LLVMType::getUnderlyingType() const {
|
||||
return static_cast<ImplType *>(type)->underlyingType;
|
||||
}
|
||||
|
@ -91,9 +70,24 @@ LLVMDialect::LLVMDialect(MLIRContext *context)
|
|||
addOperations<
|
||||
#include "mlir/LLVMIR/llvm_ops.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
typeParseHook = parseLLVMType;
|
||||
typePrintHook = printLLVMType;
|
||||
/// Parse a type registered to this dialect.
|
||||
Type LLVMDialect::parseType(StringRef tyData, Location loc,
|
||||
MLIRContext *context) const {
|
||||
llvm::SMDiagnostic errorMessage;
|
||||
llvm::Type *type = llvm::parseType(tyData, errorMessage, module);
|
||||
if (!type)
|
||||
return (context->emitError(loc, errorMessage.getMessage()), nullptr);
|
||||
return LLVMType::get(context, type);
|
||||
}
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void LLVMDialect::printType(Type type, raw_ostream &os) const {
|
||||
auto llvmType = type.dyn_cast<LLVMType>();
|
||||
assert(llvmType && "printing wrong type");
|
||||
assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
|
||||
llvmType.getUnderlyingType()->print(os);
|
||||
}
|
||||
|
||||
static DialectRegistration<LLVMDialect> llvmDialect;
|
||||
|
|
|
@ -496,15 +496,7 @@ Type Parser::parseExtendedType() {
|
|||
return aliasIt->second;
|
||||
}
|
||||
|
||||
// Otherwise, check for a registered dialect with this name.
|
||||
auto *dialect = state.context->getRegisteredDialect(identifier);
|
||||
if (dialect) {
|
||||
// Make sure that the dialect provides a parsing hook.
|
||||
if (!dialect->typeParseHook)
|
||||
return (emitError("dialect '" + dialect->getNamespace() +
|
||||
"' provides no type parsing hook"),
|
||||
nullptr);
|
||||
}
|
||||
// Otherwise, we are parsing a dialect-specific type.
|
||||
|
||||
// Consume the '<'.
|
||||
if (parseToken(Token::less, "expected '<' in dialect type"))
|
||||
|
@ -522,8 +514,8 @@ Type Parser::parseExtendedType() {
|
|||
Type result;
|
||||
|
||||
// If we found a registered dialect, then ask it to parse the type.
|
||||
if (dialect) {
|
||||
result = dialect->typeParseHook(typeData, loc, state.context);
|
||||
if (auto *dialect = state.context->getRegisteredDialect(identifier)) {
|
||||
result = dialect->parseType(typeData, loc, state.context);
|
||||
if (!result)
|
||||
return nullptr;
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue