Source code for caliber.multiclass_classification.minimizing.linear_scaling.calibration.focal_linear_scaling

from functools import partial
from typing import Optional

from caliber.multiclass_classification.metrics import focal_loss
from caliber.multiclass_classification.minimizing.linear_scaling.calibration.base import (
    CalibrationLinearScalingMulticlassClassificationModel,
)
from caliber.multiclass_classification.minimizing.linear_scaling.linear_scaling_smooth_fit_mixin_ import (
    LinearScalingSmoothFitMulticlassClassificationMixin,
)


[docs] class FocalLinearScalingMulticlassClassificationModel( LinearScalingSmoothFitMulticlassClassificationMixin, CalibrationLinearScalingMulticlassClassificationModel, ): def __init__( self, minimize_options: Optional[dict] = None, has_intercept: bool = True, has_shared_intercept: bool = False, has_cross_slopes: bool = True, has_shared_slope: bool = False, gamma: float = 2.0, ): super().__init__( partial(focal_loss, gamma=gamma), minimize_options, has_intercept=has_intercept, has_shared_intercept=has_shared_intercept, has_shared_slope=has_shared_slope, has_cross_slopes=has_cross_slopes, )