Pre Merge pull request !49933 from lanzhineng/notwait

This commit is contained in:
lanzhineng 2023-03-08 12:39:20 +00:00 committed by Gitee
commit d6c4232ed9
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 122 additions and 32 deletions

View File

@ -141,11 +141,29 @@ void AnalysisSchedule::SetNextReady() {
MS_EXCEPTION_IF_NULL(item);
return item->HasResult();
});
if (it == schedule_list_.end()) {
if (IntToSize(infer_thread_count_.load()) >= schedule_list_.size()) {
MS_LOG(DEBUG) << "There is some task to be added. Please wait.";
while (it == schedule_list_.end()) {
if (IntToSize(infer_thread_count_.load()) > schedule_list_.size()) {
MS_LOG(DEBUG) << "There is some task to be added. Please wait. "
<< " infer_count: " << infer_thread_count_.load() << " schedule: " << schedule_list_.size();
return;
}
(void)std::for_each(schedule_list_.begin(), schedule_list_.end(),
[](const auto &item) { MS_LOG(DEBUG) << "Leave infer thread: " << item->thread_id(); });
if (enable_waiting_branch_eval()) {
// Try to set one of possible result.
auto possible_it = std::find_if(schedule_list_.cbegin(), schedule_list_.cend(), [](const auto &item) {
MS_EXCEPTION_IF_NULL(item);
return item->SetPossibleResult();
});
if (possible_it != schedule_list_.end()) {
MS_LOG(DEBUG) << "Try to set one branch result from the other branch. " << (*possible_it)->thread_id()
<< " result: " << (*possible_it)->HasResult();
it = possible_it;
break;
}
}
// Enter endless loop if there is not ready result.
(void)activate_threads_.insert(schedule_list_.front()->thread_id());
// Let the first thread to trigger endless loop exception.
@ -155,6 +173,7 @@ void AnalysisSchedule::SetNextReady() {
schedule_list_.pop_front();
return;
}
auto async_task = *it;
(void)activate_threads_.insert(async_task->thread_id());
async_task->SetReady();
@ -166,6 +185,7 @@ void AnalysisSchedule::SetNextReady() {
}
AbstractBasePtr AsyncAbstract::GetResult() {
ClearPossibleResult();
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
@ -173,6 +193,25 @@ AbstractBasePtr AsyncAbstract::GetResult() {
MS_LOG(DEBUG) << GetInferThread() << " success to get async result: " << async_task.get() << " " << ret->ToString();
return ret;
}
void AsyncAbstract::ClearPossibleResult() {
std::lock_guard<std::mutex> lock(lock_);
if (result_ != nullptr && result_->isa<AsyncAbstractFuncAtom>()) {
result_ = nullptr;
}
}
bool AsyncAbstract::SetPossibleResult() {
std::lock_guard<std::mutex> lock(lock_);
if (not_copy_from_other_ && switchAbstract_ != nullptr && switchAbstract_->HasResult()) {
result_ = switchAbstract_->TryGetResult();
if (NeedWaitForBranches(result_)) {
result_ = AsyncAbstractFuncAtom::MakeShared(shared_from_this(), std::vector<size_t>{0});
}
not_copy_from_other_ = false;
return true;
}
return false;
}
namespace {
AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const std::vector<std::size_t> &index,
@ -202,6 +241,22 @@ AbstractFunctionPtr GetAbstractFuncRecursively(const AbstractBasePtr &abs, const
<< abs->ToString();
}
} // namespace
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
return true;
}
if (abstract->isa<AbstractSequence>()) {
auto seq = abstract->cast_ptr<AbstractSequence>();
MS_EXCEPTION_IF_NULL(seq);
auto elements = seq->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
}
}
return false;
}
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
if (resolved_ != nullptr) {
@ -301,5 +356,10 @@ std::string ArgsToString(const AbstractBasePtrList &args_abs_list) {
}
return buffer.str();
}
bool enable_waiting_branch_eval() {
static std::string ms_env = common::GetEnv("MS_DEV_NOT_WAIT_BRANCH_EVAL");
static bool enable_waiting_branch_eval_ = ms_env != "1";
return enable_waiting_branch_eval_;
}
} // namespace abstract
} // namespace mindspore

View File

