//  Copyright © 2022 Apple Inc.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/Resize.h>
#include <ATen/native/mps/OperationUtils.h>

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_unique2.h>
#include <ATen/ops/_unique2_native.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/argsort.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/full.h>
#include <ATen/ops/masked_select.h>
#include <ATen/ops/nonzero.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/unique_consecutive.h>
#include <ATen/ops/unique_consecutive_native.h>
#include <ATen/ops/unique_dim_consecutive.h>
#include <ATen/ops/unique_dim_consecutive_native.h>
#include <ATen/ops/unique_dim_native.h>
#include <ATen/ops/zeros.h>
#endif

namespace at::native {
namespace mps {

struct UniqueCachedGraph : public MPSCachedGraph {
  UniqueCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
  MPSGraphTensor* inputTensor_ = nil;
  MPSGraphTensor* outputTensor_ = nil;
  MPSGraphTensor* inverseIndicesTensor_ = nil;
  MPSGraphTensor* countsTensor_ = nil;
  MPSGraphTensor* lengthTensor_ = nil;
};

static std::string getUniqueKey(const ScalarType& dtype,
                                const IntArrayRef& base_shape,
                                const bool return_inverse,
                                const bool return_counts,
                                const bool consecutive,
                                std::optional<int64_t> dimOpt) {
  return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
      (dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" +
      std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]";
}

// dim arg not supported when non consecutive, ie sorted
static std::array<MPSGraphTensor*, 4> buildUniqueGraph(const Tensor& self,
                                                       UniqueCachedGraph* uniqueGraph,
                                                       const bool return_inverse,
                                                       const bool return_counts,
                                                       const bool consecutive,
                                                       std::optional<int64_t> dimOpt) {
  int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0;

  MPSGraph* graph = uniqueGraph->graph();
  MPSGraphTensor* inputTensor = uniqueGraph->inputTensor_;
  MPSShape* shape = [inputTensor shape];
  MPSShape* destShape = shape;
  uint64_t length = [shape[dim] unsignedIntValue];
  MPSDataType dataType = [inputTensor dataType];

  MPSGraphTensor* resultTensor = nil;
  MPSGraphTensor* inverseIndicesTensor = nil;
  MPSGraphTensor* countTensor = nil;
  MPSGraphTensor* lengthTensor = nil;

  const bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
  if (needsFlatten) {
    inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
    length = 1;
    for (const auto i : c10::irange([shape count])) {
      if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
        TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
      }
    }

    destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
  }

  if (length <= 1) {
    // Trivial case, only 1 element everything is unique
    resultTensor = inputTensor;
    lengthTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
    if (return_inverse) {
      inverseIndicesTensor = [graph constantWithScalar:0.0f dataType:MPSDataTypeInt32];
    }
    if (return_counts) {
      countTensor = [graph constantWithScalar:1.0f dataType:MPSDataTypeInt32];
    }
    return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
  }

  // #issue 104398441 sortWithTensor only supports following types, cast if necessary
  if (dataType != MPSDataTypeInt32 && dataType != MPSDataTypeFloat32 && dataType != MPSDataTypeFloat16) {
    dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
    inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
  }

  MPSGraphTensor* sortedInput = nil;
  if (consecutive) {
    sortedInput = inputTensor;
  } else {
    sortedInput = [graph sortWithTensor:inputTensor axis:0 name:nil];
  }

  MPSGraphTensor* frontNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:0 length:length - 1 name:nil];
  MPSGraphTensor* backNMinusOne = [graph sliceTensor:sortedInput dimension:dim start:1 length:length - 1 name:nil];
  MPSGraphTensor* notEqualToPreviousElement = [graph notEqualWithPrimaryTensor:backNMinusOne
                                                               secondaryTensor:frontNMinusOne
                                                                          name:nil];
  MPSGraphTensor* mask = [graph castTensor:notEqualToPreviousElement toType:MPSDataTypeInt32 name:@"castMaskTensor"];

  // If comparing tensors, not scalars, check if entire tensor matches previous element using reductionOr over tensor
  if (dimOpt.has_value() && [shape count] != 1) {
    NSMutableArray* axes = [[NSMutableArray alloc] initWithCapacity:[shape count] - 1];
    for (const auto axis : c10::irange([shape count])) {
      if (static_cast<decltype(dim)>(axis) != dim) {
        [axes addObject:[NSNumber numberWithUnsignedInteger:axis]];
      }
    }
    mask = [graph reductionOrWithTensor:mask axes:axes name:nil];
    mask = [graph squeezeTensor:mask axes:axes name:nil];
    [axes release];
  }

  MPSGraphTensor* scannedIndices = [graph cumulativeSumWithTensor:mask axis:0 name:nil];
  lengthTensor = [graph sliceTensor:scannedIndices dimension:0 start:length - 2 length:1 name:nil];

  MPSGraphTensor* minusOneTensor = [graph constantWithScalar:-1.0f dataType:MPSDataTypeInt32];
  MPSGraphTensor* maskedIndices = [graph selectWithPredicateTensor:mask
                                               truePredicateTensor:scannedIndices
                                              falsePredicateTensor:minusOneTensor
                                                              name:nil];

  MPSGraphTensor* zeroTensor = [graph constantWithScalar:0.0f shape:@[ @1 ] dataType:MPSDataTypeInt32];
  MPSGraphTensor* maskedIndicesWithHead = [graph concatTensors:@[ zeroTensor, maskedIndices ] dimension:0 name:nil];
  MPSGraphTensor* scannedIndicesWithHead = [graph concatTensors:@[ zeroTensor, scannedIndices ] dimension:0 name:nil];

  resultTensor = [graph scatterWithUpdatesTensor:sortedInput
                                   indicesTensor:maskedIndicesWithHead
                                           shape:destShape
                                            axis:dim
                                            mode:MPSGraphScatterModeSet
                                            name:nil];
  // Cast back if necessary
  if ([uniqueGraph->inputTensor_ dataType] != dataType) {
    resultTensor = [graph castTensor:resultTensor toType:[uniqueGraph->inputTensor_ dataType] name:@"castResultTensor"];
  }

  // Compute optional returned tensors if requested
  if (return_inverse) {
    MPSGraphTensor* argSortedInput = nil;
    if (consecutive)
      argSortedInput = [graph coordinateAlongAxis:0
                                        withShape:@[ [NSNumber numberWithUnsignedInteger:length] ]
                                             name:nil];
    else
      argSortedInput = [graph argSortWithTensor:inputTensor axis:0 name:nil];
    inverseIndicesTensor = [graph scatterWithUpdatesTensor:scannedIndicesWithHead
                                             indicesTensor:argSortedInput
                                                     shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
                                                      axis:0
                                                      mode:MPSGraphScatterModeAdd
                                                      name:nil];
    if (needsFlatten)
      inverseIndicesTensor = [graph reshapeTensor:inverseIndicesTensor withShape:shape name:nil];
  }

  if (return_counts) {
    MPSGraphTensor* unitTensor = [graph constantWithScalar:1.0f
                                                     shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
                                                  dataType:MPSDataTypeInt32];
    countTensor = [graph scatterWithUpdatesTensor:unitTensor
                                    indicesTensor:scannedIndicesWithHead
                                            shape:@[ [NSNumber numberWithUnsignedInteger:length] ]
                                             axis:0
                                             mode:MPSGraphScatterModeAdd
                                             name:nil];
  }

  return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
}

static UniqueCachedGraph* getUniqueGraph(const Tensor& self,
                                         const bool return_inverse,
                                         const bool return_counts,
                                         const bool consecutive,
                                         std::optional<int64_t> dim) {
  @autoreleasepool {
    std::string key = getUniqueKey(self.scalar_type(), self.sizes(), return_inverse, return_counts, consecutive, dim);
    return LookUpOrCreateCachedGraph<UniqueCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
      newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self), getMPSShape(self));
      auto outputTensors = buildUniqueGraph(self, newCachedGraph, return_inverse, return_counts, consecutive, dim);

      newCachedGraph->outputTensor_ = outputTensors[0];
      newCachedGraph->inverseIndicesTensor_ = outputTensors[1];
      newCachedGraph->countsTensor_ = outputTensors[2];
      newCachedGraph->lengthTensor_ = outputTensors[3];
    });
  }
}

static void runUniqueGraph(UniqueCachedGraph* uniqueGraph,
                           const Tensor& input,
                           Tensor& output,
                           Tensor& inverse_indices,
                           Tensor& counts,
                           Tensor& length,
                           bool return_inverse,
                           bool return_counts) {
  Placeholder inputPlaceholder = Placeholder(uniqueGraph->inputTensor_, input);
  auto feeds = dictionaryFromPlaceholders(inputPlaceholder);

  NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [NSMutableDictionary dictionary];
  Placeholder outputPlaceholder = Placeholder(uniqueGraph->outputTensor_, output);
  Placeholder lengthPlaceholder = Placeholder(uniqueGraph->lengthTensor_, length);
  [results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
  [results setObject:lengthPlaceholder.getMPSGraphTensorData() forKey:lengthPlaceholder.getMPSGraphTensor()];
  if (return_inverse) {
    Placeholder inverseIndicesPlaceholder = Placeholder(uniqueGraph->inverseIndicesTensor_, inverse_indices);
    [results setObject:inverseIndicesPlaceholder.getMPSGraphTensorData()
                forKey:inverseIndicesPlaceholder.getMPSGraphTensor()];
  }
  if (return_counts) {
    Placeholder countsPlaceholder = Placeholder(uniqueGraph->countsTensor_, counts);
    [results setObject:countsPlaceholder.getMPSGraphTensorData() forKey:countsPlaceholder.getMPSGraphTensor()];
  }

  // Run the graph
  MPSStream* stream = getCurrentMPSStream();
  runMPSGraph(stream, uniqueGraph->graph(), feeds, results);
}

} // namespace mps

static std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
                                                           const bool return_inverse,
                                                           const bool return_counts,
                                                           const bool consecutive,
                                                           std::optional<int64_t> dimOpt) {
  const Tensor& input = self.contiguous();

  // get flat output size
  int64_t totalElems = c10::multiply_integers(input.sizes());

  IntArrayRef outputShape = IntArrayRef(totalElems);
  IntArrayRef inverseIndicesShape = input.sizes();
  IntArrayRef countsShape = IntArrayRef(totalElems);
  int64_t dim = dimOpt.has_value() ? maybe_wrap_dim(dimOpt.value(), self.dim()) : 0;

  if (dimOpt.has_value()) {
    outputShape = input.sizes();
    inverseIndicesShape = IntArrayRef(input.sizes()[dim]);
    countsShape = IntArrayRef(input.sizes()[dim]);
  }
  if (!return_inverse)
    inverseIndicesShape = {};
  if (!return_counts)
    countsShape = {};

  Tensor output = at::empty(outputShape, input.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
  Tensor inverse_indices =
      at::empty(inverseIndicesShape, ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
  Tensor counts = at::empty(countsShape, ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt);
  Tensor length = at::empty({1}, ScalarType::Int, std::nullopt, kMPS, std::nullopt, std::nullopt);

  if (input.numel() == 0) {
    return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts));
  }

  @autoreleasepool {
    mps::UniqueCachedGraph* uniqueGraph =
        mps::getUniqueGraph(input, return_inverse, return_counts, consecutive, dimOpt);
    mps::runUniqueGraph(uniqueGraph, input, output, inverse_indices, counts, length, return_inverse, return_counts);
  }

  int64_t lengthScalar = length.item<int64_t>() + 1; // length actually holds max index, add 1
  if (!output.sizes().empty()) {
    output = at::slice(output, dim, 0, lengthScalar);
  }
  if (return_counts)
    counts = at::slice(counts, 0, 0, lengthScalar);

  return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts));
}

static std::tuple<Tensor, Tensor, Tensor> castToMPS(std::tuple<Tensor, Tensor, Tensor> out) {
  return std::make_tuple(std::get<0>(out).to("mps"), std::get<1>(out).to("mps"), std::get<2>(out).to("mps"));
}

