179 lines
7.0 KiB
C++
179 lines
7.0 KiB
C++
//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to deduce minimal version/extension/capability
|
|
// requirements for a spirv::ModuleOp.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Pass to deduce minimal version/extension/capability requirements for a
|
|
/// spirv::ModuleOp.
|
|
class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
/// Checks that `candidates` extension requirements are possible to be satisfied
|
|
/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
|
|
/// errors attaching to the given `op` on failures.
|
|
///
|
|
/// `candidates` is a vector of vector for extension requirements following
|
|
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
|
/// convention.
|
|
static LogicalResult checkAndUpdateExtensionRequirements(
|
|
Operation *op, const spirv::TargetEnv &targetEnv,
|
|
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
|
|
llvm::SetVector<spirv::Extension> &deducedExtensions) {
|
|
for (const auto &ors : candidates) {
|
|
if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
|
|
deducedExtensions.insert(*chosen);
|
|
} else {
|
|
SmallVector<StringRef, 4> extStrings;
|
|
for (spirv::Extension ext : ors)
|
|
extStrings.push_back(spirv::stringifyExtension(ext));
|
|
|
|
return op->emitError("'")
|
|
<< op->getName() << "' requires at least one extension in ["
|
|
<< llvm::join(extStrings, ", ")
|
|
<< "] but none allowed in target environment";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Checks that `candidates`capability requirements are possible to be satisfied
|
|
/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
|
|
/// errors attaching to the given `op` on failures.
|
|
///
|
|
/// `candidates` is a vector of vector for capability requirements following
|
|
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
|
|
/// convention.
|
|
static LogicalResult checkAndUpdateCapabilityRequirements(
|
|
Operation *op, const spirv::TargetEnv &targetEnv,
|
|
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
|
|
llvm::SetVector<spirv::Capability> &deducedCapabilities) {
|
|
for (const auto &ors : candidates) {
|
|
if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
|
|
deducedCapabilities.insert(*chosen);
|
|
} else {
|
|
SmallVector<StringRef, 4> capStrings;
|
|
for (spirv::Capability cap : ors)
|
|
capStrings.push_back(spirv::stringifyCapability(cap));
|
|
|
|
return op->emitError("'")
|
|
<< op->getName() << "' requires at least one capability in ["
|
|
<< llvm::join(capStrings, ", ")
|
|
<< "] but none allowed in target environment";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void UpdateVCEPass::runOnOperation() {
|
|
spirv::ModuleOp module = getOperation();
|
|
|
|
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
|
|
if (!targetAttr) {
|
|
module.emitError("missing 'spv.target_env' attribute");
|
|
return signalPassFailure();
|
|
}
|
|
|
|
spirv::TargetEnv targetEnv(targetAttr);
|
|
spirv::Version allowedVersion = targetAttr.getVersion();
|
|
|
|
spirv::Version deducedVersion = spirv::Version::V_1_0;
|
|
llvm::SetVector<spirv::Extension> deducedExtensions;
|
|
llvm::SetVector<spirv::Capability> deducedCapabilities;
|
|
|
|
// Walk each SPIR-V op to deduce the minimal version/extension/capability
|
|
// requirements.
|
|
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
|
|
// Op min version requirements
|
|
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
|
|
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
|
|
if (deducedVersion > allowedVersion) {
|
|
return op->emitError("'") << op->getName() << "' requires min version "
|
|
<< spirv::stringifyVersion(deducedVersion)
|
|
<< " but target environment allows up to "
|
|
<< spirv::stringifyVersion(allowedVersion);
|
|
}
|
|
}
|
|
|
|
// Op extension requirements
|
|
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
|
|
if (failed(checkAndUpdateExtensionRequirements(
|
|
op, targetEnv, extensions.getExtensions(), deducedExtensions)))
|
|
return WalkResult::interrupt();
|
|
|
|
// Op capability requirements
|
|
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
|
|
if (failed(checkAndUpdateCapabilityRequirements(
|
|
op, targetEnv, capabilities.getCapabilities(),
|
|
deducedCapabilities)))
|
|
return WalkResult::interrupt();
|
|
|
|
SmallVector<Type, 4> valueTypes;
|
|
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
|
|
valueTypes.append(op->result_type_begin(), op->result_type_end());
|
|
|
|
// Special treatment for global variables, whose type requirements are
|
|
// conveyed by type attributes.
|
|
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
|
valueTypes.push_back(globalVar.type());
|
|
|
|
// Requirements from values' types
|
|
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
|
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
|
|
for (Type valueType : valueTypes) {
|
|
typeExtensions.clear();
|
|
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
|
|
if (failed(checkAndUpdateExtensionRequirements(
|
|
op, targetEnv, typeExtensions, deducedExtensions)))
|
|
return WalkResult::interrupt();
|
|
|
|
typeCapabilities.clear();
|
|
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
|
|
if (failed(checkAndUpdateCapabilityRequirements(
|
|
op, targetEnv, typeCapabilities, deducedCapabilities)))
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
if (walkResult.wasInterrupted())
|
|
return signalPassFailure();
|
|
|
|
// TODO: verify that the deduced version is consistent with
|
|
// SPIR-V ops' maximal version requirements.
|
|
|
|
auto triple = spirv::VerCapExtAttr::get(
|
|
deducedVersion, deducedCapabilities.getArrayRef(),
|
|
deducedExtensions.getArrayRef(), &getContext());
|
|
module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<spirv::ModuleOp>>
|
|
mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
|
|
return std::make_unique<UpdateVCEPass>();
|
|
}
|