Source code for caliber.ood.mahalanobis

from typing import Optional

import numpy as np
from scipy import stats
from scipy.linalg.lapack import dtrtri

from caliber.binary_classification.base import AbstractBinaryClassificationModel


[docs] class MahalanobisBinaryClassificationModel(AbstractBinaryClassificationModel): def __init__(self, threshold: float): super().__init__() self.threshold = threshold self._mean, self._chol = None, None
[docs] def fit(self, embeddings: np.ndarray, targets: Optional[np.ndarray] = None): if targets is None: self._mean, self._chol = self._get_mean_and_chol(embeddings) else: self._mean, self._chol = dict(), dict() unique_targets = np.unique(targets) for y in unique_targets: _embeddings = embeddings[targets == y] self._mean[y], self._chol[y] = self._get_mean_and_chol(_embeddings)
[docs] def predict_proba(self, embeddings: np.ndarray) -> np.ndarray: if isinstance(self._mean, dict): dists = [] for mean in self._mean.values(): dists.append(np.linalg.norm(embeddings - mean[None], axis=1)) dists = np.stack(dists, axis=1) indices = np.argmin(dists, 1) modes = np.array(list(self._mean.keys()))[indices] unique_modes = np.unique(modes) probs = np.zeros(len(embeddings)) for y in unique_modes: indices = np.where(modes == y)[0] _embeddings = embeddings[indices] probs[indices] = self._get_probs( _embeddings, self._mean[y], self._chol[y] ) return probs return self._get_probs(embeddings, self._mean, self._chol)
[docs] def predict(self, embeddings: np.ndarray) -> np.ndarray: return (self.predict_proba(embeddings) > self.threshold).astype(int)
@staticmethod def _get_mean_and_chol(embeddings: np.ndarray) -> tuple[np.ndarray, np.ndarray]: mean = np.mean(embeddings, 0) cov = np.cov(embeddings.T) chol = np.linalg.cholesky(cov) chol = dtrtri(chol, lower=True)[0].T return mean, chol def _get_probs( self, embeddings: np.ndarray, mean: np.ndarray, chol: np.ndarray ) -> np.ndarray: transformed_embeddings = self._get_transformed_embeddings( embeddings, mean, chol ) return stats.chi2(embeddings.shape[1]).cdf(transformed_embeddings) @staticmethod def _get_transformed_embeddings( embeddings: np.ndarray, mean: np.ndarray, chol: np.ndarray ) -> np.ndarray: return np.sum(np.matmul(embeddings - mean[None], chol) ** 2, axis=1)