[mlir:PDL] Add support for DialectConversion with pattern configurations

Up until now PDL(L) has not supported dialect conversion because we had no
way of remapping values or integrating with type conversions. This commit
rectifies that by adding a new "pattern configuration" concept to PDL. This
essentially allows for attaching external configurations to patterns, which
can hook into pattern events (for now just the scope of a rewrite, but we
could also pass configs to native rewrites as well). This allows for injecting
the type converter into the conversion pattern rewriter.

Differential Revision: https://reviews.llvm.org/D133142
This commit is contained in:
River Riddle 2022-09-08 16:59:39 -07:00
parent f3a86a23c1
commit 8c66344ee9
19 changed files with 669 additions and 95 deletions

View File

@ -13,12 +13,14 @@
#ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H #ifndef MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H #define MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H
#include <memory> #include "mlir/Support/LLVM.h"
namespace mlir { namespace mlir {
class ModuleOp; class ModuleOp;
class Operation;
template <typename OpT> template <typename OpT>
class OperationPass; class OperationPass;
class PDLPatternConfigSet;
#define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP #define GEN_PASS_DECL_CONVERTPDLTOPDLINTERP
#include "mlir/Conversion/Passes.h.inc" #include "mlir/Conversion/Passes.h.inc"
@ -26,6 +28,12 @@ class OperationPass;
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops. /// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(); std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass();
/// Creates and returns a pass to convert PDL ops to PDL interpreter ops.
/// `configMap` holds a map of the configurations for each pattern being
/// compiled.
std::unique_ptr<OperationPass<ModuleOp>> createPDLToPDLInterpPass(
DenseMap<Operation *, PDLPatternConfigSet *> &configMap);
} // namespace mlir } // namespace mlir
#endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PDLTOPDLINTERP_H

View File

@ -600,10 +600,16 @@ public:
class PatternRewriter : public RewriterBase { class PatternRewriter : public RewriterBase {
public: public:
using RewriterBase::RewriterBase; using RewriterBase::RewriterBase;
/// A hook used to indicate if the pattern rewriter can recover from failure
/// during the rewrite stage of a pattern. For example, if the pattern
/// rewriter supports rollback, it may progress smoothly even if IR was
/// changed during the rewrite.
virtual bool canRecoverFromRewriteFailure() const { return false; }
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PDLPatternModule // PDL Patterns
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -796,6 +802,108 @@ protected:
SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges; SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
}; };
//===----------------------------------------------------------------------===//
// PDLPatternConfig
/// An individual configuration for a pattern, which can be accessed by native
/// functions via the PDLPatternConfigSet. This allows for injecting additional
/// configuration into PDL patterns that is specific to certain compilation
/// flows.
class PDLPatternConfig {
public:
virtual ~PDLPatternConfig() = default;
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern. These can be used to setup any specific state necessary for the
/// rewrite.
virtual void notifyRewriteBegin(PatternRewriter &rewriter) {}
virtual void notifyRewriteEnd(PatternRewriter &rewriter) {}
/// Return the TypeID that represents this configuration.
TypeID getTypeID() const { return id; }
protected:
PDLPatternConfig(TypeID id) : id(id) {}
private:
TypeID id;
};
/// This class provides a base class for users implementing a type of pattern
/// configuration.
template <typename T>
class PDLPatternConfigBase : public PDLPatternConfig {
public:
/// Support LLVM style casting.
static bool classof(const PDLPatternConfig *config) {
return config->getTypeID() == getConfigID();
}
/// Return the type id used for this configuration.
static TypeID getConfigID() { return TypeID::get<T>(); }
protected:
PDLPatternConfigBase() : PDLPatternConfig(getConfigID()) {}
};
/// This class contains a set of configurations for a specific pattern.
/// Configurations are uniqued by TypeID, meaning that only one configuration of
/// each type is allowed.
class PDLPatternConfigSet {
public:
PDLPatternConfigSet() = default;
/// Construct a set with the given configurations.
template <typename... ConfigsT>
PDLPatternConfigSet(ConfigsT &&...configs) {
(addConfig(std::forward<ConfigsT>(configs)), ...);
}
/// Get the configuration defined by the given type. Asserts that the
/// configuration of the provided type exists.
template <typename T>
const T &get() const {
const T *config = tryGet<T>();
assert(config && "configuration not found");
return *config;
}
/// Get the configuration defined by the given type, returns nullptr if the
/// configuration does not exist.
template <typename T>
const T *tryGet() const {
for (const auto &configIt : configs)
if (const T *config = dyn_cast<T>(configIt.get()))
return config;
return nullptr;
}
/// Notify the configurations within this set at the beginning or end of a
/// rewrite of a matched pattern.
void notifyRewriteBegin(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteBegin(rewriter);
}
void notifyRewriteEnd(PatternRewriter &rewriter) {
for (const auto &config : configs)
config->notifyRewriteEnd(rewriter);
}
protected:
/// Add a configuration to the set.
template <typename T>
void addConfig(T &&config) {
assert(!tryGet<std::decay_t<T>>() && "configuration already exists");
configs.emplace_back(
std::make_unique<std::decay_t<T>>(std::forward<T>(config)));
}
/// The set of configurations for this pattern. This uses a vector instead of
/// a map with the expectation that the number of configurations per set is
/// small (<= 1).
SmallVector<std::unique_ptr<PDLPatternConfig>> configs;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PDLPatternModule // PDLPatternModule
@ -807,9 +915,11 @@ using PDLConstraintFunction =
/// A native PDL rewrite function. This function performs a rewrite on the /// A native PDL rewrite function. This function performs a rewrite on the
/// given set of values. Any results from this rewrite that should be passed /// given set of values. Any results from this rewrite that should be passed
/// back to PDL should be added to the provided result list. This method is only /// back to PDL should be added to the provided result list. This method is only
/// invoked when the corresponding match was successful. /// invoked when the corresponding match was successful. Returns failure if an
using PDLRewriteFunction = /// invariant of the rewrite was broken (certain rewriters may recover from
std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; /// partial pattern application).
using PDLRewriteFunction = std::function<LogicalResult(
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
namespace detail { namespace detail {
namespace pdl_function_builder { namespace pdl_function_builder {
@ -1034,6 +1144,13 @@ struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
results.push_back(types); results.push_back(types);
} }
}; };
template <unsigned N>
struct ProcessPDLValue<SmallVector<Type, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Type, N> values) {
results.push_back(TypeRange(values));
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Value // Value
@ -1061,6 +1178,13 @@ struct ProcessPDLValue<ResultRange> {
results.push_back(values); results.push_back(values);
} }
}; };
template <unsigned N>
struct ProcessPDLValue<SmallVector<Value, N>> {
static void processAsResult(PatternRewriter &, PDLResultList &results,
SmallVector<Value, N> values) {
results.push_back(ValueRange(values));
}
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PDL Function Builder: Argument Handling // PDL Function Builder: Argument Handling
@ -1111,28 +1235,49 @@ void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
/// Store a single result within the result list. /// Store a single result within the result list.
template <typename T> template <typename T>
static void processResults(PatternRewriter &rewriter, PDLResultList &results, static LogicalResult processResults(PatternRewriter &rewriter,
T &&value) { PDLResultList &results, T &&value) {
ProcessPDLValue<T>::processAsResult(rewriter, results, ProcessPDLValue<T>::processAsResult(rewriter, results,
std::forward<T>(value)); std::forward<T>(value));
return success();
} }
/// Store a std::pair<> as individual results within the result list. /// Store a std::pair<> as individual results within the result list.
template <typename T1, typename T2> template <typename T1, typename T2>
static void processResults(PatternRewriter &rewriter, PDLResultList &results, static LogicalResult processResults(PatternRewriter &rewriter,
std::pair<T1, T2> &&pair) { PDLResultList &results,
processResults(rewriter, results, std::move(pair.first)); std::pair<T1, T2> &&pair) {
processResults(rewriter, results, std::move(pair.second)); if (failed(processResults(rewriter, results, std::move(pair.first))) ||
failed(processResults(rewriter, results, std::move(pair.second))))
return failure();
return success();
} }
/// Store a std::tuple<> as individual results within the result list. /// Store a std::tuple<> as individual results within the result list.
template <typename... Ts> template <typename... Ts>
static void processResults(PatternRewriter &rewriter, PDLResultList &results, static LogicalResult processResults(PatternRewriter &rewriter,
std::tuple<Ts...> &&tuple) { PDLResultList &results,
std::tuple<Ts...> &&tuple) {
auto applyFn = [&](auto &&...args) { auto applyFn = [&](auto &&...args) {
(processResults(rewriter, results, std::move(args)), ...); return (succeeded(processResults(rewriter, results, std::move(args))) &&
...);
}; };
std::apply(applyFn, std::move(tuple)); return success(std::apply(applyFn, std::move(tuple)));
}
/// Handle LogicalResult propagation.
inline LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
LogicalResult &&result) {
return result;
}
template <typename T>
static LogicalResult processResults(PatternRewriter &rewriter,
PDLResultList &results,
FailureOr<T> &&result) {
if (failed(result))
return failure();
return processResults(rewriter, results, std::move(*result));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1192,23 +1337,26 @@ buildConstraintFn(ConstraintFnT &&constraintFn) {
/// This overload handles the case of no return values. /// This overload handles the case of no return values.
template <typename PDLFnT, std::size_t... I, template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>> typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value> std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &, ArrayRef<PDLValue> values, PDLResultList &, ArrayRef<PDLValue> values,
std::index_sequence<I...>) { std::index_sequence<I...>) {
fn(rewriter, fn(rewriter,
(ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
values[I]))...); values[I]))...);
return success();
} }
/// This overload handles the case of return values, which need to be packaged /// This overload handles the case of return values, which need to be packaged
/// into the result list. /// into the result list.
template <typename PDLFnT, std::size_t... I, template <typename PDLFnT, std::size_t... I,
typename FnTraitsT = llvm::function_traits<PDLFnT>> typename FnTraitsT = llvm::function_traits<PDLFnT>>
std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value> std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value,
LogicalResult>
processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
PDLResultList &results, ArrayRef<PDLValue> values, PDLResultList &results, ArrayRef<PDLValue> values,
std::index_sequence<I...>) { std::index_sequence<I...>) {
processResults( return processResults(
rewriter, results, rewriter, results,
fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
processAsArg(values[I]))...)); processAsArg(values[I]))...));
@ -1240,14 +1388,17 @@ buildRewriteFn(RewriteFnT &&rewriteFn) {
std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args - std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1>(); 1>();
assertArgs<RewriteFnT>(rewriter, values, argIndices); assertArgs<RewriteFnT>(rewriter, values, argIndices);
processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, return processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
argIndices); argIndices);
}; };
} }
} // namespace pdl_function_builder } // namespace pdl_function_builder
} // namespace detail } // namespace detail
//===----------------------------------------------------------------------===//
// PDLPatternModule
/// This class contains all of the necessary data for a set of PDL patterns, or /// This class contains all of the necessary data for a set of PDL patterns, or
/// pattern rewrites specified in the form of the PDL dialect. This PDL module /// pattern rewrites specified in the form of the PDL dialect. This PDL module
/// contained by this pattern may contain any number of `pdl.pattern` /// contained by this pattern may contain any number of `pdl.pattern`
@ -1256,9 +1407,17 @@ class PDLPatternModule {
public: public:
PDLPatternModule() = default; PDLPatternModule() = default;
/// Construct a PDL pattern with the given module. /// Construct a PDL pattern with the given module and configurations.
PDLPatternModule(OwningOpRef<ModuleOp> pdlModule) PDLPatternModule(OwningOpRef<ModuleOp> module)
: pdlModule(std::move(pdlModule)) {} : pdlModule(std::move(module)) {}
template <typename... ConfigsT>
PDLPatternModule(OwningOpRef<ModuleOp> module, ConfigsT &&...patternConfigs)
: PDLPatternModule(std::move(module)) {
auto configSet = std::make_unique<PDLPatternConfigSet>(
std::forward<ConfigsT>(patternConfigs)...);
attachConfigToPatterns(*pdlModule, *configSet);
configs.emplace_back(std::move(configSet));
}
/// Merge the state in `other` into this pattern module. /// Merge the state in `other` into this pattern module.
void mergeIn(PDLPatternModule &&other); void mergeIn(PDLPatternModule &&other);
@ -1344,6 +1503,14 @@ public:
return rewriteFunctions; return rewriteFunctions;
} }
/// Return the set of the registered pattern configs.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> takeConfigs() {
return std::move(configs);
}
DenseMap<Operation *, PDLPatternConfigSet *> takeConfigMap() {
return std::move(configMap);
}
/// Clear out the patterns and functions within this module. /// Clear out the patterns and functions within this module.
void clear() { void clear() {
pdlModule = nullptr; pdlModule = nullptr;
@ -1352,9 +1519,17 @@ public:
} }
private: private:
/// Attach the given pattern config set to the patterns defined within the
/// given module.
void attachConfigToPatterns(ModuleOp module, PDLPatternConfigSet &configSet);
/// The module containing the `pdl.pattern` operations. /// The module containing the `pdl.pattern` operations.
OwningOpRef<ModuleOp> pdlModule; OwningOpRef<ModuleOp> pdlModule;
/// The set of configuration sets referenced by patterns within `pdlModule`.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
DenseMap<Operation *, PDLPatternConfigSet *> configMap;
/// The external functions referenced from within the PDL module. /// The external functions referenced from within the PDL module.
llvm::StringMap<PDLConstraintFunction> constraintFunctions; llvm::StringMap<PDLConstraintFunction> constraintFunctions;
llvm::StringMap<PDLRewriteFunction> rewriteFunctions; llvm::StringMap<PDLRewriteFunction> rewriteFunctions;

