"""Nearest neighbour searches"""
from __future__ import annotations
from abc import ABC, abstractmethod
from sklearn.neighbors._unsupervised import NearestNeighbors
[docs]class NNSearch(ABC):
"""
Abstract base class for nearest neighbour searches. Subclasses must
implement __init__ and Index.
"""
[docs] @abstractmethod
def __init__(self, **kwargs):
pass
[docs] def construct(self, X) -> Index:
"""
Construct the index based on the data X.
Parameters
----------
X : array shape=(n_instances, n_features, )
Construction instances.
Returns
-------
I : Index
Constructed index
"""
return self.Index(self, X)
[docs] class Index(ABC):
"""
Abstract base class for the index object created by NNSearch.construct.
Subclasses must implement __init__ and query.
Parameters
----------
search : NNSearch
The search object that contains all the relevant parametre values.
X : array shape=(n_instances, n_features, )
Construction instances.
"""
@abstractmethod
def __init__(self, search: NNSearch, X):
self._X = X
self._len = len(X)
def query_self(self, k):
return [a[:, 1:] for a in self.query(self._X, k + 1)]
[docs] @abstractmethod
def query(self, X, k: int):
"""
Identify the k nearest neighbours for each of the instances in X.
Parameters
----------
X : array shape=(n_instances, n_features, )
Query instances.
k : int
Number of neighbours to return
Returns
-------
I : array shape=(n_instances, k, )
Indices of the k nearest neighbours among the construction
instances for each query instance.
D : array shape=(n_instances, k, )
Distances to the k nearest neighbours among the construction
instances for each query instance.
"""
pass
def __len__(self):
return self._len
[docs]class BallTree(NNSearch):
"""
Nearest neighbour search with a Ball tree.
Parameters
----------
metric : str, default='manhattan'
The metric through which distances are defined.
leaf_size : int, default=30
The leaf size to be used for the Ball tree.
n_jobs : int, default=1
The number of parallel jobs to run for neighbour search. -1 means using
all processors.
"""
[docs] def __init__(self, *, metric: str = 'manhattan', leaf_size: int = 30,
n_jobs: int = 1):
self.construction_params = {
'algorithm': 'ball_tree',
'metric': metric,
'leaf_size': leaf_size,
'n_jobs': n_jobs,
}
[docs] class Index(NNSearch.Index):
def __init__(self, search: BallTree, X):
super().__init__(search, X)
self.tree = NearestNeighbors(**search.construction_params).fit(X)
[docs] def query(self, X, k: int):
return self.tree.kneighbors(X, n_neighbors=k)[::-1]
[docs]class KDTree(NNSearch):
"""
Nearest neighbour search with a KD-tree.
Parameters
----------
metric : str, default='manhattan'
The metric through which distances are defined.
leaf_size : int, default=30
The leaf size to be used for the KD-tree.
n_jobs : int, default=1
The number of parallel jobs to run for neighbour search. -1 means using
all processors.
"""
[docs] def __init__(self, *, metric: str = 'manhattan', leaf_size: int = 30,
n_jobs: int = 1):
self.construction_params = {
'algorithm': 'kd_tree',
'metric': metric,
'leaf_size': leaf_size,
'n_jobs': n_jobs,
}
[docs] class Index(NNSearch.Index):
def __init__(self, search: KDTree, X):
super().__init__(search, X)
self.tree = NearestNeighbors(**search.construction_params).fit(X)
[docs] def query(self, X, k: int):
return self.tree.kneighbors(X, n_neighbors=k)[::-1]