Source code for caliber.binary_classification.pred_from_probs_mixin
import numpy as np
from caliber.binary_classification.checks_mixin import BinaryClassificationChecksMixin
class PredFromProbsBinaryClassificationMixin(BinaryClassificationChecksMixin):
def __init__(self, threshold: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.threshold = threshold
def predict(self, probs: np.ndarray) -> np.ndarray:
self._check_probs(probs)
return (probs > self.threshold).astype(int)