Source code for caliber.binary_classification.ood.da_kolmogorov_interpolant
from typing import Any, Optional
import numpy as np
from scipy import stats
from scipy.special import kolmogorov
from caliber.binary_classification.base import AbstractBinaryClassificationModel
[docs]
class DistanceAwareKolmogorovInterpolantBinaryClassificationModel(
AbstractBinaryClassificationModel
):
def __init__(self, model: Optional[Any] = None):
super().__init__()
self.model = model
self._train_ecdf = None
[docs]
def fit(self, probs: np.ndarray, distances: np.ndarray, targets: np.ndarray):
if self.model is not None:
self.model.fit(probs, targets)
self._train_ecdf = stats.ecdf(distances).cdf
[docs]
def predict_proba(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray:
probs = np.copy(probs)
if self.model is not None:
probs = self.model.predict_proba(probs)
ecdf = stats.ecdf(distances).cdf
w = kolmogorov(
np.sqrt(len(distances))
* np.abs(ecdf.evaluate(distances) - self._train_ecdf.evaluate(distances))
)
return w * probs + 0.5 * (1 - w)
[docs]
def predict(self, probs: np.ndarray, distances: np.ndarray) -> np.ndarray:
return (self.predict_proba(probs, distances) >= 0.5).astype(int)