球坐标系角距离(angular distance)下的KD树最近点查询算法

代码地址:https://github.com/mhy12345/angular-kdtree

给定一个坐标集合S,对于若干询问坐标p,分别求出该坐标在坐标集合中最邻近(K临近)的坐标。

KNN算法

问题概述

在K临近算法是一个经典的问题,通常在平面欧几里得坐标系下,我们使用KDTree这个算法来解决,时间复杂度为单次查询O(\sqrt{N})。现在,我们需要将这个KD树算法扩展到球坐标系下,来解决球坐标系中的最近点查询。具体来讲,定义问题为——

通过经度纬度定义球坐标系坐标点,并给定角距离公式 angular_separation 如下,这里的距离即天文学中的天体视角

double angular_separation(double lon1, double lat1, double lon2, double lat2) {
    double sdlon = sin(lon2 - lon1);
    double cdlon = cos(lon2 - lon1);
    double slat1 = sin(lat1);
    double slat2 = sin(lat2);
    double clat1 = cos(lat1);
    double clat2 = cos(lat2);

    double num1 = clat2 * sdlon;
    double num2 = clat1 * slat2 - slat1 * clat2 * cdlon;
    double denominator = slat1 * slat2 + clat1 * clat2 * cdlon;
    return atan2(hypot(num1, num2), denominator);
}

我们对于每个询问坐标q,查询到与q距离最近的一个集合S中的坐标。

实现思路

与普通KD树相似,使用横纵轴轮流分割的方式建树。即,第一层将所有点按照维度大小对半分,第二层在第一层的基础上按照经度对半分。这样,KD树每一个节点都维护一个球坐标bounding-box内的坐标集合。对于每一次询问,只需要递归对每个子树遍历查询,并通过计算坐标与bounding-box的最短角距离来进行遍历剪枝。

使用样例

C++语言

#include "kdtree.h"
#include <iostream>
using namespace std;

int main() {
    KDTree kdt = KDTree(); // 新建一个KDTree实例
    Point point_to_insert = Point(
        random()%10000/10000.0*PI*2,  // x, 维度坐标,范围(0, 2pi)
        random()%10000/10000.0*PI - PI/2,  // y, 精度坐标,范围(-pi/2, pi/2)
        NULL // tag, 一个类型为const char*的标签,可用于储存必要的信息
        );
    kdt.Insert(point_to_insert);
    Point point_to_query = Point(0,0,NULL);
    pair<double,const Point*> kdt_res = kdt.Search(
        point_to_query,  // pt, 待查询的坐标, tag可为空,
        0.1 // r, 搜索距离,即查询角距离小于该值的区域
        );
    if (kdt_res.second) { // NULL表示没有找到,反之储存结果
        cout<<"Distance : "<<kdt_res.first<<endl;
        cout<<Point : "<<kdt_res.second->x<<" "<<kdt_res.second.y<<endl;
    }
}

Python语言

import ctypes
from ctypes import *

so = cdll.LoadLibrary   
lib = so("/path/to/kdt_toolbox.so")
lib.KDTBatchInsert.argtypes = [c_int, POINTER(c_double), POINTER(c_double), POINTER(c_char_p)]
lib.KDTInsert.argtypes = [c_double, c_double, c_char_p]
lib.KDTSearch.argtypes = [c_double,c_double]
lib.KDTSearch.restype = c_char_p

xs = np.array(... ,dtype=np.double)
ys = np.array(... ,dtype=np.double)
sid = list(map(lambda x:x.encode('ascii'), ...))

char_p_arr = ctypes.c_char_p * len(data)
sid_p = char_p_arr(*sid)

xs_p = xs.ctypes.data_as(POINTER(c_double))
ys_p = ys.ctypes.data_as(POINTER(c_double))

lib.KDTInit()
args = c_int(len(data)),xs_p, ys_p, sid_p
lib.KDTBatchInsert(*args)

args = c_double(xs[0]),c_double(ys[0])
res = lib.KDTSearch(*args)
assert(res.decode('ascii') == str(sid[0]))

如果你觉得有用,记得点star哦~

发表评论

电子邮件地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据