View File

@ -574,6 +574,11 @@ public:
// PatternRewriter Hooks // PatternRewriter Hooks
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
/// Indicate that the conversion rewriter can recover from rewrite failure.
/// Recovery is supported via rollback, allowing for continued processing of
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
/// PatternRewriter hook for replacing the results of an operation when the /// PatternRewriter hook for replacing the results of an operation when the
/// given functor returns true. /// given functor returns true.
void replaceOpWithIf( void replaceOpWithIf(
@ -891,6 +896,35 @@ private:
MLIRContext &ctx; MLIRContext &ctx;
}; };
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
/// A PDL configuration that is used to supported dialect conversion
/// functionality.
class PDLConversionConfig final
: public PDLPatternConfigBase<PDLConversionConfig> {
public:
PDLConversionConfig(TypeConverter *converter) : converter(converter) {}
~PDLConversionConfig() final = default;
/// Return the type converter used by this configuration, which may be nullptr
/// if no type conversions are expected.
TypeConverter *getTypeConverter() const { return converter; }
/// Hooks that are invoked at the beginning and end of a rewrite of a matched
/// pattern.
void notifyRewriteBegin(PatternRewriter &rewriter) final;
void notifyRewriteEnd(PatternRewriter &rewriter) final;
private:
/// An optional type converter to use for the pattern.
TypeConverter *converter;
};
/// Register the dialect conversion PDL functions with the given pattern set.
void registerConversionPDLFunctions(RewritePatternSet &patterns);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op Conversion Entry Points // Op Conversion Entry Points
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,30 @@
//===- DialectConversion.pdll - DialectConversion PDLL Support -*- PDLL -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines various utilities for interacting with dialect conversion
// within PDLL.
//
//===----------------------------------------------------------------------===//
/// This rewrite returns the converted value of `value`, whose type is defined
/// by the type converted specified in the `PDLConversionConfig` of the current
/// pattern.
Rewrite convertValue(value: Value) -> Value;
/// This rewrite returns the converted values of `values`, whose type is defined
/// by the type converted specified in the `PDLConversionConfig` of the current
/// pattern.
Rewrite convertValues(values: ValueRange) -> ValueRange;
/// This rewrite returns the converted type of `type` as defined by the type
/// converted specified in the `PDLConversionConfig` of the current pattern.
Rewrite convertType(type: Type) -> Type;
/// This rewrite returns the converted types of `types` as defined by the type
/// converted specified in the `PDLConversionConfig` of the current pattern.
Rewrite convertTypes(types: TypeRange) -> TypeRange;

