// Copyright @yucwang 2022

#include "cpfinder.h"

#include "cukd/builder.h"
#include "cukd/fcp.h"

namespace cukd {

typedef FCPPoint<float3, float> FCPPointF3;

__global__ void _floatToFCPointF3(const float* p, FCPPointF3* dst, int num) {
    int tid = threadIdx.x + blockIdx.x * blockDim.x;
    if (tid >= num) return;

    dst[tid].p.x = p[tid];
    dst[tid].p.y = p[tid + num];
    dst[tid].p.z = p[tid + (num << 1)];
    dst[tid].originalId = tid;
}

void floatToFCPointF3(const float* p, FCPPointF3* dst, int num) {
    int bs = 128;
    int nb = cukd::common::divRoundUp(num,bs);
    _floatToFCPointF3<<<nb,bs>>>(p,dst,num);
}

__global__ void d_fcp(int *d_results,
                    FCPPointF3 *d_queries,
                    int numQueries,
                    FCPPointF3 *d_nodes,
                    int numNodes)
{
  int tid = threadIdx.x+blockIdx.x*blockDim.x;
  if (tid >= numQueries) return;

  int curId = cukd::fcp<float3>(d_queries[tid],d_nodes,numNodes);
  d_results[tid] = d_nodes[curId].originalId;
}

void _fcp(int *d_results,
         FCPPointF3 *d_queries,
         int numQueries,
         FCPPointF3 *d_nodes,
         int numNodes)
{
  int bs = 128;
  int nb = cukd::common::divRoundUp(numQueries,bs);
  d_fcp<<<nb,bs>>>(d_results,d_queries,numQueries,d_nodes,numNodes);
}

void FindClosestPoint(const float* refPoint, int refNumber, const float* queries,
                         int queriesNumber, int* results) {
    FCPPointF3* refPointCopy = 0;
    CUKD_CUDA_CALL(MallocManaged((void**)&refPointCopy, refNumber * sizeof(FCPPointF3)));

    FCPPointF3* refQueryCopy = 0;
    CUKD_CUDA_CALL(MallocManaged((void**)&refQueryCopy, queriesNumber * sizeof(FCPPointF3)));

    floatToFCPointF3(refPoint, refPointCopy, refNumber);
    floatToFCPointF3(queries, refQueryCopy, queriesNumber);
    CUKD_CUDA_SYNC_CHECK();

    // build KD Tree
    cukd::buildTree<float3, float, 3, FCPPointInterface<float3, float>>(refPointCopy, refNumber);
    CUKD_CUDA_SYNC_CHECK();

    _fcp(results, refQueryCopy, queriesNumber, refPointCopy, refNumber);
    CUKD_CUDA_SYNC_CHECK();

    CUKD_CUDA_CALL(Free((void*) refPointCopy));
    CUKD_CUDA_CALL(Free((void*) refQueryCopy));
}

} // namespace cukd