Source code for baselines.one_v_all

import copy
import math
from pyexpat import model
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import shutil
import time
import torch.utils.data as data
import sys
import logging
from tqdm import tqdm

sys.path.append("..")
from helpers.utils import *
from helpers.metrics import *
from .basemethod import BaseMethod

eps_cst = 1e-8


[docs]class OVASurrogate(BaseMethod): """Method of OvA surrogate from Calibrated Learning to Defer with One-vs-All Classifiers https://proceedings.mlr.press/v162/verma22c/verma22c.pdf""" def __init__(self, alpha, plotting_interval, model, device): self.alpha = alpha self.plotting_interval = plotting_interval self.model = model self.device = device # from https://github.com/rajevv/OvA-L2D/blob/main/losses/losses.py
[docs] def LogisticLossOVA(self, outputs, y): outputs[torch.where(outputs == 0.0)] = (-1 * y) * (-1 * np.inf) l = torch.log2(1 + torch.exp((-1 * y) * outputs + eps_cst) + eps_cst) return l
[docs] def ova_loss(self, outputs, m, labels): """ outputs: network outputs m: cost of deferring to expert cost of classifier predicting hum_preds == target labels: target """ batch_size = outputs.size()[0] l1 = self.LogisticLossOVA(outputs[range(batch_size), labels], 1) l2 = torch.sum( self.LogisticLossOVA(outputs[:, :-1], -1), dim=1 ) - self.LogisticLossOVA(outputs[range(batch_size), labels], -1) l3 = self.LogisticLossOVA(outputs[range(batch_size), -1], -1) l4 = self.LogisticLossOVA(outputs[range(batch_size), -1], 1) l5 = m * (l4 - l3) l = l1 + l2 + l3 + l5 return torch.mean(l)
[docs] def fit_epoch(self, dataloader, optimizer, verbose=True, epoch=1): """ Fit the model for one epoch model: model to be trained dataloader: dataloader optimizer: optimizer verbose: print loss epoch: epoch number """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() self.model.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) m = (hum_preds == data_y) * 1 m2 = self.alpha * (hum_preds == data_y) * 1 + (hum_preds != data_y) * 1 m = torch.tensor(m).to(self.device) outputs = self.model(data_x) loss = self.ova_loss(outputs, m, data_y) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, data_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def fit( self, dataloader_train, dataloader_val, dataloader_test, epochs, optimizer, lr, verbose=True, test_interval=5, scheduler=None, ): optimizer = optimizer(self.model.parameters(), lr) if scheduler is not None: scheduler = scheduler(optimizer, len(dataloader_train)*epochs) for epoch in tqdm(range(epochs)): self.fit_epoch(dataloader_train, optimizer, verbose, epoch) if verbose and epoch % test_interval == 0: data_test = self.test(dataloader_val) logging.info(compute_deferral_metrics(data_test)) if scheduler is not None: scheduler.step() final_test = self.test(dataloader_test) return compute_deferral_metrics(final_test)
[docs] def test(self, dataloader): defers_all = [] truths_all = [] hum_preds_all = [] predictions_all = [] # classifier only rej_score_all = [] # rejector probability class_probs_all = [] # classifier probability self.model.eval() with torch.no_grad(): for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs = self.model(data_x) outputs_class = F.softmax(outputs[:, :-1], dim=1) outputs = F.softmax(outputs, dim=1) _, predicted = torch.max(outputs.data, 1) max_probs, predicted_class = torch.max(outputs.data[:, :-1], 1) predictions_all.extend(predicted_class.cpu().numpy()) defers_all.extend( (predicted.cpu().numpy() == len(outputs.data[0]) - 1).astype(int) ) truths_all.extend(data_y.cpu().numpy()) hum_preds_all.extend(hum_preds.cpu().numpy()) for i in range(len(outputs.data)): rej_score_all.append( outputs.data[i][-1].item() - outputs.data[i][predicted_class[i]].item() ) class_probs_all.extend(outputs_class.cpu().numpy()) # convert to numpy defers_all = np.array(defers_all) truths_all = np.array(truths_all) hum_preds_all = np.array(hum_preds_all) predictions_all = np.array(predictions_all) rej_score_all = np.array(rej_score_all) class_probs_all = np.array(class_probs_all) data = { "defers": defers_all, "labels": truths_all, "hum_preds": hum_preds_all, "preds": predictions_all, "rej_score": rej_score_all, "class_probs": class_probs_all, } return data