forked from OSSInnovation/mindspore
!9880 change cast format test
From: @lianliguang Reviewed-by: @kisnwang,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
d988e13fb5
|
@ -219,6 +219,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
} else {
|
||||
data_layout_pm->AddPass(std::make_shared<MergeCastToOp>());
|
||||
data_layout_pm->AddPass(std::make_shared<ConvertCastFormat>());
|
||||
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
|
||||
data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
|
||||
}
|
||||
data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||
|
|
|
@ -16,8 +16,11 @@
|
|||
#include "backend/optimizer/ascend/format_type/convert_cast_format.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef ConvertCastFormat::DefinePattern() const {
|
||||
|
@ -26,8 +29,8 @@ const BaseRef ConvertCastFormat::DefinePattern() const {
|
|||
return VectorRef({X, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, const mindspore::AnfNodePtr &node,
|
||||
const mindspore::EquivPtr &) const {
|
||||
const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -44,26 +47,77 @@ const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, con
|
|||
continue;
|
||||
}
|
||||
auto cast_node = input_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_node);
|
||||
auto input_node_name = AnfAlgo::GetCNodeName(cast_node);
|
||||
if (input_node_name != prim::kPrimCast->name()) {
|
||||
continue;
|
||||
}
|
||||
auto format = AnfAlgo::GetInputFormat(node, input_index);
|
||||
auto cast_input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_node, 0), 0).first;
|
||||
auto cast_input_format = AnfAlgo::GetOutputFormat(cast_input_node, 0);
|
||||
// change cast to default that can be more faster when it cast other hw format
|
||||
if (cast_input_format != format) {
|
||||
if (cast_input_format == kOpFormat_DEFAULT || format == kOpFormat_DEFAULT) {
|
||||
auto info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(
|
||||
AnfAlgo::GetSelectKernelBuildInfo(cast_node));
|
||||
info_builder->SetInputsFormat({kOpFormat_DEFAULT});
|
||||
info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
|
||||
}
|
||||
}
|
||||
ChangeCastFormat(cast_node, func_graph);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ConvertCastFormat::SetCastFormat(const CNodePtr &cast_node, const string &format) const {
|
||||
auto info_builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(cast_node));
|
||||
info_builder->SetInputsFormat({format});
|
||||
info_builder->SetOutputsFormat({format});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get());
|
||||
}
|
||||
|
||||
void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const {
|
||||
MS_EXCEPTION_IF_NULL(cast_node);
|
||||
auto input_node_name = AnfAlgo::GetCNodeName(cast_node);
|
||||
if (input_node_name != prim::kPrimCast->name()) {
|
||||
return;
|
||||
}
|
||||
if (AnfAlgo::HasNodeAttr(kAttrVisited, cast_node) && AnfAlgo::GetNodeAttr<bool>(cast_node, kAttrVisited)) {
|
||||
return;
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node);
|
||||
auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node);
|
||||
MS_EXCEPTION_IF_NULL(used_cast_node_list);
|
||||
std::unordered_map<string, size_t> format_counter;
|
||||
for (const auto &node_info : *used_cast_node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node_info.first);
|
||||
auto cast_out_node = node_info.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cast_out_node);
|
||||
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) {
|
||||
if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast<CNodePtr>(), index), 0).first !=
|
||||
cast_node) {
|
||||
continue;
|
||||
}
|
||||
auto format = AnfAlgo::GetInputFormat(cast_out_node, index);
|
||||
auto it = format_counter.find(format);
|
||||
if (it == format_counter.end()) {
|
||||
format_counter[format] = 1;
|
||||
} else {
|
||||
it->second++;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0);
|
||||
string convert_format = kOpFormat_DEFAULT;
|
||||
if (cast_input_format == kOpFormat_DEFAULT) {
|
||||
SetCastFormat(cast_node, convert_format);
|
||||
return;
|
||||
}
|
||||
if (format_counter.size() == 1 && format_counter.begin()->first == kOpFormat_DEFAULT) {
|
||||
SetCastFormat(cast_node, convert_format);
|
||||
return;
|
||||
}
|
||||
auto it = format_counter.find(cast_input_format);
|
||||
if (it == format_counter.end()) {
|
||||
format_counter[cast_input_format] = 1;
|
||||
} else {
|
||||
it->second++;
|
||||
}
|
||||
if (format_counter.size() < 2) {
|
||||
size_t max_counter = 0;
|
||||
for (const auto &iter : format_counter) {
|
||||
if (iter.second > max_counter) {
|
||||
max_counter = iter.second;
|
||||
convert_format = iter.first;
|
||||
}
|
||||
}
|
||||
// change cast to default that can be more faster when it cast other hw format
|
||||
SetCastFormat(cast_node, convert_format);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -25,10 +27,11 @@ class ConvertCastFormat : public PatternProcessPass {
|
|||
explicit ConvertCastFormat(bool multigraph = true) : PatternProcessPass("convert_cast_format", multigraph) {}
|
||||
~ConvertCastFormat() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
bool NeedChangeCastFormat();
|
||||
void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const;
|
||||
void SetCastFormat(const CNodePtr &cast_node, const string &format) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue