// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#include <gtest/gtest.h>

#include <Kokkos_Macros.hpp>
#ifdef KOKKOS_ENABLE_EXPERIMENTAL_CXX20_MODULES
import kokkos.core;
import kokkos.core_impl;
#else
#include <Kokkos_Core.hpp>
#endif

#include <regex>

namespace {

template <class IndexType>
void construct_mdrange_policy_variable_type() {
  (void)Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>{
      Kokkos::Array<IndexType, 2>{}, Kokkos::Array<IndexType, 2>{}};

  (void)Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>{
      {{IndexType(0), IndexType(0)}}, {{IndexType(2), IndexType(2)}}};

  (void)Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>{
      {IndexType(0), IndexType(0)}, {IndexType(2), IndexType(2)}};
}

TEST(TEST_CATEGORY, md_range_policy_construction_from_arrays) {
  {
    // Check that construction from Kokkos::Array of the specified index type
    // works.
    using IndexType = unsigned long long;
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>,
                          Kokkos::IndexType<IndexType>>
        p1(Kokkos::Array<IndexType, 2>{{0, 1}},
           Kokkos::Array<IndexType, 2>{{2, 3}});
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>,
                          Kokkos::IndexType<IndexType>>
        p2(Kokkos::Array<IndexType, 2>{{0, 1}},
           Kokkos::Array<IndexType, 2>{{2, 3}});
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>,
                          Kokkos::IndexType<IndexType>>
        p3(Kokkos::Array<IndexType, 2>{{0, 1}},
           Kokkos::Array<IndexType, 2>{{2, 3}},
           Kokkos::Array<IndexType, 1>{{4}});
  }
  {
    // Check that construction from double-braced initializer list
    // works.
    using index_type = unsigned long long;
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>> p1({{0, 1}},
                                                              {{2, 3}});
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>,
                          Kokkos::IndexType<index_type>>
        p2({{0, 1}}, {{2, 3}});
  }
  {
    // Check that construction from Kokkos::Array of long compiles for backwards
    // compability.  This was broken in
    // https://github.com/kokkos/kokkos/pull/3527/commits/88ea8eec6567c84739d77bdd25fdbc647fae28bb#r512323639
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>> p1(
        Kokkos::Array<long, 2>{{0, 1}}, Kokkos::Array<long, 2>{{2, 3}});
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>> p2(
        Kokkos::Array<long, 2>{{0, 1}}, Kokkos::Array<long, 2>{{2, 3}});
    Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>> p3(
        Kokkos::Array<long, 2>{{0, 1}}, Kokkos::Array<long, 2>{{2, 3}},
        Kokkos::Array<long, 1>{{4}});
  }

  // Check that construction from various index types works.
  construct_mdrange_policy_variable_type<char>();
  construct_mdrange_policy_variable_type<int>();
  construct_mdrange_policy_variable_type<unsigned long>();
  construct_mdrange_policy_variable_type<std::int64_t>();
}

#ifndef KOKKOS_ENABLE_OPENMPTARGET  // FIXME_OPENMPTARGET
TEST(TEST_CATEGORY_DEATH, policy_bounds_unsafe_narrowing_conversions) {
  using Policy = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>,
                                       Kokkos::IndexType<unsigned>>;

  std::string msg =
      "Kokkos::MDRangePolicy bound type error: an unsafe implicit conversion "
      "is "
      "performed on a bound (-1) in dimension (0), which may not preserve its "
      "original value.\n";
  std::string expected = std::regex_replace(msg, std::regex("\\(|\\)"), "\\$&");

  ::testing::FLAGS_gtest_death_test_style = "threadsafe";
  ASSERT_DEATH({ (void)Policy({-1, 0}, {2, 3}); }, expected);
}

TEST(TEST_CATEGORY_DEATH, policy_invalid_bounds) {
  using Policy = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>;

  ::testing::FLAGS_gtest_death_test_style = "threadsafe";

  auto [dim0, dim1] = (Policy::inner_direction == Kokkos::Iterate::Right)
                          ? std::make_pair(1, 0)
                          : std::make_pair(0, 1);
  std::string msg1 =
      "Kokkos::MDRangePolicy bounds error: The lower bound (100) is greater "
      "than its upper bound (90) in dimension " +
      std::to_string(dim0) + ".\n";

  std::string msg2 =
      "Kokkos::MDRangePolicy bounds error: The lower bound (100) is greater "
      "than its upper bound (90) in dimension " +
      std::to_string(dim1) + ".\n";

#if !defined(KOKKOS_ENABLE_DEPRECATED_CODE_4)
  // escape the parentheses in the regex to match the error message
  msg1 = std::regex_replace(msg1, std::regex("\\(|\\)"), "\\$&");
  (void)msg2;
  ::testing::FLAGS_gtest_death_test_style = "threadsafe";
  ASSERT_DEATH({ (void)Policy({100, 100}, {90, 90}); }, msg1);
#else
  if (!Kokkos::show_warnings()) {
    GTEST_SKIP() << "Kokkos warning messages are disabled";
  }

  ::testing::internal::CaptureStderr();
  (void)Policy({100, 100}, {90, 90});
#ifdef KOKKOS_ENABLE_DEPRECATION_WARNINGS
  ASSERT_EQ(::testing::internal::GetCapturedStderr(), msg1 + msg2);
#else
  ASSERT_TRUE(::testing::internal::GetCapturedStderr().empty());
  (void)msg1;
  (void)msg2;
#endif

#endif
}
#endif

TEST(TEST_CATEGORY, policy_get_tile_size) {
  constexpr int rank = 3;
  using Policy    = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<rank>>;
  using tile_type = typename Policy::tile_type;

  std::size_t last_rank =
      (Policy::inner_direction == Kokkos::Iterate::Right) ? rank - 1 : 0;

  auto default_size_properties =
      Kokkos::Impl::get_tile_size_properties(TEST_EXECSPACE());

  {
    int dim_length = 100;
    Policy policy_default({0, 0, 0}, {dim_length, dim_length, dim_length});

    auto rec_tile_sizes      = policy_default.tile_size_recommended();
    auto internal_tile_sizes = policy_default.m_tile;

    for (std::size_t i = 0; i < rank; ++i) {
      EXPECT_EQ(rec_tile_sizes[i], internal_tile_sizes[i])
          << " incorrect recommended tile size returned for rank " << i;
    }
  }
  {
    int dim_length = 100;
    Policy policy({0, 0, 0}, {dim_length, dim_length, dim_length},
                  tile_type{{2, 4, 16}});

    auto rec_tile_sizes = policy.tile_size_recommended();

    EXPECT_EQ(default_size_properties.max_total_tile_size,
              policy.max_total_tile_size());

    int prod_rec_tile_size = 1;
    for (std::size_t i = 0; i < rank; ++i) {
      EXPECT_GT(rec_tile_sizes[i], 0)
          << " invalid default tile size for rank " << i;

      if (default_size_properties.default_largest_tile_size == 0) {
        auto expected_rec_tile_size =
            (i == last_rank) ? dim_length
                             : default_size_properties.default_tile_size;
        EXPECT_EQ(expected_rec_tile_size, rec_tile_sizes[i])
            << " incorrect recommended tile size returned for rank " << i;
      } else {
        auto expected_rec_tile_size =
            (i == last_rank) ? default_size_properties.default_largest_tile_size
                             : default_size_properties.default_tile_size;
        EXPECT_EQ(expected_rec_tile_size, rec_tile_sizes[i])
            << " incorrect recommended tile size returned for rank " << i;
      }

      prod_rec_tile_size *= rec_tile_sizes[i];
    }
    EXPECT_LT(prod_rec_tile_size, policy.max_total_tile_size());
  }
}

// The execution space is defaulted if not given to the constructor.
TEST(TEST_CATEGORY, md_range_policy_default_space) {
  using policy_t = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>;

  policy_t defaulted({42, 47}, {666, 999});

  ASSERT_EQ(defaulted.space(), TEST_EXECSPACE{});
}

// The execution space instance can be updated.
TEST(TEST_CATEGORY, md_range_policy_impl_set_space) {
  using policy_t = Kokkos::MDRangePolicy<TEST_EXECSPACE, Kokkos::Rank<2>>;

  const auto [exec_old, exec_new] =
      Kokkos::Experimental::partition_space(TEST_EXECSPACE{}, 1, 1);

  const policy_t policy_old(exec_old, {42, 47}, {666, 999});
  ASSERT_EQ(policy_old.space(), exec_old);

  const policy_t policy_new(Kokkos::Impl::PolicyUpdate{}, policy_old, exec_new);
  ASSERT_EQ(policy_new.space(), exec_new);
  ASSERT_EQ(policy_new.m_lower, (typename policy_t::point_type{42, 47}));
  ASSERT_EQ(policy_new.m_upper, (typename policy_t::point_type{666, 999}));
}

}  // namespace
