Source code for caliber.multiclass_classification.ood.da_histogram_binning

import numpy as np
from scipy.stats import expon

from caliber.multiclass_classification.base import AbstractMulticlassClassificationModel


[docs] class DistanceAwareHistogramBinningMulticlassClassificationModel( AbstractMulticlassClassificationModel ): def __init__( self, n_prob_bins: int = 10, n_dist_bins: int = 10, conf_distance: float = 0.95, min_prob_bin: float = 0.01, ): super().__init__() self.n_prob_bins = n_prob_bins self.n_dist_bins = n_dist_bins self.conf_distance = conf_distance self._min_prob_bin = min_prob_bin self._quantile_distance = None self._prob_bin_edges = None self._dist_bin_edges = None self._ood_threshold = None self._max_distance = None self._n_classes = None
[docs] def fit(self, probs: np.ndarray, distances: np.ndarray, targets: np.ndarray): self._quantile_distance = np.quantile(distances, self.conf_distance) self._max_distance = np.max(distances) self._ood_threshold = expon(0, self._max_distance) self._n_classes = probs.shape[1] self._prob_bin_edges = self._get_prob_bin_edges() self._dist_bin_edges = self._get_dist_bin_edges() prob_bin_indices = np.digitize(probs, self._prob_bin_edges) dist_bin_indices = np.digitize(distances, self._dist_bin_edges) top_class_indices = np.argmax(probs, axis=1) self._params = np.empty( (self.n_prob_bins + 1, self.n_dist_bins + 1, self._n_classes) ) for i in range(1, self.n_prob_bins + 2): for j in range(1, self.n_dist_bins + 2): for c in range(self._n_classes): mask = self._get_mask( i, j, c, prob_bin_indices, dist_bin_indices, top_class_indices ) self._fit_bin(i, j, c, mask, targets)
[docs] def predict_proba(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray: if self._prob_bin_edges is None: raise ValueError("Run `fit` first.") if probs.shape[1] != self._n_classes: raise ValueError( "The number of classes when fitting and predicting must be the same." ) prob_bin_indices = np.digitize(probs, self._prob_bin_edges) dist_bin_indices = np.digitize(distances, self._dist_bin_edges) top_class_indices = np.argmax(probs, axis=1) probs = np.copy(probs) for i in range(1, self.n_prob_bins + 2): for j in range(1, self.n_dist_bins + 2): for c in range(0, self._n_classes): mask = self._get_mask( i, j, c, prob_bin_indices, dist_bin_indices, top_class_indices ) if not np.isnan(self._params[i - 1, j - 1, c]): probs[mask, c] = self._params[i - 1, j - 1, c] else: cdf = self._get_cdf(distances[mask]) probs[mask, c] = (1 - cdf) * probs[ mask, c ] + cdf / self._n_classes return probs
[docs] def predict(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray: return np.argmax(self.predict_proba(probs, distances), axis=1)
def _get_prob_bin_edges(self) -> np.ndarray: return np.linspace(0, 1, self.n_prob_bins + 1) def _get_dist_bin_edges(self) -> np.ndarray: return np.linspace(0, self._max_distance, self.n_dist_bins + 1) def _fit_bin(self, i: int, j: int, c: int, mask: np.ndarray, targets: np.ndarray): prob_bin = np.mean(mask) if prob_bin >= self._min_prob_bin: cdf = self._ood_threshold.cdf( self._dist_bin_edges[j - 1] - self._quantile_distance ) self._params[i - 1, j - 1, c] = (1 - cdf) * np.mean( targets[mask] == c ) + 0.5 * cdf else: self._params[i - 1, j - 1, c] = np.nan @staticmethod def _get_mask( prob_bin_idx: int, dist_bin_idx: int, class_idx: int, prob_bin_indices: np.ndarray, dist_bin_indices: np.ndarray, top_class_indices: np.ndarray, ) -> np.ndarray: return ( (prob_bin_indices[:, class_idx] == prob_bin_idx) * (top_class_indices == class_idx) * (dist_bin_indices == dist_bin_idx) ) def _get_cdf(self, d: float) -> np.ndarray: d -= self._quantile_distance return self._ood_threshold.cdf(d) * (d > 0)