Source code for baselines.differentiable_triage

import copy
import math
from pyexpat import model
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import random
import shutil
import time
import torch.utils.data as data
import sys
import pickle
import logging
from tqdm import tqdm

sys.path.append("..")
from helpers.utils import *
from helpers.metrics import *
from .basemethod import BaseMethod

eps_cst = 1e-8


[docs]def weighted_cross_entropy_loss(outputs, labels, weights): """ Weigthed cross entropy loss outputs: network outputs with softmax labels: target weights: weights for each example return: weighted cross entropy loss as scalar """ outputs = weights * F.cross_entropy(outputs, labels, reduction="none") # regular CE return torch.sum(outputs) / torch.sum(weights)
[docs]class DifferentiableTriage(BaseMethod): def __init__( self, model_class, model_rejector, device, weight_low=0.00, strategy="human_error", plotting_interval=100, ): """Method from the paper 'Differentiable Learning Under Triage' adapted to this setting Args: model_class (_type_): _description_ model_rejector (_type_): _description_ device (_type_): _description_ weight_low (float in [0,1], optional): weight for points that are deferred so that classifier trains less on them strategy (_type_): pick between "model_first", "human_error" "model_first" means that the rejector is 1 only if the human is correct and the model is wrong "human_error": the rejector is 1 if the human gets it right, otherwise 0 plotting_interval (int, optional): _description_. Defaults to 100. """ self.model_class = model_class self.model_rejector = model_rejector self.device = device self.weight_low = weight_low self.plotting_interval = plotting_interval self.strategy = strategy
[docs] def fit_epoch_class(self, dataloader, optimizer, verbose=True, epoch=1): """ train classifier for single epoch Args: dataloader (dataloader): _description_ optimizer (optimizer): _description_ verbose (bool, optional): to print loss or not. Defaults to True. epoch (int, optional): _description_. Defaults to 1. """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() loss_fn = nn.CrossEntropyLoss() self.model_class.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) outputs = self.model_class(data_x) # cross entropy loss loss = F.cross_entropy(outputs, data_y) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, data_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def find_machine_samples(self, model_outputs, data_y, hum_preds): """ Args: model_outputs (_type_): _description_ data_y (_type_): _description_ hum_preds (_type_): _description_ Returns: array: binary array of size equal to the input indicating whether to train or not on each poin """ max_class_probs, predicted_class = torch.max(model_outputs.data, 1) model_error = predicted_class != data_y hum_error = hum_preds != data_y rejector_labels = [] soft_weights_classifier = [] if self.strategy == "model_first": for i in range(len(model_outputs)): if not model_error[i]: rejector_labels.append(0) soft_weights_classifier.append(1) elif not hum_error[i]: rejector_labels.append(1) soft_weights_classifier.append(self.weight_low) else: rejector_labels.append(0) soft_weights_classifier.append(1.0) else: for i in range(len(model_outputs)): if not hum_error[i]: rejector_labels.append(1) soft_weights_classifier.append(self.weight_low) else: rejector_labels.append(0) soft_weights_classifier.append(1.0) rejector_labels = torch.cuda.LongTensor(rejector_labels) soft_weights_classifier = torch.tensor(soft_weights_classifier).to(self.device) return rejector_labels, soft_weights_classifier
[docs] def fit_epoch_class_triage(self, dataloader, optimizer, verbose=True, epoch=1): """ Fit the model for classifier for one epoch """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() self.model_class.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs = self.model_class(data_x) # cross entropy loss rejector_labels, soft_weights_classifier = self.find_machine_samples( outputs, data_y, hum_preds ) loss = weighted_cross_entropy_loss(outputs, data_y, soft_weights_classifier) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, data_y, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def fit_epoch_rejector(self, dataloader, optimizer, verbose=True, epoch=1): """ Fit the rejector for one epoch """ batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() end = time.time() loss_fn = nn.CrossEntropyLoss() self.model_rejector.train() for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs_class = self.model_class(data_x) rejector_labels, soft_weights_classifier = self.find_machine_samples( outputs_class, data_y, hum_preds ) outputs = self.model_rejector(data_x) # cross entropy loss loss = F.cross_entropy(outputs, rejector_labels) optimizer.zero_grad() loss.backward() optimizer.step() prec1 = accuracy(outputs.data, rejector_labels, topk=(1,))[0] losses.update(loss.data.item(), data_x.size(0)) top1.update(prec1.item(), data_x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if torch.isnan(loss): print("Nan loss") logging.warning(f"NAN LOSS") break if verbose and batch % self.plotting_interval == 0: logging.info( "Epoch: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})".format( epoch, batch, len(dataloader), batch_time=batch_time, loss=losses, top1=top1, ) )
[docs] def fit( self, dataloader_train, dataloader_val, dataloader_test, epochs, optimizer, lr, verbose=True, test_interval=5, scheduler=None, ): optimizer_class = optimizer(self.model_class.parameters(), lr=lr) optimizer_rejector = optimizer(self.model_rejector.parameters(), lr=lr) if scheduler is not None: scheduler_class = scheduler(optimizer_class, len(dataloader_train) * epochs) scheduler_rejector = scheduler(optimizer_rejector, len(dataloader_train) * epochs) self.model_class.train() self.model_rejector.train() logging.info("Re-training classifier on data based on the formula") for epoch in tqdm(range(int(epochs))): self.fit_epoch_class_triage( dataloader_train, optimizer_class, verbose=verbose, epoch=epoch ) if verbose and epoch % test_interval == 0: logging.info(compute_classification_metrics(self.test(dataloader_val))) if scheduler is not None: scheduler_class.step() # now fit rejector logging.info("Fitting rejector on all data") for epoch in tqdm(range(int(epochs))): self.fit_epoch_rejector( dataloader_train, optimizer_rejector, verbose=verbose, epoch=epoch ) if verbose and epoch % test_interval == 0: logging.info(compute_deferral_metrics(self.test(dataloader_val))) if scheduler is not None: scheduler_rejector.step() return compute_deferral_metrics(self.test(dataloader_test))
[docs] def fit_hyperparam( self, dataloader_train, dataloader_val, dataloader_test, epochs, optimizer, lr, verbose=True, test_interval=5, scheduler = None, ): weight_low_grid = [0, 0.1, 1] best_weight = 0 best_acc = 0 model_rejector_dict = copy.deepcopy(self.model_rejector.state_dict()) model_class_dict = copy.deepcopy(self.model_class.state_dict()) for weight in tqdm(weight_low_grid): self.weight_low = weight self.model_rejector.load_state_dict(model_rejector_dict) self.model_class.load_state_dict(model_class_dict) self.fit( dataloader_train, dataloader_val, dataloader_test, epochs, optimizer = optimizer, lr = lr, verbose = verbose, test_interval = test_interval, scheduler = scheduler, )["system_acc"] accuracy = compute_deferral_metrics(self.test(dataloader_val))["system_acc"] logging.info(f"weight low : {weight}, accuracy: {accuracy}") if accuracy > best_acc: best_acc = accuracy best_weight = weight self.weight_low = best_weight self.model_rejector.load_state_dict(model_rejector_dict) self.model_class.load_state_dict(model_class_dict) fit = self.fit( dataloader_train, dataloader_val, dataloader_test, epochs, optimizer = optimizer, lr = lr, verbose = verbose, test_interval = test_interval, scheduler = scheduler, ) test_metrics = compute_deferral_metrics(self.test(dataloader_test)) return test_metrics
[docs] def test(self, dataloader): defers_all = [] truths_all = [] hum_preds_all = [] predictions_all = [] # classifier only rej_score_all = [] # rejector probability class_probs_all = [] # classifier probability self.model_rejector.eval() self.model_class.eval() with torch.no_grad(): for batch, (data_x, data_y, hum_preds) in enumerate(dataloader): data_x = data_x.to(self.device) data_y = data_y.to(self.device) hum_preds = hum_preds.to(self.device) outputs_class = self.model_class(data_x) outputs_class = F.softmax(outputs_class, dim=1) outputs_rejector = self.model_rejector(data_x) outputs_rejector = F.softmax(outputs_rejector, dim=1) _, predictions_rejector = torch.max(outputs_rejector.data, 1) max_class_probs, predicted_class = torch.max(outputs_class.data, 1) predictions_all.extend(predicted_class.cpu().numpy()) truths_all.extend(data_y.cpu().numpy()) hum_preds_all.extend(hum_preds.cpu().numpy()) defers_all.extend(predictions_rejector.cpu().numpy()) rej_score_all.extend(outputs_rejector[:, 1].cpu().numpy()) class_probs_all.extend(outputs_class.cpu().numpy()) # convert to numpy defers_all = np.array(defers_all) truths_all = np.array(truths_all) hum_preds_all = np.array(hum_preds_all) predictions_all = np.array(predictions_all) rej_score_all = np.array(rej_score_all) class_probs_all = np.array(class_probs_all) data = { "defers": defers_all, "labels": truths_all, "hum_preds": hum_preds_all, "preds": predictions_all, "rej_score": rej_score_all, "class_probs": class_probs_all, } return data