llvm-project/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

124 lines
4.5 KiB
C++

//===- TransformDialect.cpp - Transform Dialect Definition ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/DialectImplementation.h"
using namespace mlir;
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
#ifndef NDEBUG
void transform::detail::checkImplementsTransformOpInterface(
StringRef name, MLIRContext *context) {
// Since the operation is being inserted into the Transform dialect and the
// dialect does not implement the interface fallback, only check for the op
// itself having the interface implementation.
RegisteredOperationName opName =
*RegisteredOperationName::lookup(name, context);
assert((opName.hasInterface<TransformOpInterface>() ||
opName.hasTrait<OpTrait::IsTerminator>()) &&
"non-terminator ops injected into the transform dialect must "
"implement TransformOpInterface");
assert(opName.hasInterface<MemoryEffectOpInterface>() &&
"ops injected into the transform dialect must implement "
"MemoryEffectsOpInterface");
}
void transform::detail::checkImplementsTransformTypeInterface(
TypeID typeID, MLIRContext *context) {
const auto &abstractType = AbstractType::lookup(typeID, context);
assert(abstractType.hasInterface(TransformTypeInterface::getInterfaceID()));
}
#endif // NDEBUG
namespace {
struct PDLOperationTypeTransformTypeInterfaceImpl
: public transform::TransformTypeInterface::ExternalModel<
PDLOperationTypeTransformTypeInterfaceImpl, pdl::OperationType> {
DiagnosedSilenceableFailure
checkPayload(Type type, Location loc, ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
};
} // namespace
void transform::TransformDialect::initialize() {
// Using the checked versions to enable the same assertions as for the ops
// from extensions.
addOperationsChecked<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
pdl::OperationType::attachInterface<
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());
}
void transform::TransformDialect::mergeInPDLMatchHooks(
llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
// Steal the constraint functions form the given map.
for (auto &it : constraintFns)
pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
}
const llvm::StringMap<PDLConstraintFunction> &
transform::TransformDialect::getPDLConstraintHooks() const {
return pdlMatchHooks.getConstraintFunctions();
}
Type transform::TransformDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
SMLoc loc = parser.getCurrentLocation();
if (failed(parser.parseKeyword(&keyword)))
return nullptr;
auto it = typeParsingHooks.find(keyword);
if (it == typeParsingHooks.end()) {
parser.emitError(loc) << "unknown type mnemonic: " << keyword;
return nullptr;
}
return it->getValue()(parser);
}
void transform::TransformDialect::printType(Type type,
DialectAsmPrinter &printer) const {
auto it = typePrintingHooks.find(type.getTypeID());
assert(it != typePrintingHooks.end() && "printing unknown type");
it->getSecond()(type, printer);
}
void transform::TransformDialect::reportDuplicateTypeRegistration(
StringRef mnemonic) {
std::string buffer;
llvm::raw_string_ostream msg(buffer);
msg << "extensible dialect type '" << mnemonic
<< "' is already registered with a different implementation";
msg.flush();
llvm::report_fatal_error(StringRef(buffer));
}
void transform::TransformDialect::reportDuplicateOpRegistration(
StringRef opName) {
std::string buffer;
llvm::raw_string_ostream msg(buffer);
msg << "extensible dialect operation '" << opName
<< "' is already registered with a mismatching TypeID";
msg.flush();
llvm::report_fatal_error(StringRef(buffer));
}
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.cpp.inc"