View File

@ -37,7 +37,8 @@ namespace {
/// given module containing PDL pattern operations. /// given module containing PDL pattern operations.
struct PatternLowering { struct PatternLowering {
public: public:
PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule); PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
DenseMap<Operation *, PDLPatternConfigSet *> *configMap);
/// Generate code for matching and rewriting based on the pattern operations /// Generate code for matching and rewriting based on the pattern operations
/// within the module. /// within the module.
@ -140,13 +141,19 @@ private:
/// The set of operation values whose whose location will be used for newly /// The set of operation values whose whose location will be used for newly
/// generated operations. /// generated operations.
SetVector<Value> locOps; SetVector<Value> locOps;
/// A mapping between pattern operations and the corresponding configuration
/// set.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
}; };
} // namespace } // namespace
PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc, PatternLowering::PatternLowering(
ModuleOp rewriterModule) pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule,
DenseMap<Operation *, PDLPatternConfigSet *> *configMap)
: builder(matcherFunc.getContext()), matcherFunc(matcherFunc), : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {} rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule),
configMap(configMap) {}
void PatternLowering::lower(ModuleOp module) { void PatternLowering::lower(ModuleOp module) {
PredicateUniquer predicateUniquer; PredicateUniquer predicateUniquer;
@ -589,10 +596,14 @@ void PatternLowering::generate(SuccessNode *successNode, Block *&currentBlock) {
rootKindAttr = builder.getStringAttr(*rootKind); rootKindAttr = builder.getStringAttr(*rootKind);
builder.setInsertionPointToEnd(currentBlock); builder.setInsertionPointToEnd(currentBlock);
builder.create<pdl_interp::RecordMatchOp>( auto matchOp = builder.create<pdl_interp::RecordMatchOp>(
pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(), pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(), rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.getBenefitAttr(),
failureBlockStack.back()); failureBlockStack.back());
// Set the config of the lowered match to the parent pattern.
if (configMap)
configMap->try_emplace(matchOp, configMap->lookup(pattern));
} }
SymbolRefAttr PatternLowering::generateRewriter( SymbolRefAttr PatternLowering::generateRewriter(
@ -922,7 +933,14 @@ void PatternLowering::generateOperationResultTypeRewriter(
namespace { namespace {
struct PDLToPDLInterpPass struct PDLToPDLInterpPass
: public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> { : public impl::ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
PDLToPDLInterpPass() = default;
PDLToPDLInterpPass(const PDLToPDLInterpPass &rhs) = default;
PDLToPDLInterpPass(DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: configMap(&configMap) {}
void runOnOperation() final; void runOnOperation() final;
/// A map containing the configuration for each pattern.
DenseMap<Operation *, PDLPatternConfigSet *> *configMap = nullptr;
}; };
} // namespace } // namespace
@ -946,15 +964,24 @@ void PDLToPDLInterpPass::runOnOperation() {
module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName()); module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
// Generate the code for the patterns within the module. // Generate the code for the patterns within the module.
PatternLowering generator(matcherFunc, rewriterModule); PatternLowering generator(matcherFunc, rewriterModule, configMap);
generator.lower(module); generator.lower(module);
// After generation, delete all of the pattern operations. // After generation, delete all of the pattern operations.
for (pdl::PatternOp pattern : for (pdl::PatternOp pattern :
llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) llvm::make_early_inc_range(module.getOps<pdl::PatternOp>())) {
// Drop the now dead config mappings.
if (configMap)
configMap->erase(pattern);
pattern.erase(); pattern.erase();
}
} }
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() { std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
return std::make_unique<PDLToPDLInterpPass>(); return std::make_unique<PDLToPDLInterpPass>();
} }
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass(
DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
return std::make_unique<PDLToPDLInterpPass>(configMap);
}

View File

@ -158,11 +158,15 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
if (!other.pdlModule) if (!other.pdlModule)
return; return;
// Steal the functions of the other module. // Steal the functions and config of the other module.
for (auto &it : other.constraintFunctions) for (auto &it : other.constraintFunctions)
registerConstraintFunction(it.first(), std::move(it.second)); registerConstraintFunction(it.first(), std::move(it.second));
for (auto &it : other.rewriteFunctions) for (auto &it : other.rewriteFunctions)
registerRewriteFunction(it.first(), std::move(it.second)); registerRewriteFunction(it.first(), std::move(it.second));
for (auto &it : other.configs)
configs.emplace_back(std::move(it));
for (auto &it : other.configMap)
configMap.insert(it);
// Steal the other state if we have no patterns. // Steal the other state if we have no patterns.
if (!pdlModule) { if (!pdlModule) {
@ -176,6 +180,18 @@ void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
other.pdlModule->getBody()->getOperations()); other.pdlModule->getBody()->getOperations());
} }
void PDLPatternModule::attachConfigToPatterns(ModuleOp module,
PDLPatternConfigSet &configSet) {
// Attach the configuration to the symbols within the module. We only add
// to symbols to avoid hardcoding any specific operation names here (given
// that we don't depend on any PDL dialect). We can't use
// cast<SymbolOpInterface> here because patterns may be optional symbols.
module->walk([&](Operation *op) {
if (op->hasTrait<SymbolOpInterface::Trait>())
configMap[op] = &configSet;
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Function Registry // Function Registry

View File

@ -34,21 +34,23 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr) { ByteCodeAddr rewriterAddr) {
PatternBenefit benefit = matchOp.getBenefit();
MLIRContext *ctx = matchOp.getContext();
// Collect the set of generated operations.
SmallVector<StringRef, 8> generatedOps; SmallVector<StringRef, 8> generatedOps;
if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr()) if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
generatedOps = generatedOps =
llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>()); llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
PatternBenefit benefit = matchOp.getBenefit();
MLIRContext *ctx = matchOp.getContext();
// Check to see if this is pattern matches a specific operation type. // Check to see if this is pattern matches a specific operation type.
if (Optional<StringRef> rootKind = matchOp.getRootKind()) if (Optional<StringRef> rootKind = matchOp.getRootKind())
return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx, return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
generatedOps); generatedOps);
return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx, return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
generatedOps); benefit, ctx, generatedOps);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -194,14 +196,15 @@ public:
ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex,
ByteCodeField &maxLoopLevel, ByteCodeField &maxLoopLevel,
llvm::StringMap<PDLConstraintFunction> &constraintFns, llvm::StringMap<PDLConstraintFunction> &constraintFns,
llvm::StringMap<PDLRewriteFunction> &rewriteFns) llvm::StringMap<PDLRewriteFunction> &rewriteFns,
const DenseMap<Operation *, PDLPatternConfigSet *> &configMap)
: ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
rewriterByteCode(rewriterByteCode), patterns(patterns), rewriterByteCode(rewriterByteCode), patterns(patterns),
maxValueMemoryIndex(maxValueMemoryIndex), maxValueMemoryIndex(maxValueMemoryIndex),
maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
maxLoopLevel(maxLoopLevel) { maxLoopLevel(maxLoopLevel), configMap(configMap) {
for (const auto &it : llvm::enumerate(constraintFns)) for (const auto &it : llvm::enumerate(constraintFns))
constraintToMemIndex.try_emplace(it.value().first(), it.index()); constraintToMemIndex.try_emplace(it.value().first(), it.index());
for (const auto &it : llvm::enumerate(rewriteFns)) for (const auto &it : llvm::enumerate(rewriteFns))
@ -328,6 +331,9 @@ private:
ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex;
ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex;
ByteCodeField &maxLoopLevel; ByteCodeField &maxLoopLevel;
/// A map of pattern configurations.
const DenseMap<Operation *, PDLPatternConfigSet *> &configMap;
}; };
/// This class provides utilities for writing a bytecode stream. /// This class provides utilities for writing a bytecode stream.
@ -969,7 +975,8 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
ByteCodeField patternIndex = patterns.size(); ByteCodeField patternIndex = patterns.size();
patterns.emplace_back(PDLByteCodePattern::create( patterns.emplace_back(PDLByteCodePattern::create(
op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()])); op, configMap.lookup(op),
rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
writer.append(OpCode::RecordMatch, patternIndex, writer.append(OpCode::RecordMatch, patternIndex,
SuccessorRange(op.getOperation()), op.getMatchedOps()); SuccessorRange(op.getOperation()), op.getMatchedOps());
writer.appendPDLValueList(op.getInputs()); writer.appendPDLValueList(op.getInputs());
@ -1014,13 +1021,16 @@ void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
// PDLByteCode // PDLByteCode
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
PDLByteCode::PDLByteCode(ModuleOp module, PDLByteCode::PDLByteCode(
llvm::StringMap<PDLConstraintFunction> constraintFns, ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
llvm::StringMap<PDLRewriteFunction> rewriteFns) { const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns)
: configs(std::move(configs)) {
Generator generator(module.getContext(), uniquedData, matcherByteCode, Generator generator(module.getContext(), uniquedData, matcherByteCode,
rewriterByteCode, patterns, maxValueMemoryIndex, rewriterByteCode, patterns, maxValueMemoryIndex,
maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
maxLoopLevel, constraintFns, rewriteFns); maxLoopLevel, constraintFns, rewriteFns, configMap);
generator.generate(module); generator.generate(module);
// Initialize the external functions. // Initialize the external functions.
@ -1076,14 +1086,15 @@ public:
/// Start executing the code at the current bytecode index. `matches` is an /// Start executing the code at the current bytecode index. `matches` is an
/// optional field provided when this function is executed in a matching /// optional field provided when this function is executed in a matching
/// context. /// context.
void execute(PatternRewriter &rewriter, LogicalResult
SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr, execute(PatternRewriter &rewriter,
Optional<Location> mainRewriteLoc = {}); SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
Optional<Location> mainRewriteLoc = {});
private: private:
/// Internal implementation of executing each of the bytecode commands. /// Internal implementation of executing each of the bytecode commands.
void executeApplyConstraint(PatternRewriter &rewriter); void executeApplyConstraint(PatternRewriter &rewriter);
void executeApplyRewrite(PatternRewriter &rewriter); LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
void executeAreEqual(); void executeAreEqual();
void executeAreRangesEqual(); void executeAreRangesEqual();
void executeBranch(); void executeBranch();
@ -1345,7 +1356,7 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
selectJump(succeeded(constraintFn(rewriter, args))); selectJump(succeeded(constraintFn(rewriter, args)));
} }
void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
SmallVector<PDLValue, 16> args; SmallVector<PDLValue, 16> args;
@ -1359,7 +1370,7 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
// Execute the rewrite function. // Execute the rewrite function.
ByteCodeField numResults = read(); ByteCodeField numResults = read();
ByteCodeRewriteResultList results(numResults); ByteCodeRewriteResultList results(numResults);
rewriteFn(rewriter, results, args); LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
assert(results.getResults().size() == numResults && assert(results.getResults().size() == numResults &&
"native PDL rewrite function returned unexpected number of results"); "native PDL rewrite function returned unexpected number of results");
@ -1395,6 +1406,13 @@ void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
allocatedTypeRangeMemory.push_back(std::move(it)); allocatedTypeRangeMemory.push_back(std::move(it));
for (auto &it : results.getAllocatedValueRanges()) for (auto &it : results.getAllocatedValueRanges())
allocatedValueRangeMemory.push_back(std::move(it)); allocatedValueRangeMemory.push_back(std::move(it));
// Process the result of the rewrite.
if (failed(rewriteResult)) {
LLVM_DEBUG(llvm::dbgs() << " - Failed");
return failure();
}
return success();
} }
void ByteCodeExecutor::executeAreEqual() { void ByteCodeExecutor::executeAreEqual() {
@ -2017,10 +2035,10 @@ void ByteCodeExecutor::executeSwitchTypes() {
}); });
} }
void ByteCodeExecutor::execute( LogicalResult
PatternRewriter &rewriter, ByteCodeExecutor::execute(PatternRewriter &rewriter,
SmallVectorImpl<PDLByteCode::MatchResult> *matches, SmallVectorImpl<PDLByteCode::MatchResult> *matches,
Optional<Location> mainRewriteLoc) { Optional<Location> mainRewriteLoc) {
while (true) { while (true) {
// Print the location of the operation being executed. // Print the location of the operation being executed.
LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
@ -2031,7 +2049,8 @@ void ByteCodeExecutor::execute(
executeApplyConstraint(rewriter); executeApplyConstraint(rewriter);
break; break;
case ApplyRewrite: case ApplyRewrite:
executeApplyRewrite(rewriter); if (failed(executeApplyRewrite(rewriter)))
return failure();
break; break;
case AreEqual: case AreEqual:
executeAreEqual(); executeAreEqual();
@ -2078,7 +2097,7 @@ void ByteCodeExecutor::execute(
case Finalize: case Finalize:
executeFinalize(); executeFinalize();
LLVM_DEBUG(llvm::dbgs() << "\n"); LLVM_DEBUG(llvm::dbgs() << "\n");
return; return success();
case ForEach: case ForEach:
executeForEach(); executeForEach();
break; break;
@ -2166,8 +2185,6 @@ void ByteCodeExecutor::execute(
} }
} }
/// Run the pattern matcher on the given root operation, collecting the matched
/// patterns in `matches`.
void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches, SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const { PDLByteCodeMutableState &state) const {
@ -2181,7 +2198,8 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions); constraintFunctions, rewriteFunctions);
executor.execute(rewriter, &matches); LogicalResult executeResult = executor.execute(rewriter, &matches);
assert(succeeded(executeResult) && "unexpected matcher execution failure");
// Order the found matches by benefit. // Order the found matches by benefit.
std::stable_sort(matches.begin(), matches.end(), std::stable_sort(matches.begin(), matches.end(),
@ -2190,9 +2208,13 @@ void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
}); });
} }
/// Run the rewriter of the given pattern on the root operation `op`. LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, const MatchResult &match,
PDLByteCodeMutableState &state) const { PDLByteCodeMutableState &state) const {
auto *configSet = match.pattern->getConfigSet();
if (configSet)
configSet->notifyRewriteBegin(rewriter);
// The arguments of the rewrite function are stored at the start of the // The arguments of the rewrite function are stored at the start of the
// memory buffer. // memory buffer.
llvm::copy(match.values, state.memory.begin()); llvm::copy(match.values, state.memory.begin());
@ -2204,5 +2226,24 @@ void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
state.allocatedValueRangeMemory, state.loopIndex, uniquedData, state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
rewriterByteCode, state.currentPatternBenefits, patterns, rewriterByteCode, state.currentPatternBenefits, patterns,
constraintFunctions, rewriteFunctions); constraintFunctions, rewriteFunctions);
executor.execute(rewriter, /*matches=*/nullptr, match.location); LogicalResult result =
executor.execute(rewriter, /*matches=*/nullptr, match.location);
if (configSet)
configSet->notifyRewriteEnd(rewriter);
// If the rewrite failed, check if the pattern rewriter can recover. If it
// can, we can signal to the pattern applicator to keep trying patterns. If it
// doesn't, we need to bail. Bailing here should be fine, given that we have
// no means to propagate such a failure to the user, and it also indicates a
// bug in the user code (i.e. failable rewrites should not be used with
// pattern rewriters that don't support it).
if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
llvm::report_fatal_error(
"Native PDL Rewrite failed, but the pattern "
"rewriter doesn't support recovery. Failable pattern rewrites should "
"not be used with pattern rewriters that do not support them.");
}
return result;
} }