@ -40,6 +40,7 @@ namespace abstract {
class AsyncInferTask;
class AsyncAbstract;
class AsyncAbstractFuncAtom;
using AsyncInferTaskPtr = std::shared_ptr<AsyncInferTask>;
using AsyncAbstractPtr = std::shared_ptr<AsyncAbstract>;
class AnalysisSchedule {
@ -212,7 +213,7 @@ class NormalCache {
class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
public:
AsyncAbstract() = default;
explicit AsyncAbstract(std::shared_ptr<AsyncAbstract> switchAbstract = nullptr) : switchAbstract_(switchAbstract) {}
~AsyncAbstract() = default;
AbstractBasePtr GetResult();
AbstractBasePtr TryGetResult() {
@ -224,10 +225,12 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
return result_ != nullptr;
}
void set_result(const AbstractBasePtr &result) {
MS_EXCEPTION_IF_NULL(result);
std::lock_guard<std::mutex> lock(lock_);
result_ = result;
}
void ClearPossibleResult();
std::string ToString() {
std::ostringstream buffer;
std::lock_guard<std::mutex> lock(lock_);
@ -235,9 +238,13 @@ class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
return buffer.str();
}
bool SetPossibleResult();
private:
std::mutex lock_;
AbstractBasePtr result_{nullptr};
bool not_copy_from_other_{true};
std::shared_ptr<AsyncAbstract> switchAbstract_;
};
// Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
@ -320,6 +327,7 @@ class AsyncInferTask {
}
bool HasResult() { return abstract_ptr_->HasResult(); }
bool SetPossibleResult() { return abstract_ptr_->SetPossibleResult(); }
int ready() {
std::lock_guard<std::mutex> lock(lock_);
return SizeToInt(ready_);
@ -454,6 +462,8 @@ class AnalysisResultCacheMgr {
};
std::string ArgsToString(const AbstractBasePtrList &args_abs_list);
bool enable_waiting_branch_eval();
bool NeedWaitForBranches(const AbstractBasePtr &abstract);
inline std::string GetInferThread() { return std::string(" INFER:") + AnalysisSchedule::thread_id() + ":"; }
} // namespace abstract

View File

@ -44,7 +44,9 @@ std::atomic<size_t> function_call_depth;
std::atomic<size_t> stack_frame_depth;
void ResetFunctionCallDepth() { function_call_depth = 0; }
void IncreaseFunctionCallDepth() { ++function_call_depth; }
void DecreaseFunctionCallDepth() {
if (function_call_depth == 0) {
MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it.";
@ -55,31 +57,19 @@ void DecreaseFunctionCallDepth() {
size_t FunctionCallDepth() { return function_call_depth; }
void ResetStackFrameDepth() { stack_frame_depth = 0; }
void IncreaseStackFrameDepth() { ++stack_frame_depth; }
void DecreaseStackFrameDepth() {
if (stack_frame_depth == 0) {
MS_LOG(EXCEPTION) << "Current stack frame depth is already 0, can not decrease it.";
}
stack_frame_depth--;
}
size_t StackFrameDepth() { return stack_frame_depth; }
namespace {
bool NeedWaitForBranches(const AbstractBasePtr &abstract) {
MS_EXCEPTION_IF_NULL(abstract);
if (abstract->isa<AbstractFunction>()) {
return true;
}
if (abstract->isa<AbstractSequence>()) {
auto elements = abstract->cast_ptr<AbstractSequence>()->elements();
if (std::any_of(elements.begin(), elements.end(),
[](const AbstractBasePtr &item) { return NeedWaitForBranches(item); })) {
return true;
}
}
return false;
}
void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList args_conf_list, AnfNodeConfigPtr out_conf,
std::string thread_id, AsyncAbstractPtr async_result_branch, AsyncAbstractPtr async_result_main,
AsyncInferTaskPtr async_task, trace::TraceGraphEvalStack graph_evals,
@ -100,11 +90,11 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
// Acquire GIL for eval to callback python.
EvalResultPtr result;
{
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " begin.";
py::gil_scoped_acquire py_guard;
result = eval->Run(engine, args_conf_list, out_conf);
}
MS_LOG(DEBUG) << std::this_thread::get_id() << " end.";
MS_LOG(DEBUG) << eval->ToString() << "_" << AnalysisSchedule::thread_id() << " end.";
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(result->abstract());
@ -112,6 +102,7 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
AnalysisResultCacheMgr::GetInstance().CheckSwitchValueJoinable(out_conf, result->abstract());
// Broaden the result of switch(c,t,f)()
auto broaden_abstract = result->abstract()->Broaden();
// Notify the thread of waiting for branch value and the main thread to continue.
async_result_branch->set_result(broaden_abstract);
async_result_main->set_result(broaden_abstract);
@ -119,7 +110,8 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
<< " asyncResult address = " << async_result_branch.get()
<< " value = " << async_result_branch->TryGetResult()->ToString();
} catch (const std::exception &ex) {
MS_LOG(INFO) << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString() << " threw exception.";
MS_LOG(INFO) << GetInferThread() << "Eval node: " << out_conf->node()->ToString() << " " << eval->ToString()
<< " threw exception: " << ex.what();
AnalysisSchedule::GetInstance().HandleException(ex);
}
trace::ClearTraceStack();
@ -179,17 +171,34 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result,
std::size_t len = branch_async_abstract_list.size();
for (size_t i = 0; i < len; ++i) {
auto result = branch_async_abstract_list[i]->TryGetResult();
AbstractBasePtr result;
if (enable_waiting_branch_eval()) {
result = branch_async_abstract_list[i]->GetResult();
} else {
result = branch_async_abstract_list[i]->TryGetResult();
}
if (result) {
out_specs->push_back(result);
if (result->isa<AsyncAbstractFuncAtom>()) {
branch_async_abstract_list[i]->ClearPossibleResult();
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
<< branch_async_abstract_list[i]->ToString();
} else {
out_specs->push_back(result);
}
} else {
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
MS_LOG(DEBUG) << "Pending add: " << branch_async_abstract_list[i].get() << "_"
<< branch_async_abstract_list[i]->ToString();
}
}
if (first_result->isa<AbstractFunction>()) {
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], std::vector<size_t>{0});
out_specs->push_back(async_func);
MS_LOG(DEBUG) << "out_specs add: " << async_func.get() << "_" << async_func->ToString();
}
} else if (first_result->isa<AbstractSequence>()) {
const auto &new_first_result =
@ -203,7 +212,6 @@ void BuildPossibleSpecs(const AbstractBasePtr &first_result,
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
}
}
void CheckInterpretedObject(const AbstractBasePtr &abs) {
static const auto support_fallback = common::GetEnv("MS_DEV_ENABLE_FALLBACK");
static const auto use_fallback = (support_fallback != "0");
@ -936,6 +944,7 @@ EvaluatorPtr AnalysisEngine::_GetEvaluatorFor(const std::shared_ptr<PartialAbstr
EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
MS_EXCEPTION_IF_NULL(func);
MS_LOG(DEBUG) << "GetEvaluatorFor: " << func->ToString() << " tracking_id: " << func->tracking_id();
if (func->isa<PrimitiveAbstractClosure>()) {
return _GetEvaluatorFor(std::static_pointer_cast<PrimitiveAbstractClosure>(func));
}
@ -963,6 +972,7 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) {
if (func->isa<PartialAbstractClosure>()) {
return _GetEvaluatorFor(std::static_pointer_cast<PartialAbstractClosure>(func));
}
MS_LOG(EXCEPTION) << "Cannot GetEvaluator from " << func->type_name();
}
@ -1161,7 +1171,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
// Release GIL for C++
MS_LOG(DEBUG) << std::this_thread::get_id() << " begin.";
MS_LOG(DEBUG) << out_conf->func_graph()->ToString() << "_" << std::this_thread::get_id() << " begin.";
py::gil_scoped_release infer_gil_release;
// Only one thread to run
@ -1187,7 +1197,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
static std::atomic<int> id_count{0};
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
MS_EXCEPTION_IF_NULL(evaluator);
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>(async_result_main);
// Control the order to run.
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
control_run_order->set_result(std::make_shared<AbstractScalar>(1));
@ -1216,11 +1226,19 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
AbstractBasePtrList out_specs;
size_t len = evaluators.size();
if (NeedWaitForBranches(first_result)) {
MS_LOG(DEBUG) << GetInferThread() << " BuildPossibleSpecs.";
BuildPossibleSpecs(first_result, async_result_branches, &out_specs);
} else {
for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch.
auto result = async_result_branches[i]->TryGetResult();
AbstractBasePtr result;
if (enable_waiting_branch_eval()) {
// wait to get the result of branch.
result = async_result_branches[i]->GetResult();
} else {
// Not wait to get the result of branch.
result = async_result_branches[i]->TryGetResult();
}
if (result) {
MS_LOG(DEBUG) << "#" << i << ": " << GetInferThread() << " async get " << evaluators[i]->ToString()
<< ", result: " << result->ToString() << ", args: " << args_conf_list;
@ -1228,13 +1246,13 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
}
}
}
MS_LOG(DEBUG) << std::this_thread::get_id() << " finish.";
MS_LOG(DEBUG) << GetInferThread() << " finish.";
const auto &processed_result = ProcessEvalResults(out_specs, out_conf->node());
if (processed_result != nullptr) {
// This is the final switch()() value.
AnalysisResultCacheMgr::GetInstance().SetSwitchValue(out_conf, processed_result->abstract());
}
MS_LOG(DEBUG) << GetInferThread() << "join finish.";
return processed_result;
}

View File

@ -147,6 +147,8 @@ def test_branch_value_compatible():
try:
forward_net(x, y, i)
except RuntimeError as e:
assert 'limit' in str(e)
except ValueError as e:
assert 'Join Failed' in str(e)