/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include <memory>
#include <vector>
#include <string>
#include <algorithm>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"

namespace mindspore {
namespace opt {
namespace {
std::vector<int64_t> TransposeAxis(const std::string &src_format, const std::string &dst_format) {
  if ((src_format == kOpFormat_NCHW) && (dst_format == kOpFormat_NHWC)) {
    return {0, 2, 3, 1};
  } else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) {
    return {0, 3, 1, 2};
  } else {
    MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format;
  }
}

// Transpose can be replaceed by nop reshape in some situations.
// 1. out_shape [x, 1, 1, y]
// 2. out_shape [x, y, 1, 1]
// 3. out_shape [x, 1, y, 1]
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int64_t> &transpose_perm) {
  if (out_shape.size() != 4) {
    MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D.";
  }
  std::vector<int64_t> perm1 = {0, 2, 3, 1};
  std::vector<int64_t> perm2 = {0, 3, 1, 2};
  auto num = std::count(out_shape.begin(), out_shape.end(), 1);
  if ((transpose_perm == perm1) || (transpose_perm == perm2)) {
    if (num >= 2) {
      return true;
    }
  }
  return false;
}

void SetTransposeOpBuildInfo(const std::string &input_format, const std::string &output_format,
                             const AnfNodePtr &node) {
  MS_EXCEPTION_IF_NULL(node);
  auto input_type = AnfAlgo::GetPrevNodeOutputInferDataType(node, 0);
  auto output_type = AnfAlgo::GetOutputInferDataType(node, 0);
  kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  builder.SetInputsFormat({input_format});
  builder.SetInputsDeviceType({input_type});
  builder.SetOutputsFormat({output_format});
  builder.SetOutputsDeviceType({output_type});
  builder.SetKernelType(UNKNOWN_KERNEL_TYPE);
  builder.SetProcessor(kernel::Processor::CUDA);
  AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}

// Insert transpose op between node and used_node whose position is used_node_index.
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
                           int used_node_index, const std::vector<int64_t> &transpose_perm) {
  MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
                << ", index: " << used_node_index;
  MS_EXCEPTION_IF_NULL(graph);
  // 0.Judge whether it is a fake transpose
  auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index);
  bool is_fake = IsFakeTranspose(transed_shape, transpose_perm);
  // 1.Create a transpose node or a fake transpose node:reshape.
  mindspore::PrimitivePtr transpose_prim;
  if (is_fake) {
    transpose_prim = std::make_shared<Primitive>(prim::kPrimReshape->name());
  } else {
    transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name());
  }
  MS_EXCEPTION_IF_NULL(transpose_prim);
  // 2.Set the input of transpose.
  std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node};
  auto transpose_op = graph->NewCNode(transpose_input);
  // 3.Set the output info of transpose.
  auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)};
  auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)};
  AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get());
  if (!is_fake) {
    AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op);
  }
  // 4. Set the new edge of transpose op.
  FuncGraphManagerPtr manager = graph->manager();
  MS_EXCEPTION_IF_NULL(manager);
  manager->SetEdge(used_node, used_node_index + 1, transpose_op);
  return transpose_op;
}
}  // namespace

const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
                                                  const EquivPtr &equiv) const {
  MS_EXCEPTION_IF_NULL(graph);
  MS_EXCEPTION_IF_NULL(node);
  MS_EXCEPTION_IF_NULL(equiv);
  if (!AnfAlgo::IsRealCNodeKernel(node)) {
    return nullptr;
  }
  auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node));
  if (iter == device::gpu::kKernelFormatPositionMap.end()) {
    return nullptr;
  }
  auto origin_data_format = AnfAlgo::GetOriginDataFormat(node);
  if (origin_data_format == kOpFormat_DEFAULT) {
    origin_data_format = kOpFormat_NCHW;
  }
  MS_LOG(DEBUG) << "Process node: " << node->fullname_with_scope();
  // Insert input transpose from origin_data_format to input_format.
  auto inputs_format = AnfAlgo::GetAllInputFormats(node);
  for (size_t i = 0; i < inputs_format.size(); i++) {
    if ((inputs_format[i] != kOpFormat_DEFAULT) && (inputs_format[i] != origin_data_format)) {
      auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
      MS_EXCEPTION_IF_NULL(input_node);
      auto transpose_perm = TransposeAxis(origin_data_format, inputs_format[i]);
      auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
      SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
    }
  }

  // Insert output transpose from output_format to origin_data_format.
  auto outputs_format = AnfAlgo::GetAllOutputFormats(node);
  for (size_t i = 0; i < outputs_format.size(); i++) {
    if ((outputs_format[i] != kOpFormat_DEFAULT) && (outputs_format[i] != origin_data_format)) {
      // Find all nodes connected with node output, and change their inputs to transpose.
      auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, i);
      for (size_t j = 0; j < used_node_list->size(); j++) {
        auto used_node = used_node_list->at(j).first;
        auto used_node_index = used_node_list->at(j).second - 1;
        auto transpose_perm = TransposeAxis(outputs_format[i], origin_data_format);
        if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
          MS_LOG(DEBUG) << "The used node of [" << node->fullname_with_scope() << "] is tuple item.";
          // The tuple item need get next used nodes again.
          ProcessForTupleItem(graph, used_node, used_node_index, transpose_perm, outputs_format[i]);
          continue;
        }
        auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
        SetTransposeOpBuildInfo(outputs_format[i], kOpFormat_DEFAULT, transpose_op);
      }
    }
  }
  return node;
}

void InsertFormatTransformOp::ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int node_index,
                                                  const std::vector<int64_t> &transpose_perm,
                                                  const std::string &transpose_format) const {
  auto used_node_list = GetRealNodeUsedListByOutputIdx(graph, node, node_index);
  for (size_t i = 0; i < used_node_list->size(); i++) {
    auto used_node = used_node_list->at(i).first;
    auto used_node_index = used_node_list->at(i).second - 1;
    if (AnfAlgo::GetCNodeName(used_node) == prim::kPrimTupleGetItem->name()) {
      MS_LOG(EXCEPTION) << "The used node of tuple item can't be tuple item.";
    }
    auto transpose_op = InsertTransposeOp(graph, node, used_node, used_node_index, transpose_perm);
    SetTransposeOpBuildInfo(transpose_format, kOpFormat_DEFAULT, transpose_op);
  }
}
}  // namespace opt
}  // namespace mindspore