View File

@ -38,19 +38,27 @@ using OwningOpRange = llvm::OwningArrayRef<Operation *>;
class PDLByteCodePattern : public Pattern { class PDLByteCodePattern : public Pattern {
public: public:
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
PDLPatternConfigSet *configSet,
ByteCodeAddr rewriterAddr); ByteCodeAddr rewriterAddr);
/// Return the bytecode address of the rewriter for this pattern. /// Return the bytecode address of the rewriter for this pattern.
ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
/// Return the configuration set for this pattern, or null if there is none.
PDLPatternConfigSet *getConfigSet() const { return configSet; }
private: private:
template <typename... Args> template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) PDLByteCodePattern(ByteCodeAddr rewriterAddr, PDLPatternConfigSet *configSet,
: Pattern(std::forward<Args>(patternArgs)...), Args &&...patternArgs)
rewriterAddr(rewriterAddr) {} : Pattern(std::forward<Args>(patternArgs)...), rewriterAddr(rewriterAddr),
configSet(configSet) {}
/// The address of the rewriter for this pattern. /// The address of the rewriter for this pattern.
ByteCodeAddr rewriterAddr; ByteCodeAddr rewriterAddr;
/// The optional config set for this pattern.
PDLPatternConfigSet *configSet;
}; };
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -148,6 +156,8 @@ public:
/// Create a ByteCode instance from the given module containing operations in /// Create a ByteCode instance from the given module containing operations in
/// the PDL interpreter dialect. /// the PDL interpreter dialect.
PDLByteCode(ModuleOp module, PDLByteCode(ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
const DenseMap<Operation *, PDLPatternConfigSet *> &configMap,
llvm::StringMap<PDLConstraintFunction> constraintFns, llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns); llvm::StringMap<PDLRewriteFunction> rewriteFns);
@ -165,9 +175,9 @@ public:
PDLByteCodeMutableState &state) const; PDLByteCodeMutableState &state) const;
/// Run the rewriter of the given pattern that was previously matched in /// Run the rewriter of the given pattern that was previously matched in
/// `match`. /// `match`. Returns if a failure was encountered during the rewrite.
void rewrite(PatternRewriter &rewriter, const MatchResult &match, LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const; PDLByteCodeMutableState &state) const;
private: private:
/// Execute the given byte code starting at the provided instruction `inst`. /// Execute the given byte code starting at the provided instruction `inst`.
@ -177,6 +187,9 @@ private:
PDLByteCodeMutableState &state, PDLByteCodeMutableState &state,
SmallVectorImpl<MatchResult> *matches) const; SmallVectorImpl<MatchResult> *matches) const;
/// The set of pattern configs referenced within the bytecode.
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs;
/// A vector containing pointers to uniqued data. The storage is intentionally /// A vector containing pointers to uniqued data. The storage is intentionally
/// opaque such that we can store a wide range of data types. The types of /// opaque such that we can store a wide range of data types. The types of
/// data stored here include: /// data stored here include:

