[mlir] Add getNumThreads to MLIRContext

Querying threads directly from the thread pool fails if there is no thread pool or if multithreading is not enabled. Returns 1 by default.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D116259
This commit is contained in:
Mogball 2021-12-24 01:41:21 +00:00
parent a2baf634a1
commit 41a64338cc
4 changed files with 19 additions and 2 deletions

View File

@ -147,6 +147,13 @@ public:
/// this call in this case.
void setThreadPool(llvm::ThreadPool &pool);
/// Return the number of threads used by the thread pool in this context. The
/// number of computed hardware threads can change over the lifetime of a
/// process based on affinity changes, so users should use the number of
/// threads actually in the thread pool for dispatching work. Returns 1 if
/// multithreading is disabled.
unsigned getNumThreads();
/// Return the thread pool used by this context. This method requires that
/// multithreading be enabled within the context, and should generally not be
/// used directly. Users should instead prefer the threading utilities within

View File

@ -518,6 +518,16 @@ void MLIRContext::setThreadPool(llvm::ThreadPool &pool) {
enableMultithreading();
}
unsigned MLIRContext::getNumThreads() {
if (isMultithreadingEnabled()) {
assert(impl->threadPool &&
"multi-threading is enabled but threadpool not set");
return impl->threadPool->getThreadCount();
}
// No multithreading or active thread pool. Return 1 thread.
return 1;
}
llvm::ThreadPool &MLIRContext::getThreadPool() {
assert(isMultithreadingEnabled() &&
"expected multi-threading to be enabled within the context");

View File

@ -679,8 +679,7 @@ InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
// Note: The number of pass managers here needs to remain constant
// to prevent issues with pass instrumentations that rely on having the same
// pass manager for the main thread.
llvm::ThreadPool &threadPool = ctx->getThreadPool();
size_t numThreads = threadPool.getThreadCount();
size_t numThreads = ctx->getNumThreads();
if (opPipelines.size() < numThreads) {
// Reserve before resizing so that we can use a reference to the first
// element.

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
// RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY