/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "nnacl/fp32/gather_fp32.h"
#include <string.h>
#include "nnacl/errorcode.h"

inline int Stride(const int *shape, int rank, int index) {
  int i, stride = 1;
  for (i = index + 1; i < rank; ++i) {
    stride *= shape[i];
  }
  return stride;
}

int Gather(const float *input, int outer_size, int inner_size, int limit, const int *indices, int indices_element_size,
           float *output) {
  for (int m = 0; m < outer_size; ++m) {
    const float *inputm = input + inner_size * m * limit;
    float *outputm = output + inner_size * m * indices_element_size;
    for (int i = 0; i < indices_element_size; ++i) {
      if (indices[i] < 0 || indices[i] > limit) {
        return NNACL_ERR;
      }
      memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size);
    }
  }
  return NNACL_OK;
}

int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, const int *indices,
                int indices_element_size, int32_t *output) {
  for (int m = 0; m < outer_size; ++m) {
    const int32_t *inputm = input + inner_size * m * limit;
    int32_t *outputm = output + inner_size * m * indices_element_size;
    for (int i = 0; i < indices_element_size; ++i) {
      if (indices[i] < 0 || indices[i] > limit) {
        return NNACL_ERR;
      }
      memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(int32_t) * inner_size);
    }
  }
  return NNACL_OK;
}