View File

@ -16,7 +16,9 @@
using namespace mlir; using namespace mlir;
static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { static LogicalResult
convertPDLToPDLInterp(ModuleOp pdlModule,
DenseMap<Operation *, PDLPatternConfigSet *> &configMap) {
// Skip the conversion if the module doesn't contain pdl. // Skip the conversion if the module doesn't contain pdl.
if (pdlModule.getOps<pdl::PatternOp>().empty()) if (pdlModule.getOps<pdl::PatternOp>().empty())
return success(); return success();
@ -37,7 +39,7 @@ static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
// mode. // mode.
pdlPipeline.enableVerifier(false); pdlPipeline.enableVerifier(false);
#endif #endif
pdlPipeline.addPass(createPDLToPDLInterpPass()); pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
if (failed(pdlPipeline.run(pdlModule))) if (failed(pdlPipeline.run(pdlModule)))
return failure(); return failure();
@ -123,13 +125,16 @@ FrozenRewritePatternSet::FrozenRewritePatternSet(
ModuleOp pdlModule = pdlPatterns.getModule(); ModuleOp pdlModule = pdlPatterns.getModule();
if (!pdlModule) if (!pdlModule)
return; return;
if (failed(convertPDLToPDLInterp(pdlModule))) DenseMap<Operation *, PDLPatternConfigSet *> configMap =
pdlPatterns.takeConfigMap();
if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
llvm::report_fatal_error( llvm::report_fatal_error(
"failed to lower PDL pattern module to the PDL Interpreter"); "failed to lower PDL pattern module to the PDL Interpreter");
// Generate the pdl bytecode. // Generate the pdl bytecode.
impl->pdlByteCode = std::make_unique<detail::PDLByteCode>( impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
pdlModule, pdlPatterns.takeConstraintFunctions(), pdlModule, pdlPatterns.takeConfigs(), configMap,
pdlPatterns.takeConstraintFunctions(),
pdlPatterns.takeRewriteFunctions()); pdlPatterns.takeRewriteFunctions());
} }

View File

@ -191,20 +191,21 @@ LogicalResult PatternApplicator::matchAndRewrite(
Operation *dumpRootOp = getDumpRootOp(op); Operation *dumpRootOp = getDumpRootOp(op);
#endif #endif
if (pdlMatch) { if (pdlMatch) {
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
} else { } else {
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");
const auto *pattern = static_cast<const RewritePattern *>(bestPattern); const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
LLVM_DEBUG(llvm::dbgs()
<< "Trying to match \"" << pattern->getDebugName() << "\"\n");
result = pattern->matchAndRewrite(op, rewriter); result = pattern->matchAndRewrite(op, rewriter);
LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
<< succeeded(result) << "\n");
if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
result = failure(); << "\" result " << succeeded(result) << "\n");
} }
// Process the result of the pattern application.
if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
result = failure();
if (succeeded(result)) { if (succeeded(result)) {
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
break; break;

View File

@ -93,10 +93,12 @@ void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
os << "} // end namespace\n\n"; os << "} // end namespace\n\n";
// Emit function to add the generated matchers to the pattern list. // Emit function to add the generated matchers to the pattern list.
os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(" os << "template <typename... ConfigsT>\n"
"::mlir::RewritePatternSet &patterns) {\n"; "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
"::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {\n";
for (const auto &name : patternNames) for (const auto &name : patternNames)
os << " patterns.add<" << name << ">(patterns.getContext());\n"; os << " patterns.add<" << name
<< ">(patterns.getContext(), configs...);\n";
os << "}\n"; os << "}\n";
} }
@ -104,14 +106,15 @@ void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
StringSet<> &nativeFunctions) { StringSet<> &nativeFunctions) {
const char *patternClassStartStr = R"( const char *patternClassStartStr = R"(
struct {0} : ::mlir::PDLPatternModule {{ struct {0} : ::mlir::PDLPatternModule {{
{0}(::mlir::MLIRContext *context) template <typename... ConfigsT>
{0}(::mlir::MLIRContext *context, ConfigsT &&...configs)
: ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
)"; )";
os << llvm::formatv(patternClassStartStr, patternName); os << llvm::formatv(patternClassStartStr, patternName);
os << "R\"mlir("; os << "R\"mlir(";
pattern->print(os, OpPrintingFlags().enableDebugInfo()); pattern->print(os, OpPrintingFlags().enableDebugInfo());
os << "\n )mlir\", context)) {\n"; os << "\n )mlir\", context), std::forward<ConfigsT>(configs)...) {\n";
// Register any native functions used within the pattern. // Register any native functions used within the pattern.
StringSet<> registeredNativeFunctions; StringSet<> registeredNativeFunctions;

View File

@ -3272,6 +3272,76 @@ auto ConversionTarget::getOpInfo(OperationName op) const
return llvm::None; return llvm::None;
} }
//===----------------------------------------------------------------------===//
// PDL Configuration
//===----------------------------------------------------------------------===//
void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
rewriterImpl.currentTypeConverter = getTypeConverter();
}
void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
rewriterImpl.currentTypeConverter = nullptr;
}
/// Remap the given value using the rewriter and the type converter in the
/// provided config.
static FailureOr<SmallVector<Value>>
pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) {
SmallVector<Value> mappedValues;
if (failed(rewriter.getRemappedValues(values, mappedValues)))
return failure();
return std::move(mappedValues);
}
void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerRewriteFunction(
"convertValue",
[](PatternRewriter &rewriter, Value value) -> FailureOr<Value> {
auto results = pdllConvertValues(
static_cast<ConversionPatternRewriter &>(rewriter), value);
if (failed(results))
return failure();
return results->front();
});
patterns.getPDLPatterns().registerRewriteFunction(
"convertValues", [](PatternRewriter &rewriter, ValueRange values) {
return pdllConvertValues(
static_cast<ConversionPatternRewriter &>(rewriter), values);
});
patterns.getPDLPatterns().registerRewriteFunction(
"convertType",
[](PatternRewriter &rewriter, Type type) -> FailureOr<Type> {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
if (TypeConverter *converter = rewriterImpl.currentTypeConverter) {
if (Type newType = converter->convertType(type))
return newType;
return failure();
}
return type;
});
patterns.getPDLPatterns().registerRewriteFunction(
"convertTypes",
[](PatternRewriter &rewriter,
TypeRange types) -> FailureOr<SmallVector<Type>> {
auto &rewriterImpl =
static_cast<ConversionPatternRewriter &>(rewriter).getImpl();
TypeConverter *converter = rewriterImpl.currentTypeConverter;
if (!converter)
return SmallVector<Type>(types);
SmallVector<Type> remappedTypes;
if (failed(converter->convertTypes(types, remappedTypes)))
return failure();
return std::move(remappedTypes);
});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op Conversion Entry Points // Op Conversion Entry Points
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -0,0 +1,18 @@
// RUN: mlir-opt %s -test-dialect-conversion-pdll | FileCheck %s
// CHECK-LABEL: @TestSingleConversion
func.func @TestSingleConversion() {
// CHECK: %[[CAST:.*]] = "test.cast"() : () -> f64
// CHECK-NEXT: "test.return"(%[[CAST]]) : (f64) -> ()
%result = "test.cast"() : () -> (i64)
"test.return"(%result) : (i64) -> ()
}
// CHECK-LABEL: @TestLingeringConversion
func.func @TestLingeringConversion() -> i64 {
// CHECK: %[[ORIG_CAST:.*]] = "test.cast"() : () -> f64
// CHECK: %[[MATERIALIZE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ORIG_CAST]] : f64 to i64
// CHECK-NEXT: return %[[MATERIALIZE_CAST]] : i64
%result = "test.cast"() : () -> (i64)
return %result : i64
}

