Source code for baselines.selective_prediction

import copy
import math
from pyexpat import model
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import shutil
import time
import torch.utils.data as data
import sys
import pickle
import logging
from tqdm import tqdm

sys.path.append("..")
from helpers.utils import *
from helpers.metrics import *
from .basemethod import BaseMethod

eps_cst = 1e-8


[docs]class SelectivePrediction(BaseMethod): """Selective Prediction method, train classifier on all data, and defer based on thresholding classifier confidence (max class prob)""" def __init__(self, model_class, device, plotting_interval=100): self.model_class = model_class self.device = device self.plotting_interval = plotting_interval self.treshold_class = 0.5
[docs] def fit_epoch_class(self, dataloader, optimizer, verbose=True, epoch=1): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() loss_fn = nn.CrossEntropyLoss() self.model_class.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) outputs = self.model_class(data_x) # cross entropy loss loss = F.cross_entropy(outputs, data_y) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, data_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def set_optimal_threshold(self, dataloader): """set threshold to maximize system accuracy on validation set Args: dataloader (_type_): dataloader validation set """ data_preds = self.test(dataloader) treshold_grid = data_preds["max_probs"] treshold_grid = np.append(treshold_grid, np.linspace(0, 1, 20)) best_treshold = 0 best_acc = 0 # optimize for system accuracy for treshold in treshold_grid: defers = (data_preds["max_probs"] < treshold) * 1 acc = sklearn.metrics.accuracy_score( data_preds["preds"] * (1 - defers) + data_preds["hum_preds"] * (defers), data_preds["labels"], ) if acc > best_acc: best_acc = acc best_treshold = treshold logging.info(f"Best treshold {best_treshold} with accuracy {best_acc}") self.treshold_class = best_treshold
[docs] def fit( self, dataloader_train, dataloader_val, dataloader_test, epochs, optimizer, lr, verbose=True, test_interval=5, scheduler=None, ): # fit classifier and expert same time optimizer_class = optimizer(self.model_class.parameters(), lr=lr) if scheduler is not None: scheduler = scheduler(optimizer, len(dataloader_train)*epochs) self.model_class.train() for epoch in tqdm(range(epochs)): self.fit_epoch_class( dataloader_train, optimizer_class, verbose=verbose, epoch=epoch ) if verbose and epoch % test_interval == 0: logging.info(compute_classification_metrics(self.test(dataloader_val))) if scheduler is not None: scheduler.step() self.set_optimal_threshold(dataloader_val) return compute_deferral_metrics(self.test(dataloader_test))
[docs] def test(self, dataloader): defers_all = [] max_probs = [] truths_all = [] hum_preds_all = [] predictions_all = [] # classifier only rej_score_all = [] # rejector probability class_probs_all = [] # classifier probability self.model_class.eval() with torch.no_grad(): for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs_class = self.model_class(data_x) outputs_class = F.softmax(outputs_class, dim=1) max_class_probs, predicted_class = torch.max(outputs_class.data, 1) predictions_all.extend(predicted_class.cpu().numpy()) truths_all.extend(data_y.cpu().numpy()) hum_preds_all.extend(hum_preds.cpu().numpy()) class_probs_all.extend(outputs_class.cpu().numpy()) defers = [] max_probs.extend(max_class_probs.cpu().numpy()) for i in range(len(data_y)): rej_score_all.extend([1 - max_class_probs[i].item()]) if max_class_probs[i] < self.treshold_class: defers.extend([1]) else: defers.extend([0]) defers_all.extend(defers) # convert to numpy defers_all = np.array(defers_all) truths_all = np.array(truths_all) hum_preds_all = np.array(hum_preds_all) predictions_all = np.array(predictions_all) max_probs = np.array(max_probs) rej_score_all = np.array(rej_score_all) class_probs_all = np.array(class_probs_all) data = { "defers": defers_all, "labels": truths_all, "max_probs": max_probs, "hum_preds": hum_preds_all, "preds": predictions_all, "rej_score": rej_score_all, "class_probs": class_probs_all, } return data