!9880 change cast format test

From: @lianliguang
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2020-12-18 09:49:45 +08:00 committed by Gitee
commit d988e13fb5
3 changed files with 80 additions and 22 deletions

View File

@ -219,6 +219,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
} else {
data_layout_pm->AddPass(std::make_shared<MergeCastToOp>());
data_layout_pm->AddPass(std::make_shared<ConvertCastFormat>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
}
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());

View File

@ -16,8 +16,11 @@
#include "backend/optimizer/ascend/format_type/convert_cast_format.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
const BaseRef ConvertCastFormat::DefinePattern() const {
@ -26,8 +29,8 @@ const BaseRef ConvertCastFormat::DefinePattern() const {
return VectorRef({X, Xs});
}
const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &) const {
const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node)) {
return nullptr;
}
@ -44,26 +47,77 @@ const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, con
continue;
}
auto cast_node = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_node);
auto input_node_name = AnfAlgo::GetCNodeName(cast_node);
if (input_node_name != prim::kPrimCast->name()) {
continue;
}
auto format = AnfAlgo::GetInputFormat(node, input_index);
auto cast_input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_node, 0), 0).first;
auto cast_input_format = AnfAlgo::GetOutputFormat(cast_input_node, 0);
// change cast to default that can be more faster when it cast other hw format
if (cast_input_format != format) {
if (cast_input_format == kOpFormat_DEFAULT || format == kOpFormat_DEFAULT) {
auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
AnfAlgo::GetSelectKernelBuildInfo(cast_node));
info_builder->SetInputsFormat({kOpFormat_DEFAULT});
info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
}
}
ChangeCastFormat(cast_node, func_graph);
}
return nullptr;
}
void ConvertCastFormat::SetCastFormat(const CNodePtr &cast_node, const string &format) const {
auto info_builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(cast_node));
info_builder->SetInputsFormat({format});
info_builder->SetOutputsFormat({format});
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
}
void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const {
MS_EXCEPTION_IF_NULL(cast_node);
auto input_node_name = AnfAlgo::GetCNodeName(cast_node);
if (input_node_name != prim::kPrimCast->name()) {
return;
}
if (AnfAlgo::HasNodeAttr(kAttrVisited, cast_node) && AnfAlgo::GetNodeAttr<bool>(cast_node, kAttrVisited)) {
return;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node);
auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node);
MS_EXCEPTION_IF_NULL(used_cast_node_list);
std::unordered_map<string, size_t> format_counter;
for (const auto &node_info : *used_cast_node_list) {
MS_EXCEPTION_IF_NULL(node_info.first);
auto cast_out_node = node_info.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cast_out_node);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) {
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
cast_node) {
continue;
}
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
auto it = format_counter.find(format);
if (it == format_counter.end()) {
format_counter[format] = 1;
} else {
it->second++;
}
}
}
auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0);
string convert_format = kOpFormat_DEFAULT;
if (cast_input_format == kOpFormat_DEFAULT) {
SetCastFormat(cast_node, convert_format);
return;
}
if (format_counter.size() == 1 && format_counter.begin()->first == kOpFormat_DEFAULT) {
SetCastFormat(cast_node, convert_format);
return;
}
auto it = format_counter.find(cast_input_format);
if (it == format_counter.end()) {
format_counter[cast_input_format] = 1;
} else {
it->second++;
}
if (format_counter.size() < 2) {
size_t max_counter = 0;
for (const auto &iter : format_counter) {
if (iter.second > max_counter) {
max_counter = iter.second;
convert_format = iter.first;
}
}
// change cast to default that can be more faster when it cast other hw format
SetCastFormat(cast_node, convert_format);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
@ -25,10 +27,11 @@ class ConvertCastFormat : public PatternProcessPass {
explicit ConvertCastFormat(bool multigraph = true) : PatternProcessPass("convert_cast_format", multigraph) {}
~ConvertCastFormat() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override;
private:
bool NeedChangeCastFormat();
void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const;
void SetCastFormat(const CNodePtr &cast_node, const string &format) const;
};
} // namespace opt
} // namespace mindspore