/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "frontend/parallel/ops_info/slice_info.h"

#include <algorithm>
#include <memory>
#include <utility>
#include <vector>

#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"

namespace mindspore {
namespace parallel {
Status SliceInfo::GetInput(const ValuePtr &input_value, std::vector<int64_t> *input) {
  MS_EXCEPTION_IF_NULL(input_value);
  ValueTuplePtr value_tuple = input_value->cast<ValueTuplePtr>();
  if (value_tuple == nullptr) {
    MS_LOG(ERROR) << name_ << ": Input value must be ValueTuplePtr.";
    return FAILED;
  }

  for (auto &element : value_tuple->value()) {
    MS_EXCEPTION_IF_NULL(element);
    if (element->isa<Int64Imm>()) {
      int64_t value = element->cast<Int64ImmPtr>()->value();
      input->push_back(value);
    } else {
      MS_LOG(ERROR) << name_ << ": The value must be int64";
      return FAILED;
    }
  }

  return SUCCESS;
}

Status SliceInfo::GetAttrs() {
  if (input_value_.size() != SLICE_INPUTS_SIZE) {
    MS_LOG(ERROR) << name_ << ": The size of input value must be " << SLICE_INPUTS_SIZE << ", but got "
                  << input_value_.size();
    return FAILED;
  }

  if ((GetInput(input_value_[SLICE_BEGIN_INDEX], &begin_) != SUCCESS) ||
      (GetInput(input_value_[SLICE_SIZE_INDEX], &size_) != SUCCESS)) {
    return FAILED;
  }

  return SUCCESS;
}

Status SliceInfo::CheckStrategy(const StrategyPtr &strategy) {
  MS_EXCEPTION_IF_NULL(strategy);
  if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Invalid strategy";
    return FAILED;
  }

  std::vector<Dimensions> stra = strategy->GetInputDim();
  if (stra.empty()) {
    MS_LOG(ERROR) << name_ << ": The strategy is empty";
    return FAILED;
  }

  Dimensions strategy_value = stra[0];

  for (size_t i = 0; i < begin_.size(); ++i) {
    bool no_fully_fetch = ((begin_[i] != 0) || (size_[i] < inputs_shape_[0][i]));
    if (no_fully_fetch && (strategy_value[i] != 1)) {
      MS_LOG(ERROR) << name_ << ": When a dimension is not fully fetched, the dimension can not be split now";
      return FAILED;
    }
  }

  return SUCCESS;
}

Status SliceInfo::InferDevMatrixShape() {
  MS_EXCEPTION_IF_NULL(strategy_);
  std::vector<Dimensions> stra = strategy_->GetInputDim();
  if (stra.empty()) {
    MS_LOG(ERROR) << name_ << ": The strategy is empty";
    return FAILED;
  }

  dev_matrix_shape_ = stra[0];
  return SUCCESS;
}

Status SliceInfo::InferTensorMap() {
  TensorMap tensor_map;
  if (inputs_shape_.empty()) {
    MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
    return FAILED;
  }

  // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
  int64_t size = SizeToInt(inputs_shape_[0].size());
  for (int i = 0; i < size; ++i) {
    tensor_map.push_back(size - i - 1);
  }

  inputs_tensor_map_.push_back(tensor_map);
  outputs_tensor_map_.push_back(tensor_map);
  return SUCCESS;
}

Status SliceInfo::InferMirrorOps() {
  mirror_ops_.clear();
  if (inputs_tensor_map_.empty()) {
    MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
    return FAILED;
  }
  Shape input_tensor_map = inputs_tensor_map_[0];
  std::vector<Group> group;
  if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Create group for input failed.";
    return FAILED;
  }

  if (group.empty()) {
    MS_LOG(INFO) << name_ << ": The mirror group is empty.";
    return SUCCESS;
  }

  OperatorVector input_op, begin_op, end_op;
  input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
  mirror_ops_.push_back(input_op);
  mirror_ops_.push_back(begin_op);
  mirror_ops_.push_back(end_op);
  return SUCCESS;
}

Status SliceInfo::InferTensorInfo() {
  if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
    MS_LOG(ERROR) << name_ << ": Invalid args";
    return FAILED;
  }
  // infer tensor layout
  TensorLayout input_layout, output_layout;
  if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
    return FAILED;
  }
  if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
    return FAILED;
  }

  TensorInfo input_tensor_info(input_layout);
  TensorInfo output_tensor_info(output_layout);

  inputs_tensor_info_.push_back(input_tensor_info);
  outputs_tensor_info_.push_back(output_tensor_info);

  return SUCCESS;
}

// Note: if the batch dimension is not fully fetched, the batch strategy may not work.
std::shared_ptr<Strategys> SliceInfo::GenerateBatchStrategies() {
  split_flag_list_ = {true};
  return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_);
}

Status SliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

Status SliceInfo::GenerateStrategies(int64_t stage_id) {
  if (InferAttrs() != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Infer attrs failed";
    return FAILED;
  }
  if (inputs_shape_.empty()) {
    MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
    return FAILED;
  }
  Shape input_split(inputs_shape_[0].size(), 1);
  for (size_t i = 0; i < begin_.size(); ++i) {
    bool no_fully_fetch = ((begin_[i] != 0) || (size_[i] < inputs_shape_[0][i]));
    if (no_fully_fetch) {
      input_split[i] = 0;
    }
  }
  Shapes splittable_inputs = {input_split};

  std::vector<StrategyPtr> sp_vector;
  if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
    return FAILED;
  }

  size_t success = 0;
  for (auto &sp : sp_vector) {
    PrintStrategy(sp);
    if (SetCostUnderStrategy(sp) == SUCCESS) {
      success++;
      MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
      PrintStrategy(sp);
    }
  }
  return SUCCESS;
}

Status SliceInfo::Init(const StrategyPtr &strategy) {
  if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Init failed.";
    return FAILED;
  }
  MS_LOG(INFO) << name_ << ": Init success.";
  return SUCCESS;
}

Status SliceInfo::InitForCostModel(const StrategyPtr &strategy) {
  if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
    MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
    return FAILED;
  }

  MS_LOG(INFO) << name_ << ": Init for cost model success.";
  return SUCCESS;
}

ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) {
  auto input_strategy = strategy_->GetInputDim().at(0);
  if (std::any_of(input_strategy.begin(), input_strategy.end(), [](const int64_t &shard) { return shard > 1; })) {
    if (ComputeReplaceGraph(cnode) != SUCCESS) {
      MS_LOG(EXCEPTION) << name_ << ": InferReplaceOp failed.";
    }
  }
  return replace_graph_;
}

AnfNodePtr CreateValueTupleAndNodePtr(const std::vector<int64_t> &value_tuple) {
  auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
  auto value_node = NewValueNode(value_ptr);
  return value_node->cast<AnfNodePtr>();
}

Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
  GenerateGraph gen_g = GenerateGraph();
  if (gen_g.Init(cnode) != SUCCESS) {
    MS_LOG(ERROR) << "GenerateGraph Init failed";
    return FAILED;
  }
  Dimensions input_stra = strategy_->GetInputDim().at(0);

  std::vector<int64_t> sliced_size_shape_int;
  Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
  for (uint64_t i = 0; i < size_.size(); i++) {
    if (input_stra[i] == 1) {
      sliced_size_shape_int.push_back(size_[i]);
    } else {
      sliced_size_shape_int.push_back(input_slice_shape[i]);
    }
  }
  auto new_begin = CreateValueTupleAndNodePtr(begin_);
  auto new_size = CreateValueTupleAndNodePtr(sliced_size_shape_int);

  auto slice = gen_g.PushBack({gen_g.NewOpInst(SLICE), gen_g.virtual_input_node(), new_begin, new_size});

  std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(slice, 1)};
  replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
    std::make_pair(input_nodes, slice));

  return SUCCESS;
}
}  // namespace parallel
}  // namespace mindspore
