Pre Merge pull request !49933 from lanzhineng/notwait
This commit is contained in:
commit
d6c4232ed9
|
@ -141,11 +141,29 @@ void AnalysisSchedule::SetNextReady() {
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
return item->HasResult();
|
||||
});
|
||||
if (it == schedule_list_.end()) {
|
||||
if (IntToSize(infer_thread_count_.load()) >= schedule_list_.size()) {
|
||||
MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
|
||||
while (it == schedule_list_.end()) {
|
||||
if (IntToSize(infer_thread_count_.load()) > schedule_list_.size()) {
|
||||
MS_LOG(DEBUG) << "There is some task to be added. Please wait. "
|
||||
<< " infer_count: " << infer_thread_count_.load() << " schedule: " << schedule_list_.size();
|
||||
return;
|
||||
}
|
||||
|
||||
(void)std::for_each(schedule_list_.begin(), schedule_list_.end(),
|
||||
[](const auto &item) { MS_LOG(DEBUG) << "Leave infer thread: " << item->thread_id(); });
|
||||
if (enable_waiting_branch_eval()) {
|
||||
// Try to set one of possible result.
|
||||
auto possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
return item->SetPossibleResult();
|
||||
});
|
||||
if (possible_it != schedule_list_.end()) {
|
||||
MS_LOG(DEBUG) << "Try to set one branch result from the other branch. " << (*possible_it)->thread_id()
|
||||
<< " result: " << (*possible_it)->HasResult();
|
||||
it = possible_it;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Enter endless loop if there is not ready result.
|
||||
(void)activate_threads_.insert(schedule_list_.front()->thread_id());
|
||||
// Let the first thread to trigger endless loop exception.
|
||||
|
@ -155,6 +173,7 @@ void AnalysisSchedule::SetNextReady() {
|
|||
schedule_list_.pop_front();
|
||||
return;
|
||||
}
|
||||
|
||||
auto async_task = *it;
|
||||
(void)activate_threads_.insert(async_task->thread_id());
|
||||
async_task->SetReady();
|
||||
|
@ -166,6 +185,7 @@ void AnalysisSchedule::SetNextReady() {
|
|||
}
|
||||
|
||||
AbstractBasePtr AsyncAbstract::GetResult() {
|
||||
ClearPossibleResult();
|
||||
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
|
||||
|
@ -173,6 +193,25 @@ AbstractBasePtr AsyncAbstract::GetResult() {
|
|||
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
|
||||
return ret;
|
||||
}
|
||||
void AsyncAbstract::ClearPossibleResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
if (result_ != nullptr && result_->isa<AsyncAbstractFuncAtom>()) {
|
||||
result_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
bool AsyncAbstract::SetPossibleResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
if (not_copy_from_other_ && switchAbstract_ != nullptr && switchAbstract_->HasResult()) {
|
||||
result_ = switchAbstract_->TryGetResult();
|
||||
if (NeedWaitForBranches(result_)) {
|
||||
result_ = AsyncAbstractFuncAtom::MakeShared(shared_from_this(), std::vector<size_t>{0});
|
||||
}
|
||||
not_copy_from_other_ = false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const std::vector<std::size_t> &index,
|
||||
|
@ -202,6 +241,22 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
|
|||
<< abs->ToString();
|
||||
}
|
||||
} // namespace
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto seq = abstract->cast_ptr<AbstractSequence>();
|
||||
MS_EXCEPTION_IF_NULL(seq);
|
||||
auto elements = seq->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
|
||||
if (resolved_ != nullptr) {
|
||||
|
@ -301,5 +356,10 @@ std::string ArgsToString(const AbstractBasePtrList &args_abs_list) {
|
|||
}
|
||||
return buffer.str();
|
||||
}
|
||||
bool enable_waiting_branch_eval() {
|
||||
static std::string ms_env = common::GetEnv("MS_DEV_NOT_WAIT_BRANCH_EVAL");
|
||||
static bool enable_waiting_branch_eval_ = ms_env != "1";
|
||||
return enable_waiting_branch_eval_;
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,7 @@ namespace abstract {
|
|||
|
||||
class AsyncInferTask;
|
||||
class AsyncAbstract;
|
||||
class AsyncAbstractFuncAtom;
|
||||
using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
|
||||
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
|
||||
class AnalysisSchedule {
|
||||
|
@ -212,7 +213,7 @@ class NormalCache {
|
|||
|
||||
class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
||||
public:
|
||||
AsyncAbstract() = default;
|
||||
explicit AsyncAbstract(std::shared_ptr<AsyncAbstract> switchAbstract = nullptr) : switchAbstract_(switchAbstract) {}
|
||||
~AsyncAbstract() = default;
|
||||
AbstractBasePtr GetResult();
|
||||
AbstractBasePtr TryGetResult() {
|
||||
|
@ -224,10 +225,12 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
return result_ != nullptr;
|
||||
}
|
||||
void set_result(const AbstractBasePtr &result) {
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
result_ = result;
|
||||
}
|
||||
|
||||
void ClearPossibleResult();
|
||||
|
||||
std::string ToString() {
|
||||
std::ostringstream buffer;
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
|
@ -235,9 +238,13 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
|||
return buffer.str();
|
||||
}
|
||||
|
||||
bool SetPossibleResult();
|
||||
|
||||
private:
|
||||
std::mutex lock_;
|
||||
AbstractBasePtr result_{nullptr};
|
||||
bool not_copy_from_other_{true};
|
||||
std::shared_ptr<AsyncAbstract> switchAbstract_;
|
||||
};
|
||||
|
||||
// Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
|
||||
|
@ -320,6 +327,7 @@ class AsyncInferTask {
|
|||
}
|
||||
|
||||
bool HasResult() { return abstract_ptr_->HasResult(); }
|
||||
bool SetPossibleResult() { return abstract_ptr_->SetPossibleResult(); }
|
||||
int ready() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return SizeToInt(ready_);
|
||||
|
@ -454,6 +462,8 @@ class AnalysisResultCacheMgr {
|
|||
};
|
||||
|
||||
std::string ArgsToString(const AbstractBasePtrList &args_abs_list);
|
||||
bool enable_waiting_branch_eval();
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract);
|
||||
|
||||
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::thread_id() + ":"; }
|
||||
} // namespace abstract
|
||||
|
|
|
@ -44,7 +44,9 @@ std::atomic<size_t> function_call_depth;
|
|||
std::atomic<size_t> stack_frame_depth;
|
||||
|
||||
void ResetFunctionCallDepth() { function_call_depth = 0; }
|
||||
|
||||
void IncreaseFunctionCallDepth() { ++function_call_depth; }
|
||||
|
||||
void DecreaseFunctionCallDepth() {
|
||||
if (function_call_depth == 0) {
|
||||
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
|
||||
|
@ -55,31 +57,19 @@ void DecreaseFunctionCallDepth() {
|
|||
size_t FunctionCallDepth() { return function_call_depth; }
|
||||
|
||||
void ResetStackFrameDepth() { stack_frame_depth = 0; }
|
||||
|
||||
void IncreaseStackFrameDepth() { ++stack_frame_depth; }
|
||||
|
||||
void DecreaseStackFrameDepth() {
|
||||
if (stack_frame_depth == 0) {
|
||||
MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
|
||||
}
|
||||
stack_frame_depth--;
|
||||
}
|
||||
|
||||
size_t StackFrameDepth() { return stack_frame_depth; }
|
||||
|
||||
namespace {
|
||||
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
if (abstract->isa<AbstractFunction>()) {
|
||||
return true;
|
||||
}
|
||||
if (abstract->isa<AbstractSequence>()) {
|
||||
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
|
||||
if (std::any_of(elements.begin(), elements.end(),
|
||||
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
|
||||
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
|
||||
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
|
||||
|
@ -100,11 +90,11 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
// Acquire GIL for eval to callback python.
|
||||
EvalResultPtr result;
|
||||
{
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
|
||||
MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " begin.";
|
||||
py::gil_scoped_acquire py_guard;
|
||||
result = eval->Run(engine, args_conf_list, out_conf);
|
||||
}
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
|
||||
MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " end.";
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(result->abstract());
|
||||
|
||||
|
@ -112,6 +102,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
|
||||
// Broaden the result of switch(c,t,f)()
|
||||
auto broaden_abstract = result->abstract()->Broaden();
|
||||
|
||||
// Notify the thread of waiting for branch value and the main thread to continue.
|
||||
async_result_branch->set_result(broaden_abstract);
|
||||
async_result_main->set_result(broaden_abstract);
|
||||
|
@ -119,7 +110,8 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
<< " asyncResult address = " << async_result_branch.get()
|
||||
<< " value = " << async_result_branch->TryGetResult()->ToString();
|
||||
} catch (const std::exception &ex) {
|
||||
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
|
||||
MS_LOG(INFO) << GetInferThread() << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString()
|
||||
<< " threw exception: " << ex.what();
|
||||
AnalysisSchedule::GetInstance().HandleException(ex);
|
||||
}
|
||||
trace::ClearTraceStack();
|
||||
|
@ -179,17 +171,34 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
|||
std::size_t len = branch_async_abstract_list.size();
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
auto result = branch_async_abstract_list[i]->TryGetResult();
|
||||
AbstractBasePtr result;
|
||||
if (enable_waiting_branch_eval()) {
|
||||
result = branch_async_abstract_list[i]->GetResult();
|
||||
} else {
|
||||
result = branch_async_abstract_list[i]->TryGetResult();
|
||||
}
|
||||
|
||||
if (result) {
|
||||
out_specs->push_back(result);
|
||||
if (result->isa<AsyncAbstractFuncAtom>()) {
|
||||
branch_async_abstract_list[i]->ClearPossibleResult();
|
||||
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
|
||||
MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
|
||||
<< branch_async_abstract_list[i]->ToString();
|
||||
} else {
|
||||
out_specs->push_back(result);
|
||||
}
|
||||
} else {
|
||||
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
|
||||
MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
|
||||
<< branch_async_abstract_list[i]->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
if (first_result->isa<AbstractFunction>()) {
|
||||
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
|
||||
out_specs->push_back(async_func);
|
||||
MS_LOG(DEBUG) << "out_specs add: " << async_func.get() << "_" << async_func->ToString();
|
||||
}
|
||||
} else if (first_result->isa<AbstractSequence>()) {
|
||||
const auto &new_first_result =
|
||||
|
@ -203,7 +212,6 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
|||
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
|
||||
}
|
||||
}
|
||||
|
||||
void CheckInterpretedObject(const AbstractBasePtr &abs) {
|
||||
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
|
||||
static const auto use_fallback = (support_fallback != "0");
|
||||
|
@ -936,6 +944,7 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
|
|||
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
MS_LOG(DEBUG) << "GetEvaluatorFor: " << func->ToString() << " tracking_id: " << func->tracking_id();
|
||||
|
||||
if (func->isa<PrimitiveAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(std::static_pointer_cast<PrimitiveAbstractClosure>(func));
|
||||
}
|
||||
|
@ -963,6 +972,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
|
|||
if (func->isa<PartialAbstractClosure>()) {
|
||||
return _GetEvaluatorFor(std::static_pointer_cast<PartialAbstractClosure>(func));
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
|
||||
}
|
||||
|
||||
|
@ -1161,7 +1171,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
MS_EXCEPTION_IF_NULL(out_conf->node());
|
||||
// Release GIL for C++
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
|
||||
MS_LOG(DEBUG) << out_conf->func_graph()->ToString() << "_" << std::this_thread::get_id() << " begin.";
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
|
||||
// Only one thread to run
|
||||
|
@ -1187,7 +1197,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
static std::atomic<int> id_count{0};
|
||||
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
|
||||
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>(async_result_main);
|
||||
// Control the order to run.
|
||||
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
|
||||
control_run_order->set_result(std::make_shared<AbstractScalar>(1));
|
||||
|
@ -1216,11 +1226,19 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
AbstractBasePtrList out_specs;
|
||||
size_t len = evaluators.size();
|
||||
if (NeedWaitForBranches(first_result)) {
|
||||
MS_LOG(DEBUG) << GetInferThread() << " BuildPossibleSpecs.";
|
||||
BuildPossibleSpecs(first_result, async_result_branches, &out_specs);
|
||||
} else {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// Not wait to get the result of branch.
|
||||
auto result = async_result_branches[i]->TryGetResult();
|
||||
AbstractBasePtr result;
|
||||
if (enable_waiting_branch_eval()) {
|
||||
// wait to get the result of branch.
|
||||
result = async_result_branches[i]->GetResult();
|
||||
} else {
|
||||
// Not wait to get the result of branch.
|
||||
result = async_result_branches[i]->TryGetResult();
|
||||
}
|
||||
|
||||
if (result) {
|
||||
MS_LOG(DEBUG) << "#" << i << ": " << GetInferThread() << " async get " << evaluators[i]->ToString()
|
||||
<< ", result: " << result->ToString() << ", args: " << args_conf_list;
|
||||
|
@ -1228,13 +1246,13 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << std::this_thread::get_id() << " finish.";
|
||||
MS_LOG(DEBUG) << GetInferThread() << " finish.";
|
||||
const auto &processed_result = ProcessEvalResults(out_specs, out_conf->node());
|
||||
if (processed_result != nullptr) {
|
||||
// This is the final switch()() value.
|
||||
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, processed_result->abstract());
|
||||
}
|
||||
MS_LOG(DEBUG) << GetInferThread() << "join finish.";
|
||||
return processed_result;
|
||||
}
|
||||
|
||||
|
|
|
@ -147,6 +147,8 @@ def test_branch_value_compatible():
|
|||
|
||||
try:
|
||||
forward_net(x, y, i)
|
||||
except RuntimeError as e:
|
||||
assert 'limit' in str(e)
|
||||
except ValueError as e:
|
||||
assert 'Join Failed' in str(e)
|
||||
|
||||
|
|
Loading…
Reference in New Issue