/**
 * Copyright 2020 Huawei Technologies Co., Ltd

 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at

 * http://www.apache.org/licenses/LICENSE-2.0

 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_SOMAS_SOMAS_SOLVER_ALG_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_SOMAS_SOMAS_SOLVER_ALG_H_

#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstddef>
#include <cstring>
#include <list>
#include <memory>
#include <numeric>
#include <set>
#include <stack>
#include <unordered_map>
#include <utility>
#include <vector>

#include "backend/optimizer/somas/somas_solver_pre.h"
#include "utils/ms_context.h"

using std::pair;
using std::set;
using std::stack;
using std::unordered_map;
using std::vector;

namespace mindspore {
namespace somas {
class Interval {
 public:
  Interval() { m_a_ = m_b_ = 0; }
  explicit Interval(SomasSolverTensorDescPtr t) {
    m_a_ = t->offset_;
    m_b_ = m_a_ + t->size_;
  }
  Interval(const size_t &a, const size_t &b) {
    m_a_ = a;
    m_b_ = b;
  }
  ~Interval() = default;

  size_t m_a_;
  size_t m_b_;
  bool intersect(const Interval &i) { return (in(i.m_a_) || in(i.m_b_)); }
  bool in(const size_t &a) { return ((a > m_a_) && (a < m_b_)); }
  Interval intersection(const Interval &i) {
    if (m_a_ < i.m_a_)
      return Interval(m_a_, i.m_b_);
    else
      return Interval(i.m_a_, m_b_);
  }
  void merge(const Interval &i) {
    m_a_ = std::min(m_a_, i.m_a_);
    m_b_ = std::max(m_b_, i.m_b_);
  }
  size_t &lb() { return m_a_; }
  size_t &ub() { return m_b_; }
  bool contains(size_t width) { return (m_b_ - m_a_) >= width; }
  bool contains(const Interval &a) { return ((a.m_a_ >= m_a_) && (a.m_b_ <= m_b_)); }
  Interval &operator=(const Interval &in) {
    m_a_ = in.m_a_;
    m_b_ = in.m_b_;
    return *this;
  }
};

class BlockTensor {
 public:
  SomasSolverTensorDescPtr m_start_tensor_;
  unordered_map<uint32_t,
                std::set<pair<size_t, size_t>, bool (*)(const pair<size_t, size_t> &, const pair<size_t, size_t> &)>>
    offsets_candidates_;
  uint32_t m_current_sol_;
  bool m_bre_allocate_;
  unordered_map<uint32_t, size_t> offsets_;
  size_t m_size_;
  BlockTensor()
      : m_start_tensor_(NULL),
        offsets_candidates_(),
        m_current_sol_(0),
        m_bre_allocate_(true),
        offsets_(),
        m_size_(0) {}
  ~BlockTensor() = default;

  BlockTensor &operator=(const BlockTensor &bt) {
    m_bre_allocate_ = bt.m_bre_allocate_;
    m_current_sol_ = 0;
    m_start_tensor_ = bt.m_start_tensor_;
    offsets_candidates_ = bt.offsets_candidates_;
    offsets_ = bt.offsets_;
    m_size_ = bt.m_size_;
    return *this;
  }
  void log() {
    SomasSolverTensorDescPtr p = m_start_tensor_;
    MS_LOG(DEBUG) << "Block of Tensors [" << m_start_tensor_->index_ << "]\nsize:  " << m_size_ << "Tensors:";
    while (p) {
      MS_LOG(DEBUG) << "[" << p->index_ << "," << p->size_ << "]";
      p = p->right_;
    }
  }
  bool Alone() const { return ((NULL == m_start_tensor_->right_) && (NULL == m_start_tensor_->left_)); }
};

class FootPrint : public std::enable_shared_from_this<FootPrint> {
 public:
  uint32_t m_solId_;
  std::shared_ptr<FootPrint> m_foot_print_next_;

  FootPrint()
      : m_solId_(0),
        m_foot_print_next_(NULL),
        m_offset_(0),
        m_starts_({}),
        m_alignment_(0),
        m_branching_strategy_(0),
        m_algorithm_(0) {}
  ~FootPrint() = default;
  void setAlignment(const size_t a) { m_alignment_ = a; }
  void setBranchingStrategy(uint32_t bs) { m_branching_strategy_ = bs; }
  void setCurrentSol(uint32_t solId) { m_solId_ = solId; }
  void setAlgorithm(uint32_t algorithm) { m_algorithm_ = algorithm; }
  void addStart(BlockTensor *elemIndex) { m_starts_.push_back(elemIndex); }
  void addElem(BlockTensor *block, const size_t &offset);
  std::shared_ptr<FootPrint> &Next() { return m_foot_print_next_; }
  vector<BlockTensor *> &getStarts() { return m_starts_; }
  void Destroy();
  const size_t getOffset() { return m_offset_; }
  void setOffset(const size_t &offset) { m_offset_ = offset; }
  bool findOffset(const std::vector<DynamicBitSet> *constraints, const BlockTensor &block, size_t *offset);
  void ConstrainedBLocks(const std::vector<DynamicBitSet> *constraints, const BlockTensor &b1, const BlockTensor &b2,
                         vector<Interval> *oInterval_l);
  void Merge(vector<Interval> *l_interval, stack<Interval> *l_merged);
  bool findFirst(stack<Interval> *merged, const BlockTensor &block, size_t *offset);
  size_t Result();
  void printStats();

 private:
  size_t m_offset_;
  vector<BlockTensor *> m_starts_;
  size_t m_alignment_;
  uint32_t m_branching_strategy_;
  uint32_t m_algorithm_;
};

class FastHeuristic {
 public:
  FastHeuristic() : m_alignment_(512), m_tensors_allocated_(0) {}
  ~FastHeuristic() = default;

  void setAlignment(const size_t &a) { m_alignment_ = a; }
  void Destroy();
  bool Eval(vector<BlockTensor> *block_tensors_v, std::shared_ptr<FootPrint> foot_print,
            const std::vector<DynamicBitSet> *pConstraints);

 private:
  size_t m_alignment_;
  size_t m_tensors_allocated_;
};
}  // namespace somas
}  // namespace mindspore

#endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_SOMAS_SOMAS_SOLVER_ALG_H_