View File

@ -1,8 +1,18 @@
add_mlir_pdll_library(MLIRTestDialectConversionPDLLPatternsIncGen
TestDialectConversion.pdll
TestDialectConversionPDLLPatterns.h.inc
EXTRA_INCLUDES
${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
)
# Exclude tests from libMLIR.so # Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms add_mlir_library(MLIRTestTransforms
TestCommutativityUtils.cpp TestCommutativityUtils.cpp
TestConstantFold.cpp TestConstantFold.cpp
TestControlFlowSink.cpp TestControlFlowSink.cpp
TestDialectConversion.cpp
TestInlining.cpp TestInlining.cpp
TestIntRangeInference.cpp TestIntRangeInference.cpp
TestTopologicalSort.cpp TestTopologicalSort.cpp
@ -12,8 +22,12 @@ add_mlir_library(MLIRTestTransforms
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms ${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
DEPENDS
MLIRTestDialectConversionPDLLPatternsIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRAnalysis MLIRAnalysis
MLIRFuncDialect
MLIRInferIntRangeInterface MLIRInferIntRangeInterface
MLIRTestDialect MLIRTestDialect
MLIRTransforms MLIRTransforms

View File

@ -0,0 +1,96 @@
//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// Test PDLL Support
//===----------------------------------------------------------------------===//
#include "TestDialectConversionPDLLPatterns.h.inc"
namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Convert I64 to F64.
if (t.isSignlessInteger(64)) {
results.push_back(FloatType::getF64(t.getContext()));
return success();
}
// Otherwise, convert the type directly.
results.push_back(t);
return success();
}
/// Hook for materializing a conversion.
static Optional<Value> materializeCast(OpBuilder &builder, Type resultType,
ValueRange inputs, Location loc) {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
}
};
struct TestDialectConversionPDLLPass
: public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
StringRef getDescription() const final {
return "Test DialectConversion PDLL functionality";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
}
LogicalResult initialize(MLIRContext *ctx) override {
// Build the pattern set within the `initialize` to avoid recompiling PDL
// patterns during each `runOnOperation` invocation.
RewritePatternSet patternList(ctx);
registerConversionPDLFunctions(patternList);
populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
patterns = std::move(patternList);
return success();
}
void runOnOperation() final {
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
target.addDynamicallyLegalDialect<TestDialect>(
[this](Operation *op) { return converter.isLegal(op); });
if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
signalPassFailure();
}
FrozenRewritePatternSet patterns;
PDLLTypeConverter converter;
};
} // namespace
namespace mlir {
namespace test {
void registerTestDialectConversionPasses() {
PassRegistration<TestDialectConversionPDLLPass>();
}
} // namespace test
} // namespace mlir

