Source code for frlearn.neighbours.neighbour_search

"""Nearest neighbour searches"""
from __future__ import annotations

from abc import ABC, abstractmethod

from sklearn.neighbors._unsupervised import NearestNeighbors


[docs]class NNSearch(ABC): """ Abstract base class for nearest neighbour searches. Subclasses must implement __init__ and Index. """
[docs] @abstractmethod def __init__(self, **kwargs): pass
[docs] def construct(self, X) -> Index: """ Construct the index based on the data X. Parameters ---------- X : array shape=(n_instances, n_features, ) Construction instances. Returns ------- I : Index Constructed index """ return self.Index(self, X)
[docs] class Index(ABC): """ Abstract base class for the index object created by NNSearch.construct. Subclasses must implement __init__ and query. Parameters ---------- search : NNSearch The search object that contains all the relevant parametre values. X : array shape=(n_instances, n_features, ) Construction instances. """ @abstractmethod def __init__(self, search: NNSearch, X): self._X = X self._len = len(X) def query_self(self, k): return [a[:, 1:] for a in self.query(self._X, k + 1)]
[docs] @abstractmethod def query(self, X, k: int): """ Identify the k nearest neighbours for each of the instances in X. Parameters ---------- X : array shape=(n_instances, n_features, ) Query instances. k : int Number of neighbours to return Returns ------- I : array shape=(n_instances, k, ) Indices of the k nearest neighbours among the construction instances for each query instance. D : array shape=(n_instances, k, ) Distances to the k nearest neighbours among the construction instances for each query instance. """ pass
def __len__(self): return self._len
[docs]class BallTree(NNSearch): """ Nearest neighbour search with a Ball tree. Parameters ---------- metric : str, default='manhattan' The metric through which distances are defined. leaf_size : int, default=30 The leaf size to be used for the Ball tree. n_jobs : int, default=1 The number of parallel jobs to run for neighbour search. -1 means using all processors. """
[docs] def __init__(self, *, metric: str = 'manhattan', leaf_size: int = 30, n_jobs: int = 1): self.construction_params = { 'algorithm': 'ball_tree', 'metric': metric, 'leaf_size': leaf_size, 'n_jobs': n_jobs, }
[docs] class Index(NNSearch.Index): def __init__(self, search: BallTree, X): super().__init__(search, X) self.tree = NearestNeighbors(**search.construction_params).fit(X)
[docs] def query(self, X, k: int): return self.tree.kneighbors(X, n_neighbors=k)[::-1]
[docs]class KDTree(NNSearch): """ Nearest neighbour search with a KD-tree. Parameters ---------- metric : str, default='manhattan' The metric through which distances are defined. leaf_size : int, default=30 The leaf size to be used for the KD-tree. n_jobs : int, default=1 The number of parallel jobs to run for neighbour search. -1 means using all processors. """
[docs] def __init__(self, *, metric: str = 'manhattan', leaf_size: int = 30, n_jobs: int = 1): self.construction_params = { 'algorithm': 'kd_tree', 'metric': metric, 'leaf_size': leaf_size, 'n_jobs': n_jobs, }
[docs] class Index(NNSearch.Index): def __init__(self, search: KDTree, X): super().__init__(search, X) self.tree = NearestNeighbors(**search.construction_params).fit(X)
[docs] def query(self, X, k: int): return self.tree.kneighbors(X, n_neighbors=k)[::-1]