Source code for caliber.multiclass_classification.ood.da_kolmogorov_interpolant

from typing import Any, Optional

import numpy as np
from scipy import stats
from scipy.special import kolmogorov

from caliber.multiclass_classification.base import AbstractMulticlassClassificationModel


[docs] class DistanceAwareKolmogorovInterpolantMulticlassClassificationModel( AbstractMulticlassClassificationModel ): def __init__(self, model: Optional[Any] = None): super().__init__() self.model = model self._train_ecdf = None
[docs] def fit(self, probs: np.ndarray, distances: np.ndarray, targets: np.ndarray): if self.model is not None: self.model.fit(probs, targets) self._train_ecdf = stats.ecdf(distances).cdf
[docs] def predict_proba(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray: probs = np.copy(probs) if self.model is not None: probs = self.model.predict_proba(probs) ecdf = stats.ecdf(distances).cdf w = kolmogorov( np.sqrt(len(distances)) * np.abs(ecdf.evaluate(distances) - self._train_ecdf.evaluate(distances)) )[:, None] return w * probs + (1 - w) / probs.shape[1]
[docs] def predict(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray: return np.argmax(self.predict_proba(probs, distances), axis=1)