// Copyright @yucwang 2022

#include "cukd/knnfinder.h"

#include "cukd/builder.h"
#include "cukd/knn.h"

namespace cukd {

typedef FCPPoint<float3, float> FCPPointF3;

__global__ void _floatToFCPointF3(const float* p, FCPPointF3* dst, int num) {
    int tid = threadIdx.x + blockIdx.x * blockDim.x;
    if (tid >= num) return;

    dst[tid].p.x = p[tid];
    dst[tid].p.y = p[tid + num];
    dst[tid].p.z = p[tid + (num << 1)];
    dst[tid].originalId = tid;
}

void floatToFCPointF3(const float* p, FCPPointF3* dst, int num) {
    int bs = 128;
    int nb = cukd::common::divRoundUp(num,bs);
    _floatToFCPointF3<<<nb,bs>>>(p,dst,num);
}

template<const int k>
__global__ void d_knn(int *d_results,
                    FCPPointF3 *d_queries,
                    int numQueries,
                    FCPPointF3 *d_nodes,
                    int numNodes)
{
  int tid = threadIdx.x+blockIdx.x*blockDim.x;
  if (tid >= numQueries) return;

  const float maxRadius = 65536.0f;
  cukd::FixedCandidateList<k> results(maxRadius);
  cukd::knn(results, d_queries[tid], d_nodes, numNodes);

  for (int i = 0; i < k; ++i) {
    int nodeId = results.decode_pointID(results.entry[i]);
    d_results[i*numQueries + tid] = d_nodes[nodeId].originalId;
  }
}

void _knn(int *d_results,
         FCPPointF3 *d_queries,
         int numQueries,
         FCPPointF3 *d_nodes,
         int numNodes,
         int k)
{
  int bs = 128;
  int nb = cukd::common::divRoundUp(numQueries,bs);

  switch (k)
  {
  case 1:
    d_knn<1><<<nb, bs>>>(d_results, d_queries, numQueries, d_nodes, numNodes);
    break;
  case 2:
    d_knn<2><<<nb, bs>>>(d_results, d_queries, numQueries, d_nodes, numNodes);
    break;
  case 3:
    d_knn<3><<<nb, bs>>>(d_results, d_queries, numQueries, d_nodes, numNodes);
    break;
  case 4:
    d_knn<4><<<nb, bs>>>(d_results, d_queries, numQueries, d_nodes, numNodes);
    break;
  default:
    d_knn<1><<<nb, bs>>>(d_results, d_queries, numQueries, d_nodes, numNodes);
    break;
  }
}

void FindKNN(const float* refPoint, int refNumber, const float* queries,
                         int queriesNumber, int* results, int k) {
    FCPPointF3* refPointCopy = 0;
    CUKD_CUDA_CALL(MallocManaged((void**)&refPointCopy, refNumber * sizeof(FCPPointF3)));

    FCPPointF3* refQueryCopy = 0;
    CUKD_CUDA_CALL(MallocManaged((void**)&refQueryCopy, queriesNumber * sizeof(FCPPointF3)));

    floatToFCPointF3(refPoint, refPointCopy, refNumber);
    floatToFCPointF3(queries, refQueryCopy, queriesNumber);
    CUKD_CUDA_SYNC_CHECK();

    // build KD Tree
    cukd::buildTree<float3, float, 3, FCPPointInterface<float3, float>>(refPointCopy, refNumber);
    CUKD_CUDA_SYNC_CHECK();

    _knn(results, refQueryCopy, queriesNumber, refPointCopy, refNumber, k);
    CUKD_CUDA_SYNC_CHECK();

    CUKD_CUDA_CALL(Free((void*) refPointCopy));
    CUKD_CUDA_CALL(Free((void*) refQueryCopy));
}

} // namespace cukd