View File

@ -0,0 +1,19 @@
//===- TestPDLL.pdll - Test PDLL functionality ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestOps.td"
#include "mlir/Transforms/DialectConversion.pdll"
/// Change the result type of a producer.
// FIXME: We shouldn't need to specify arguments for the result cast.
Pattern => replace op<test.cast>(args: ValueRange) -> (results: TypeRange)
with op<test.cast>(args) -> (convertTypes(results));
/// Pass through test.return conversion.
Pattern => replace op<test.return>(args: ValueRange)
with op<test.return>(convertValues(args));

View File

@ -0,0 +1 @@
config.suffixes.remove('.pdll')

View File

@ -5,18 +5,19 @@
// check that we handle overlap. // check that we handle overlap.
// CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule { // CHECK: struct GeneratedPDLLPattern0 : ::mlir::PDLPatternModule {
// CHECK: template <typename... ConfigsT>
// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
// CHECK: R"mlir( // CHECK: R"mlir(
// CHECK: pdl.pattern // CHECK: pdl.pattern
// CHECK: operation "test.op" // CHECK: operation "test.op"
// CHECK: )mlir", context)) // CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
// CHECK: struct NamedPattern : ::mlir::PDLPatternModule { // CHECK: struct NamedPattern : ::mlir::PDLPatternModule {
// CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>( // CHECK: : ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
// CHECK: R"mlir( // CHECK: R"mlir(
// CHECK: pdl.pattern // CHECK: pdl.pattern
// CHECK: operation "test.op2" // CHECK: operation "test.op2"
// CHECK: )mlir", context)) // CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
// CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule { // CHECK: struct GeneratedPDLLPattern1 : ::mlir::PDLPatternModule {
@ -25,13 +26,13 @@
// CHECK: R"mlir( // CHECK: R"mlir(
// CHECK: pdl.pattern // CHECK: pdl.pattern
// CHECK: operation "test.op3" // CHECK: operation "test.op3"
// CHECK: )mlir", context)) // CHECK: )mlir", context), std::forward<ConfigsT>(configs)...)
// CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns) { // CHECK: static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns(::mlir::RewritePatternSet &patterns, ConfigsT &&...configs) {
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext()); // CHECK-NEXT: patterns.add<GeneratedPDLLPattern0>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext()); // CHECK-NEXT: patterns.add<NamedPattern>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext()); // CHECK-NEXT: patterns.add<GeneratedPDLLPattern1>(patterns.getContext(), configs...);
// CHECK-NEXT: patterns.add<GeneratedPDLLPattern2>(patterns.getContext()); // CHECK-NEXT: patterns.add<GeneratedPDLLPattern2>(patterns.getContext(), configs...);
// CHECK-NEXT: } // CHECK-NEXT: }
Pattern => erase op<test.op>; Pattern => erase op<test.op>;

View File

@ -76,6 +76,7 @@ void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass(); void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes(); void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass(); void registerTestDiagnosticsPass();
void registerTestDialectConversionPasses();
void registerTestDominancePass(); void registerTestDominancePass();
void registerTestDynamicPipelinePass(); void registerTestDynamicPipelinePass();
void registerTestExpandMathPass(); void registerTestExpandMathPass();
@ -170,6 +171,7 @@ void registerTestPasses() {
mlir::test::registerTestConstantFold(); mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink(); mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass(); mlir::test::registerTestDiagnosticsPass();
mlir::test::registerTestDialectConversionPasses();
#if MLIR_CUDA_CONVERSIONS_ENABLED #if MLIR_CUDA_CONVERSIONS_ENABLED
mlir::test::registerTestGpuSerializeToCubinPass(); mlir::test::registerTestGpuSerializeToCubinPass();
#endif #endif