Source code for datasetsdefer.chestxray

from .basedataset import BaseDataset
import torch
import numpy as np
import os
import random
import shutil
import sys
import torch
import logging
import pandas as pd
import sys

sys.path.append("../")
import torchvision.transforms as transforms
from datasetsdefer.generic_dataset import GenericImageExpertDataset
import requests
import urllib.request
import tarfile


[docs]class ChestXrayDataset(BaseDataset): """Chest X-ray dataset from NIH with multiple radiologist annotations per point from Google Research""" def __init__( self, non_deferral_dataset, use_data_aug, data_dir, label_chosen, test_split=0.2, val_split=0.1, batch_size=1000, transforms=None, ): """ See https://nihcc.app.box.com/v/ChestXray-NIHCC and non_deferral_dataset (bool): if True, the dataset is the non-deferral dataset, meaning it is the full NIH dataset without the val-test of the human labeled, otherwise it is the deferral dataset that is only 4k in size total data_dir: where to save files for model label_chosen (int in 0,1,2,3): if non_deferral_dataset = False: which label to use between 0,1,2,3 which correspond to Fracture, Pneumotheras, Airspace Opacity, and Nodule/Mass; if true: then it's NoFinding or not, Pneumotheras, Effusion, Nodule/Mass use_data_aug: whether to use data augmentation (bool) test_split: percentage of test data val_split: percentage of data to be used for validation (from training set) batch_size: batch size for training transforms: data transforms """ self.non_deferral_dataset = non_deferral_dataset self.data_dir = data_dir self.use_data_aug = use_data_aug self.label_chosen = label_chosen self.test_split = test_split self.val_split = val_split self.batch_size = batch_size self.n_dataset = 2 self.train_split = 1 - test_split - val_split self.transforms = transforms self.generate_data()
[docs] def generate_data(self): """ generate data for training, validation and test sets """ links = [ "https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz", "https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz", "https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz", "https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz", "https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz", "https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz", "https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz", "https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz", "https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz", "https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz", "https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz", "https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz", ] max_links = 12 # 12 is the limit links = links[:max_links] if not os.path.exists(self.data_dir + "/images_nih"): logging.info("Downloading NIH dataset") for idx, link in enumerate(links): if not os.path.exists( self.data_dir + "/images_%02d.tar.gz" % (idx + 1) ): fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) logging.info("downloading " + fn + "...") urllib.request.urlretrieve(link, fn) # download the zip file logging.info("Download complete. Please check the checksums") # make directory if not os.path.exists(self.data_dir + "/images_nih"): os.makedirs(self.data_dir + "/images_nih") # extract files for idx in range(max_links): fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) logging.info("Extracting " + fn + "...") # os.system('tar -zxvf '+fn+' -C '+self.data_dir+'/images_nih') file = tarfile.open(fn) file.extractall(self.data_dir + "/images_nih") file.close() fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) os.remove(fn) logging.info("Done") else: # double check that all files are there and extracted # get number of files in directory # if not equal to 102120, then download again num_files = len( [ name for name in os.listdir(self.data_dir + "/images_nih") if os.path.isfile(os.path.join(self.data_dir + "/images_nih", name)) ] ) if num_files != 102120: # acutal is 112120 logging.info("Files missing. Re-downloading...") shutil.rmtree(self.data_dir + "/images_nih") for idx, link in enumerate(links): # check if file exists fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) if not os.path.exists( self.data_dir + "/images_%02d.tar.gz" % (idx + 1) ): logging.info("downloading " + fn + "...") urllib.request.urlretrieve(link, fn) logging.info("Download complete. Please check the checksums") # make directory if not os.path.exists(self.data_dir + "/images_nih"): os.makedirs(self.data_dir + "/images_nih") # extract files for idx in range(max_links): fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) logging.info("Extracting " + fn + "...") # os.system('tar -zxvf '+fn+' -C '+self.data_dir+'/images_nih') file = tarfile.open(fn) file.extractall(self.data_dir + "/images_nih") file.close() fn = self.data_dir + "/images_%02d.tar.gz" % (idx + 1) os.remove(fn) logging.info("Done") # DOWNLOAD CSV DATA FOR LABELS if ( not os.path.exists( self.data_dir + "/four_findings_expert_labels_individual_readers.csv" ) or not os.path.exists( self.data_dir + "/four_findings_expert_labels_test_labels.csv" ) or not os.path.exists( self.data_dir + "/four_findings_expert_labels_validation_labels.csv" ) or not os.path.exists(self.data_dir + "/Data_Entry_2017_v2020.csv") ): logging.info("Downloading readers NIH data") r = requests.get( "https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/individual_readers.csv", allow_redirects=True, ) with open( self.data_dir + "/four_findings_expert_labels_individual_readers.csv", "wb", ) as f: f.write(r.content) r = requests.get( "https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/test_labels.csv", allow_redirects=True, ) with open( self.data_dir + "/four_findings_expert_labels_test_labels.csv", "wb" ) as f: f.write(r.content) r = requests.get( "https://storage.googleapis.com/gcs-public-data--healthcare-nih-chest-xray-labels/four_findings_expert_labels/validation_labels.csv", allow_redirects=True, ) with open( self.data_dir + "/four_findings_expert_labels_validation_labels.csv", "wb", ) as f: f.write(r.content) logging.info("Finished Downloading readers NIH data") r = requests.get( "https://dl2.boxcloud.com/d/1/b1!54XsklVtK4wWu3E-qHtsc9JEcNkALHT_hWAUX_n2jJE1XNe5w7dXSoPbm8e1OqaA8xcmquQ2M-qbMKV_StRSatX2KMHE__vr4Z16j1mcnDlFWCXF71jpbZum1lRwn_i5iom8wJ7J0bx7Px2MJgefbb8QKUxvuMGEmnA3-e69TegyuT8wLqB0YPx19Bp8Iue4TKEJ457zFnPjtfC2p5le1yQjfIoxzKXi2oSFxZcSTAv8se3Eynm6ssbAKhs7CC9NbJruD1wuJmQYjcy3YQvEAIgdTTaIDItLiX-GWVIqgZkhPwbnnE0XGT2dm9TlS0WKq7saLydgx4ji88SmVTtr4t82V2h1AnL-6KXl2LPam8DWjSIeja-ehu5vSPf2Uyn4ShIwBHrJM7yFZWa9VrePrd7ANGMVKU879rD01gBrXTvoPdDSKe9KRHfPbCsBCkW0B2NN5T_GvKQAtN4-OFnmGF2QQDFLq17X32XKbPyfXHPAU3eICueFNOPo9MPOESoJ48gpSAIkIjB5VpCYDSK_2Bpm4U3EVfFGCx2FOo48B6_jTo1Xw2c04_RWVEDHfm30IBFMn8Qd5vFea1rFLCXQkKUCVipyLO4z2ezkIk8TaeL0u_6UnaN8bfsIhjyhxM1JwfJT0Z150njCPHHd3CmJe-Cg0Qq-TI9-P1yPFAfEBmkYVyeLzQ9H6HjdmAvMPUvXyXck-_EOBCSV9eIEkH_ZOhN1DHdd3kB-feB-2p_d65cXI1o-f-C5Ep15-AW8mnLn9UU6rG_EgsEyMev0kC75Zo0hO2rOXG_fbZ10PHt4fplmV02pYGDfvGvXtPOHF3nG_qqLF9hNgf1D6IqekWTxeS3SGEL-M3OEl6rhI_1TFg0OQLOy9UtibVPerHYmAwTyTJWFxbjp0pHXGdLZixpLWvZZ0H56aYlILyDP6PJcWMtvdXTRvBE3GdMRv0A1wVKK5tVcYeIEbAgR668AN2FIlbEDsrWjdE6DgvfiXtpaCWY8FmxRWu0UkA-GvAgVuMFhS4FFb5xzTnoUS9IBjkPB_hu8jCsC1qJXBavmWnOEVuTRMQIkhXeIdAw_OR2847VNxMpLOvpb2CIVHrRiDYXI25CO5VBR0jFk4YlX-1mfHApalBLgOaVVdQK4YCd_7J2EGij_41HipGcRSqgRUcZfecfgslc3vTP_vDAo4bASKeeHnKa-6WbP7FGrXyhur-Zv-olIQZPWX5iXJetVv4fHm0QoMLq2q1zk_ARioJSXCIaf9pM6a0rkCN5I64-DaugzU2XZ/download" ) with open(self.data_dir + "/Data_Entry_2017_v2020.csv", "wb") as f: f.write(r.content) try: readers_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_individual_readers.csv" ) test_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_test_labels.csv" ) validation_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_validation_labels.csv" ) all_dataset_data = pd.read_csv( self.data_dir + "/Data_Entry_2017_v2020.csv" ) except: logging.error("Failed to load readers NIH data") raise else: logging.info("Loading readers NIH data") try: readers_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_individual_readers.csv" ) test_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_test_labels.csv" ) validation_data = pd.read_csv( self.data_dir + "/four_findings_expert_labels_validation_labels.csv" ) all_dataset_data = pd.read_csv( self.data_dir + "/Data_Entry_2017_v2020.csv" ) except: logging.error("Failed to load readers NIH data") raise data_labels = {} for i in range(len(validation_data)): labels = [ validation_data.iloc[i]["Fracture"], validation_data.iloc[i]["Pneumothorax"], validation_data.iloc[i]["Airspace opacity"], validation_data.iloc[i]["Nodule or mass"], ] # covert YES to 1 and otherwise to 0 labels = [1 if x == "YES" else 0 for x in labels] data_labels[validation_data.iloc[i]["Image Index"]] = labels for i in range(len(test_data)): labels = [ test_data.iloc[i]["Fracture"], test_data.iloc[i]["Pneumothorax"], test_data.iloc[i]["Airspace opacity"], test_data.iloc[i]["Nodule or mass"], ] # covert YES to 1 and otherwise to 0 labels = [1 if x == "YES" else 0 for x in labels] data_labels[test_data.iloc[i]["Image Index"]] = labels data_human_labels = {} for i in range(len(readers_data)): labels = [ readers_data.iloc[i]["Fracture"], readers_data.iloc[i]["Pneumothorax"], readers_data.iloc[i]["Airspace opacity"], readers_data.iloc[i]["Nodule/mass"], ] # covert YES to 1 and otherwise to 0 labels = [1 if x == "YES" else 0 for x in labels] if readers_data.iloc[i]["Image ID"] in data_human_labels: data_human_labels[readers_data.iloc[i]["Image ID"]].append(labels) else: data_human_labels[readers_data.iloc[i]["Image ID"]] = [labels] # for each key in data_human_labels, we have a list of lists, sample only one list from each key data_human_labels = { k: random.sample(v, 1)[0] for k, v in data_human_labels.items() } labels_categories = [ "Fracture", "Pneumothorax", "Airspace opacity", "Nodule/mass", ] self.label_to_idx = { labels_categories[i]: i for i in range(len(labels_categories)) } image_to_patient_id = {} for i in range(len(readers_data)): image_to_patient_id[readers_data.iloc[i]["Image ID"]] = readers_data.iloc[ i ]["Patient ID"] patient_ids = list(set(image_to_patient_id.values())) data_all_nih_label = {} # the original dataset has the following labels ['Atelectasis' 'Cardiomegaly' 'Consolidation' 'Edema' 'Effusion' 'Emphysema' 'Fibrosis' 'Hernia' 'Infiltration' 'Mass' 'No Finding' 'Nodule' 'Pleural_Thickening' 'Pneumonia' 'Pneumothorax'] for i in range(len(all_dataset_data)): if not all_dataset_data["Patient ID"][i] in patient_ids: labels = [0, 0, 0, 0] if "Pneumothorax" in all_dataset_data["Finding Labels"][i]: labels[1] = 1 if "Effusion" in all_dataset_data["Finding Labels"][i]: labels[2] = 1 if ( "Mass" in all_dataset_data["Finding Labels"][i] or "Nodule" in all_dataset_data["Finding Labels"][i] ): labels[3] = 1 if "No Finding" in all_dataset_data["Finding Labels"][i]: labels[0] = 0 else: labels[0] = 1 data_all_nih_label[all_dataset_data["Image Index"][i]] = labels # depending on non_deferral_dataset transform_train = transforms.Compose( [ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) transform_test = transforms.Compose( [ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) if self.non_deferral_dataset == True: # iterate over key, value in data_all_nih_label data_y = [] data_expert = [] image_paths = [] for key, value in list(data_all_nih_label.items()): image_path = self.data_dir + "/images_nih/" + key # check if the file exists if os.path.isfile(image_path): data_y.append(value[self.label_chosen]) image_paths.append(self.data_dir + "/images_nih/" + key) data_expert.append(value[self.label_chosen]) # nonsense expert data_y = np.array(data_y) data_expert = np.array(data_expert) image_paths = np.array(image_paths) random_seed = random.randrange(10000) test_size = int(self.test_split * len(image_paths)) val_size = int(self.val_split * len(image_paths)) train_size = len(image_paths) - test_size - val_size train_x, val_x, test_x = torch.utils.data.random_split( image_paths, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_seed), ) train_y, val_y, test_y = torch.utils.data.random_split( data_y, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_seed), ) train_h, val_h, test_h = torch.utils.data.random_split( data_expert, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(random_seed), ) data_train = GenericImageExpertDataset( train_x.dataset[train_x.indices], train_y.dataset[train_y.indices], train_h.dataset[train_h.indices], transform_train, to_open=True, ) data_val = GenericImageExpertDataset( val_x.dataset[val_x.indices], val_y.dataset[val_y.indices], val_h.dataset[val_h.indices], transform_test, to_open=True, ) data_test = GenericImageExpertDataset( test_x.dataset[test_x.indices], test_y.dataset[test_y.indices], test_h.dataset[test_h.indices], transform_test, to_open=True, ) self.data_train_loader = torch.utils.data.DataLoader( data_train, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, ) self.data_val_loader = torch.utils.data.DataLoader( data_val, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) self.data_test_loader = torch.utils.data.DataLoader( data_test, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) else: # split patient_ids into train and test, val random.shuffle(patient_ids, random.random) # split using 80% for trarain, 10% for test and 10% for validation train_patient_ids = patient_ids[: int(len(patient_ids) * self.train_split)] test_patient_ids = patient_ids[ int(len(patient_ids) * self.train_split) : int( len(patient_ids) * (self.train_split + self.test_split) ) ] val_patient_ids = patient_ids[ int(len(patient_ids) * (self.train_split + self.test_split)) : ] # go from patient ids to image ids train_image_ids = np.array( [k for k, v in image_to_patient_id.items() if v in train_patient_ids] ) val_image_ids = np.array( [k for k, v in image_to_patient_id.items() if v in val_patient_ids] ) test_image_ids = np.array( [k for k, v in image_to_patient_id.items() if v in test_patient_ids] ) # remove images that are not in the directory train_image_ids = np.array( [ k for k in train_image_ids if os.path.isfile(self.data_dir + "/images_nih/" + k) ] ) val_image_ids = np.array( [ k for k in val_image_ids if os.path.isfile(self.data_dir + "/images_nih/" + k) ] ) test_image_ids = np.array( [ k for k in test_image_ids if os.path.isfile(self.data_dir + "/images_nih/" + k) ] ) logging.info("Finished splitting data into train, test and validation") # print sizes logging.info("Train size: {}".format(len(train_image_ids))) logging.info("Test size: {}".format(len(test_image_ids))) logging.info("Validation size: {}".format(len(val_image_ids))) train_y = np.array( [data_labels[k][self.label_chosen] for k in train_image_ids] ) val_y = np.array([data_labels[k][self.label_chosen] for k in val_image_ids]) test_y = np.array( [data_labels[k][self.label_chosen] for k in test_image_ids] ) train_h = np.array( [data_human_labels[k][self.label_chosen] for k in train_image_ids] ) val_h = np.array( [data_human_labels[k][self.label_chosen] for k in val_image_ids] ) test_h = np.array( [data_human_labels[k][self.label_chosen] for k in test_image_ids] ) train_image_ids = np.array( [self.data_dir + "/images_nih/" + k for k in train_image_ids] ) val_image_ids = np.array( [self.data_dir + "/images_nih/" + k for k in val_image_ids] ) test_image_ids = np.array( [self.data_dir + "/images_nih/" + k for k in test_image_ids] ) data_train = GenericImageExpertDataset( train_image_ids, train_y, train_h, transform_train, to_open=True ) data_val = GenericImageExpertDataset( val_image_ids, val_y, val_h, transform_test, to_open=True ) data_test = GenericImageExpertDataset( test_image_ids, test_y, test_h, transform_test, to_open=True ) self.data_train_loader = torch.utils.data.DataLoader( data_train, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, ) self.data_val_loader = torch.utils.data.DataLoader( data_val, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, ) self.data_test_loader = torch.utils.data.DataLoader( data_test, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True, )