std::tuple<Tensor, Tensor, Tensor> unique_consecutive_mps(const Tensor& self,
                                                          const bool return_inverse,
                                                          const bool return_counts,
                                                          std::optional<int64_t> dim) {
  return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
}

std::tuple<Tensor, Tensor, Tensor> unique_dim_consecutive_mps(const Tensor& self,
                                                              int64_t dim,
                                                              const bool return_inverse,
                                                              const bool return_counts) {
  return _unique_impl_mps(self, return_inverse, return_counts, true, dim);
}

std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
                                                const bool sorted,
                                                const bool return_inverse,
                                                const bool return_counts) {
  return _unique_impl_mps(self, return_inverse, return_counts, false, std::nullopt);
}

static Tensor lexsort_rows_perm_mps(const Tensor& mat_2d) {
  const auto rows = mat_2d.size(0), cols = mat_2d.size(1);
  if (rows <= 1 || cols == 0) {
    return arange(rows, mat_2d.options().dtype(kLong));
  }

  auto perm = arange(rows, mat_2d.options().dtype(kLong));
  for (auto c = cols - 1; c >= 0; --c) {
    auto keys = mat_2d.select(1, c).index_select(0, perm);
    const auto idx = argsort(keys, /*dim=*/0, /*descending=*/false);
    perm = perm.index_select(0, idx);
  }
  return perm;
}

static std::tuple<Tensor, Tensor, Tensor> unique_dim_sorted_mps_impl(const Tensor& self,
                                                                     int64_t dim,
                                                                     bool return_inverse,
                                                                     bool return_counts) {
  dim = maybe_wrap_dim(dim, self.dim());

  auto sizes = self.sizes().vec();
  auto num_zero_dims = std::count(sizes.begin(), sizes.end(), (int64_t)0);
  if (self.size(dim) == 0) {
    auto output = at::empty(sizes, self.options());
    auto inverse_indices = at::empty({0}, self.options().dtype(kLong));
    auto counts = at::empty({0}, self.options().dtype(kLong));
    return {output, inverse_indices, counts};
  }

  auto transposed = self.moveaxis(dim, 0);
  auto orig_sizes = transposed.sizes().vec();
  auto rows = transposed.size(0);
  auto input_flat = transposed.contiguous().view({rows, -1});

  auto perm = lexsort_rows_perm_mps(input_flat);
  auto input_sorted = input_flat.index_select(0, perm);

  Tensor is_unique = at::zeros({rows}, self.options().dtype(kBool));
  if (rows > 0) {
    is_unique.narrow(0, 0, 1).fill_(true);
  }
  if (rows > 1) {
    auto a = input_sorted.narrow(0, 1, rows - 1);
    auto b = input_sorted.narrow(0, 0, rows - 1);
    auto row_changed = a.ne(b).any(1);
    is_unique.narrow(0, 1, rows - 1).copy_(row_changed);
  }

  auto unique_pos = nonzero(is_unique).squeeze(1);
  auto group_id = cumsum(is_unique.to(kLong), 0).sub(1);

  auto unique_rows_2d = input_sorted.index_select(0, unique_pos);

  Tensor inverse_indices = empty({0}, self.options().dtype(kLong));
  if (return_inverse) {
    inverse_indices = empty({rows}, self.options().dtype(kLong));
    inverse_indices.index_copy_(0, perm, group_id);
  }

  Tensor counts = empty({0}, self.options().dtype(kLong));
  if (return_counts) {
    const auto num_unique = unique_pos.size(0);
    counts = zeros({num_unique}, self.options().dtype(kLong));
    counts.scatter_add_(0, group_id, ones_like(group_id, group_id.options().dtype(kLong)));
  }

  orig_sizes[0] = unique_rows_2d.size(0);
  auto output = unique_rows_2d.view(orig_sizes).moveaxis(0, dim);

  return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts));
}

std::tuple<Tensor, Tensor, Tensor> unique_dim_mps(const Tensor& self,
                                                  int64_t dim,
                                                  const bool /*sorted*/,
                                                  const bool return_inverse,
                                                  const bool return_counts) {
  return unique_dim_sorted_mps_impl(self, dim, return_inverse, return_counts);
}

} // namespace at::native
