DETECTION OF TUBERS WITH CONVOLUTIONAL NEURAL NETWORKS

PERFORMANCE OF ResNet50 IN THE TEST SET (STUDY REQUESTED DURING PEER-REVIEW)

Import packages and functions

In [1]:
# Import packages
%matplotlib inline
from PIL import Image
import numpy as np
import os
from skimage.color import gray2rgb
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
!pip install tensorflow
!pip install keras
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten, GaussianNoise, BatchNormalization, GlobalAveragePooling2D
from keras.layers import Conv2D, MaxPooling2D
from keras import Sequential
from keras.optimizers import Adam
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from keras.preprocessing import image
from keras.models import Model
from keras import backend as K
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score
!pip install git+https://github.com/raghakot/keras-vis.git --upgrade
from vis.visualization import visualize_cam, visualize_saliency, overlay
from keras import activations
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import zipfile
from keras.models import model_from_json
import matplotlib as mpl
Requirement already satisfied: tensorflow in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (1.12.0)
Requirement already satisfied: gast>=0.2.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (0.2.2)
Requirement already satisfied: keras-applications>=1.0.6 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.0.6)
Requirement already satisfied: numpy>=1.13.3 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.14.3)
Requirement already satisfied: keras-preprocessing>=1.0.5 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.0.5)
Requirement already satisfied: grpcio>=1.8.6 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.18.0)
Requirement already satisfied: six>=1.10.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.11.0)
Requirement already satisfied: tensorboard<1.13.0,>=1.12.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.12.2)
Requirement already satisfied: absl-py>=0.1.6 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (0.7.0)
Requirement already satisfied: termcolor>=1.1.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (1.1.0)
Requirement already satisfied: wheel>=0.26 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (0.32.3)
Requirement already satisfied: astor>=0.6.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (0.7.1)
Requirement already satisfied: protobuf>=3.6.1 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorflow) (3.6.1)
Requirement already satisfied: h5py in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras-applications>=1.0.6->tensorflow) (2.7.1)
Requirement already satisfied: werkzeug>=0.11.10 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorboard<1.13.0,>=1.12.0->tensorflow) (0.14.1)
Requirement already satisfied: markdown>=2.6.8 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from tensorboard<1.13.0,>=1.12.0->tensorflow) (3.0.1)
Requirement already satisfied: setuptools in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from protobuf>=3.6.1->tensorflow) (40.7.1)
You are using pip version 19.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'python -m pip install --upgrade pip' command.
Requirement already satisfied: keras in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (2.2.4)
Requirement already satisfied: pyyaml in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (3.12)
Requirement already satisfied: scipy>=0.14 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (1.1.0)
Requirement already satisfied: h5py in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (2.7.1)
Requirement already satisfied: keras-applications>=1.0.6 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (1.0.6)
Requirement already satisfied: keras-preprocessing>=1.0.5 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (1.0.5)
Requirement already satisfied: six>=1.9.0 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (1.11.0)
Requirement already satisfied: numpy>=1.9.1 in c:\users\ivansanchezfernandez\anaconda3\lib\site-packages (from keras) (1.14.3)
You are using pip version 19.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'python -m pip install --upgrade pip' command.
C:\Users\IvanSanchezFernandez\Anaconda3\lib\site-packages\h5py\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Collecting git+https://github.com/raghakot/keras-vis.git
  Cloning https://github.com/raghakot/keras-vis.git to c:\users\ivansa~1\appdata\local\temp\pip-req-build-noz36bis
error: unable to write file images/conv_vis/block4_conv3_filters.jpg
error: unable to write file images/conv_vis/block5_conv3_filters.jpg
error: unable to write file images/conv_vis/block5_conv3_filters_no_tv.jpg
error: unable to write file images/conv_vis/cover.jpg
error: unable to write file images/conv_vis/filter_67.png
error: unable to write file images/dense_vis/cover.png
error: unable to write file images/dense_vis/ouzel_vis.png
error: unable to write file images/dense_vis/random_imagenet.png
error: unable to write file images/dense_vis/random_imagenet_no_tv.png
error: unable to write file images/opt_progress.gif
error: unable to write file resources/imagenet_class_index.json
error: unable to write file tests/vis/backend/test_backend.py
error: unable to write file tests/vis/utils/test_utils.py
error: unable to write file tests/vis/visualization/test_saliency.py
error: unable to write file vis/backend/tensorflow_backend.py
error: unable to write file vis/backend/theano_backend.py
error: unable to write file vis/backprop_modifiers.py
error: unable to write file vis/callbacks.py
error: unable to create file vis/grad_modifiers.py: No space left on device
error: unable to create file vis/input_modifiers.py: No space left on device
error: unable to create file vis/losses.py: No space left on device
error: unable to create file vis/optimizer.py: No space left on device
error: unable to create file vis/regularizers.py: No space left on device
fatal: cannot create directory at 'vis/utils': No space left on device
warning: Clone succeeded, but checkout failed.
You can inspect what was checked out with 'git status'
and retry the checkout with 'git checkout -f HEAD'

Command "git clone -q https://github.com/raghakot/keras-vis.git C:\Users\IVANSA~1\AppData\Local\Temp\pip-req-build-noz36bis" failed with error code 128 in None
You are using pip version 19.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'python -m pip install --upgrade pip' command.

FIRST PART: DATA INGESTION

Import original images from local computer

We used Magnetic Resonance Imaging (MRI) scans from 114 patients with tuberous sclerosis complex (TSC) and from 114 patients with structurally normal MRI (controls).

For each MRI, we manually selected representative axial T2 and T2 FLAIR slices with tubers (in patients with TSC) and with normal findings (in controls). These axial slices were converted to deidentified .jpg images.

We created three folders per TSC and three folders for controls: TSCtrain (566 images), TSCval (130 images), and TSCtest (210 images) and Controltrain (561 images), Controlval (118 images), and Controltest (226 images). Individual patients belonged to only one of the categories (none of the patients had images in different folders).

For the model development part done in a cloud computer environment we only used the train and validation subset. We selected the model with lowest binary cross-entropy error in the validation set as the best model. The best model (InceptionV3) was saved and its performance was evaluated in the local computer on the test set (data not seen previously by the model).

During the peer-review process, the performance of the TSCCNN and ResNet50 architectures on the test set were requested.

In [2]:
# Set the figure size
mpl.rcParams['figure.figsize'] = (16,10)
In [3]:
# Unzip files
with zipfile.ZipFile("Controltest.zip","r") as zip_ref:
    zip_ref.extractall()
with zipfile.ZipFile("TSCtest.zip","r") as zip_ref:
    zip_ref.extractall()

Path to original images folder

In [4]:
# Path to the folder with the original images
pathtoimagesControltest = './Controltest/'

pathtoimagesTSCtest = './TSCtest/'

SECOND PART: IMPORTATION OF FINAL DATA

Import images and labels for the test set

In [5]:
## CONTROLS

# Define the image size
image_size = (224, 224)

# Read in the test images for controls
Controltest_images = []
Controltest_dir = pathtoimagesControltest
Controltest_files = os.listdir(Controltest_dir)
# For each image
for f in Controltest_files:
  # Open the image
  img = Image.open(Controltest_dir + f)
  # Resize the image so that it has a size 224x224
  img = img.resize(image_size)
  # Transform into a numpy array
  img_arr = np.array(img)
  # Transform from 224x224 to 224x224x3
  if img_arr.shape == image_size:
        img_arr = np.expand_dims(img_arr, 3)
        img_arr = gray2rgb(img_arr[:, :, 0])
  # Add the image to the array of images      
  Controltest_images.append(img_arr)

# After having transformed all images, transform the list into a numpy array  
Controltest_X = np.array(Controltest_images)

# Create an array of labels (0 for controls)
Controltest_y = np.array([[0]*Controltest_X.shape[0]]).T



## TSC

# Read in the test images for TSC
TSCtest_images = []
TSCtest_dir = pathtoimagesTSCtest
TSCtest_files = os.listdir(TSCtest_dir)
# For each image
for f in TSCtest_files:
  # Open the image
  img = Image.open(TSCtest_dir + f)
  # Resize the image so that it has a size 224x224
  img = img.resize(image_size)
  # Transform into a numpy array
  img_arr = np.array(img)
  # Transform from 224x224 to 224x224x3
  if img_arr.shape == image_size:
        img_arr = np.expand_dims(img_arr, 3)
        img_arr = gray2rgb(img_arr[:, :, 0])
  # Add the image to the array of images      
  TSCtest_images.append(img_arr)

# After having transformed all images, transform the list into a numpy array  
TSCtest_X = np.array(TSCtest_images)

# Create an array of labels (1 for TSC)
TSCtest_y = np.array([[1]*TSCtest_X.shape[0]]).T


## MERGE CONTROLS AND TSC

# Train merge files
test_X = np.concatenate([Controltest_X, TSCtest_X])
test_y = np.vstack((Controltest_y, TSCtest_y))

# GPU expects values to be 32-bit floats
test_X = test_X.astype(np.float32)

# Rescale the values to be between 0 and 1
test_X /= 255.
In [6]:
# Shuffle in unison the test_X and the test_y array (123 is just a random number for reproducibility)
shuffled_test_X, shuffled_test_y = shuffle(test_X, test_y, random_state=123)
In [7]:
shuffled_test_X.shape
Out[7]:
(499, 224, 224, 3)
In [8]:
# Example of an image to make sure they were converted right
plt.imshow(shuffled_test_X[0])
plt.grid(b=None)
plt.xticks([])
plt.yticks([])
plt.show()
In [9]:
shuffled_test_y.shape
Out[9]:
(499, 1)
In [10]:
shuffled_test_y[0]
Out[10]:
array([0])

THIRD PART: EVALUATE NEURAL NETWORK PERFORMANCE IN THE TEST SET

Load the model

In [11]:
# load model
json_file = open('ResNet50.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights("ResNet50.h5")

Test the model with the test data

In [12]:
# Compile model
model.compile(optimizer = Adam(lr = 0.00025), loss = 'binary_crossentropy', metrics = ['accuracy'])

# Generate predictions on test data in the form of probabilities
testResNet50 = model.predict(shuffled_test_X, batch_size = 16)
testResNet50
Out[12]:
array([[4.04824641e-05],
       [9.64355111e-01],
       [1.34498137e-07],
       [1.91993751e-02],
       [1.92810949e-05],
       [3.80822399e-04],
       [1.33629330e-08],
       [6.73725968e-03],
       [3.93042428e-04],
       [2.31301718e-04],
       [1.09248921e-08],
       [0.00000000e+00],
       [8.84703877e-06],
       [7.94118638e-09],
       [2.26143104e-08],
       [1.01196398e-04],
       [2.64528226e-02],
       [1.43634225e-07],
       [9.99906898e-01],
       [9.70076781e-06],
       [9.46363699e-09],
       [7.75704114e-03],
       [1.00000000e+00],
       [2.68115127e-05],
       [1.44215577e-04],
       [9.99959588e-01],
       [4.14162715e-12],
       [2.06526490e-10],
       [9.99999881e-01],
       [1.65087840e-05],
       [1.45655245e-01],
       [1.00000000e+00],
       [1.56712311e-03],
       [2.13648695e-02],
       [9.99962211e-01],
       [2.23436047e-07],
       [2.05710693e-14],
       [9.99610841e-01],
       [1.35714572e-05],
       [1.56904960e-07],
       [9.98852015e-01],
       [2.34237905e-06],
       [1.00000000e+00],
       [3.37540433e-02],
       [7.95376450e-02],
       [4.36060131e-03],
       [1.97926275e-02],
       [9.98673320e-01],
       [2.74353405e-11],
       [9.67591550e-05],
       [4.63243964e-07],
       [9.51221585e-01],
       [9.99991775e-01],
       [0.00000000e+00],
       [9.99894023e-01],
       [9.62791732e-04],
       [5.33263483e-06],
       [2.79609725e-04],
       [2.10892676e-06],
       [1.68454586e-04],
       [9.99501467e-01],
       [1.00000000e+00],
       [4.22167301e-04],
       [3.55677112e-05],
       [6.65327825e-05],
       [6.85180680e-07],
       [9.99999642e-01],
       [9.99945641e-01],
       [1.39172209e-06],
       [1.05721399e-03],
       [5.93247998e-04],
       [9.99989271e-01],
       [9.88096237e-01],
       [1.35035731e-03],
       [1.09545954e-05],
       [9.99804914e-01],
       [9.99997258e-01],
       [9.99850273e-01],
       [1.91062969e-14],
       [1.00000000e+00],
       [7.97810555e-01],
       [9.99999523e-01],
       [9.98119533e-01],
       [9.99942064e-01],
       [3.01489752e-04],
       [1.00000000e+00],
       [3.38092514e-06],
       [9.38011825e-01],
       [9.99866366e-01],
       [0.00000000e+00],
       [9.12125273e-08],
       [9.99816954e-01],
       [1.67517005e-08],
       [1.95319753e-05],
       [9.99952435e-01],
       [9.99950290e-01],
       [9.99979377e-01],
       [9.99992490e-01],
       [5.40369947e-04],
       [3.99783054e-11],
       [2.86333021e-02],
       [2.21684058e-06],
       [9.99984145e-01],
       [1.00000000e+00],
       [9.98567224e-01],
       [1.00000000e+00],
       [3.19666713e-01],
       [9.99999881e-01],
       [1.69066698e-04],
       [9.99992728e-01],
       [2.43302900e-04],
       [1.27740080e-07],
       [9.99979615e-01],
       [1.69139767e-05],
       [4.74393637e-05],
       [9.99993086e-01],
       [2.04184130e-02],
       [6.85191771e-05],
       [1.60174415e-04],
       [3.07130191e-04],
       [2.42566341e-04],
       [7.64821231e-01],
       [1.00000000e+00],
       [9.99988556e-01],
       [1.51595857e-04],
       [5.56958802e-02],
       [4.31551380e-05],
       [7.21960217e-02],
       [9.99996543e-01],
       [4.62486297e-01],
       [9.99977589e-01],
       [5.91151416e-03],
       [7.64214754e-01],
       [3.88993521e-06],
       [7.66507655e-05],
       [2.90765369e-04],
       [9.99991775e-01],
       [1.59517193e-04],
       [1.65164459e-27],
       [9.82807398e-01],
       [8.01126873e-07],
       [4.61785385e-05],
       [1.26973964e-06],
       [6.45093223e-09],
       [9.99999881e-01],
       [6.47667912e-04],
       [6.59189318e-05],
       [2.30633418e-06],
       [9.99960303e-01],
       [1.00000000e+00],
       [1.66472782e-05],
       [1.00000000e+00],
       [1.00000000e+00],
       [5.62212008e-05],
       [9.31209797e-05],
       [3.01089822e-05],
       [5.72998269e-22],
       [9.99999881e-01],
       [1.70966893e-01],
       [9.99999166e-01],
       [9.99947667e-01],
       [1.90808130e-13],
       [1.70645732e-02],
       [9.90076721e-01],
       [1.00000000e+00],
       [8.58354761e-05],
       [1.05125026e-03],
       [9.99995112e-01],
       [9.99271572e-01],
       [2.17188662e-03],
       [9.99971271e-01],
       [9.99995232e-01],
       [1.72268876e-04],
       [2.83232104e-04],
       [9.57520604e-01],
       [9.72910583e-01],
       [2.17218811e-04],
       [1.37371928e-04],
       [4.40445775e-03],
       [1.07509259e-05],
       [4.16055380e-04],
       [0.00000000e+00],
       [9.99561608e-01],
       [1.47870187e-05],
       [0.00000000e+00],
       [9.98784602e-01],
       [4.17083676e-04],
       [2.14510266e-07],
       [9.99909759e-01],
       [9.99917626e-01],
       [2.51408457e-03],
       [4.09871544e-04],
       [9.99890447e-01],
       [1.45933619e-02],
       [6.06916046e-06],
       [8.44406545e-01],
       [1.60973440e-25],
       [9.99996305e-01],
       [9.99988317e-01],
       [1.00000000e+00],
       [6.39466671e-05],
       [2.90723973e-10],
       [1.56784354e-05],
       [8.79158676e-02],
       [2.91677134e-05],
       [2.01183167e-18],
       [1.60475553e-04],
       [1.95669476e-04],
       [8.01126873e-07],
       [9.99999642e-01],
       [9.91696715e-01],
       [9.99788105e-01],
       [9.69437301e-01],
       [4.19664939e-05],
       [2.86333021e-02],
       [4.62294417e-03],
       [9.99998212e-01],
       [8.74309407e-08],
       [1.84196236e-09],
       [2.18433575e-04],
       [2.42846509e-05],
       [2.88475148e-05],
       [1.81203040e-06],
       [3.34074639e-06],
       [6.90106943e-04],
       [9.88882840e-01],
       [9.99349773e-01],
       [2.28220542e-05],
       [9.99998331e-01],
       [3.91802704e-03],
       [1.58487601e-05],
       [2.45597260e-03],
       [1.71472151e-02],
       [2.82231235e-06],
       [9.99943376e-01],
       [9.02124524e-01],
       [9.86783981e-01],
       [8.64363522e-28],
       [1.00000000e+00],
       [9.99907732e-01],
       [5.42552304e-03],
       [9.99983430e-01],
       [1.08106353e-04],
       [1.43920459e-10],
       [4.82254181e-05],
       [9.94006038e-01],
       [5.04288217e-03],
       [9.96676922e-01],
       [7.70291626e-06],
       [6.10072553e-01],
       [1.05866275e-06],
       [4.01324105e-05],
       [1.00000000e+00],
       [1.17642572e-04],
       [9.99996185e-01],
       [6.23968561e-28],
       [3.66602671e-06],
       [7.07563420e-04],
       [5.08713929e-05],
       [9.99998450e-01],
       [7.43270141e-16],
       [9.99634743e-01],
       [6.31848067e-08],
       [9.69967587e-05],
       [9.99999642e-01],
       [5.75678227e-28],
       [6.22712264e-08],
       [5.61116103e-05],
       [7.20351934e-01],
       [1.00000000e+00],
       [4.53309505e-04],
       [4.18438395e-09],
       [1.36343078e-04],
       [1.00000000e+00],
       [3.56623495e-04],
       [4.48890205e-05],
       [9.96679068e-01],
       [2.82413960e-02],
       [3.59575545e-16],
       [9.99528170e-01],
       [7.03334987e-01],
       [9.99985337e-01],
       [9.99958754e-01],
       [9.99910712e-01],
       [5.71482058e-04],
       [3.71144125e-07],
       [2.24545831e-03],
       [9.91875231e-01],
       [7.91134516e-06],
       [1.25644757e-08],
       [5.45969269e-05],
       [9.99983311e-01],
       [3.37794781e-01],
       [1.42589361e-05],
       [1.29109598e-03],
       [1.14822399e-18],
       [2.85292626e-04],
       [4.92930167e-06],
       [9.90916908e-01],
       [4.24527656e-03],
       [1.24115532e-03],
       [9.72329915e-01],
       [7.88826108e-01],
       [9.89049792e-01],
       [1.23315651e-04],
       [5.06865850e-04],
       [6.68113411e-04],
       [1.87135185e-03],
       [5.49206891e-21],
       [9.98354971e-01],
       [2.74350308e-03],
       [1.12942278e-01],
       [4.62879718e-04],
       [3.47326445e-10],
       [1.00000000e+00],
       [3.65633787e-05],
       [9.99840736e-01],
       [2.63274428e-12],
       [9.99979496e-01],
       [9.99150276e-01],
       [2.72178331e-05],
       [3.20226754e-05],
       [3.94183174e-02],
       [1.79349348e-01],
       [2.51993999e-25],
       [1.80276984e-03],
       [1.00000000e+00],
       [5.54274581e-02],
       [5.03087603e-02],
       [1.56255391e-11],
       [9.85920131e-01],
       [8.73764293e-05],
       [3.09841752e-08],
       [1.61057003e-02],
       [3.95034067e-06],
       [3.48682329e-06],
       [4.93961535e-02],
       [1.00000000e+00],
       [2.18521035e-03],
       [9.97713804e-01],
       [3.80649362e-05],
       [9.99996066e-01],
       [1.68994859e-01],
       [8.77717393e-04],
       [1.00644899e-03],
       [4.46913873e-05],
       [3.05271504e-04],
       [4.46592458e-06],
       [6.57207129e-05],
       [2.35794206e-09],
       [9.99967098e-01],
       [9.99914527e-01],
       [4.82658303e-04],
       [4.63958486e-07],
       [1.00000000e+00],
       [9.99961376e-01],
       [9.99638438e-01],
       [9.99747217e-01],
       [2.95535987e-03],
       [9.99530911e-01],
       [5.15024077e-08],
       [9.99889374e-01],
       [9.99603450e-01],
       [5.45184076e-01],
       [9.82501209e-01],
       [9.97623384e-01],
       [1.96987037e-02],
       [4.42107441e-03],
       [9.20440769e-04],
       [2.49005214e-04],
       [1.57722097e-05],
       [3.37544293e-03],
       [3.27324997e-05],
       [4.80695853e-05],
       [9.98897433e-01],
       [8.48214768e-05],
       [9.99983907e-01],
       [1.68942133e-05],
       [9.99999881e-01],
       [9.87843513e-01],
       [9.99989986e-01],
       [1.54991955e-01],
       [3.22988235e-05],
       [7.52202250e-05],
       [3.86031112e-03],
       [9.94161308e-01],
       [7.44783402e-11],
       [1.01932428e-05],
       [3.82144764e-23],
       [3.80464896e-11],
       [3.83904873e-04],
       [2.26714343e-04],
       [1.93421228e-03],
       [9.75408971e-01],
       [1.00000000e+00],
       [1.75708266e-17],
       [2.99904550e-05],
       [2.34444251e-05],
       [2.63116381e-04],
       [7.11865359e-05],
       [9.99871135e-01],
       [1.63089499e-04],
       [9.98622298e-01],
       [2.44862027e-03],
       [8.51557791e-01],
       [1.18297403e-05],
       [4.80955293e-07],
       [2.17035602e-04],
       [1.22802085e-05],
       [5.47172010e-01],
       [1.08072974e-01],
       [6.99270313e-05],
       [1.37258773e-08],
       [9.99999881e-01],
       [3.27069908e-02],
       [5.24073052e-09],
       [4.48695220e-10],
       [9.48164880e-01],
       [4.53850989e-05],
       [2.05244198e-02],
       [2.38185407e-06],
       [9.99915242e-01],
       [9.99978065e-01],
       [1.21619388e-01],
       [1.00000000e+00],
       [9.99996543e-01],
       [6.01608539e-04],
       [4.07690793e-01],
       [8.52803569e-05],
       [9.99998689e-01],
       [9.99999166e-01],
       [7.74859654e-06],
       [2.62641173e-04],
       [4.64482815e-04],
       [1.55799182e-06],
       [3.01155239e-01],
       [9.83088076e-01],
       [9.99980330e-01],
       [9.96753156e-01],
       [3.67795692e-05],
       [2.66248971e-01],
       [9.99841332e-01],
       [9.99999166e-01],
       [2.02511015e-04],
       [1.07427906e-04],
       [9.99989748e-01],
       [9.94196057e-01],
       [9.99973297e-01],
       [5.45706971e-05],
       [3.29598581e-04],
       [9.99778926e-01],
       [9.99999881e-01],
       [1.47237894e-04],
       [9.99998450e-01],
       [9.91480708e-01],
       [4.03864741e-01],
       [1.29920207e-02],
       [5.50120453e-12],
       [1.69519899e-05],
       [9.99976516e-01],
       [1.00000000e+00],
       [2.49727226e-08],
       [9.99108016e-01],
       [7.50426916e-05],
       [9.98073697e-01],
       [1.38421054e-03],
       [6.98602974e-01],
       [9.99999881e-01],
       [4.91045266e-01],
       [2.29586079e-03],
       [4.82530653e-04],
       [9.99998212e-01],
       [9.86854196e-01],
       [5.84911322e-04],
       [9.99991179e-01],
       [9.99964237e-01],
       [9.99993443e-01],
       [5.16579748e-05],
       [5.78787527e-04],
       [9.63169277e-01],
       [1.16534387e-07],
       [3.89328925e-05],
       [2.81662524e-05],
       [2.37497261e-05],
       [6.22449789e-08],
       [3.23103070e-02],
       [1.21884909e-06],
       [2.76602618e-03],
       [8.83053303e-07],
       [1.31156585e-05],
       [1.12716305e-04],
       [5.66848577e-08],
       [3.69390250e-06],
       [3.70577908e-07],
       [3.53625230e-03],
       [9.99188244e-01],
       [9.99999762e-01],
       [9.99931335e-01],
       [9.99893785e-01]], dtype=float32)
In [13]:
# Generate the confusion matrix
y_true = shuffled_test_y
y_predResNet50 = testResNet50 > 0.5

confusion_matrix(y_true, y_predResNet50)
Out[13]:
array([[287,   2],
       [ 28, 182]], dtype=int64)
In [14]:
# Calculate accuracy in the test set
accuracy_ResNet50 = (confusion_matrix(y_true, y_predResNet50)[0, 0] + confusion_matrix(y_true, y_predResNet50)[1, 1]) / (confusion_matrix(y_true, y_predResNet50)[0, 0] + confusion_matrix(y_true, y_predResNet50)[0, 1] + confusion_matrix(y_true, y_predResNet50)[1, 0] + confusion_matrix(y_true, y_predResNet50)[1, 1])
print('The accuracy in the test set is {}.'.format(accuracy_ResNet50))
The accuracy in the test set is 0.9398797595190381.
In [15]:
# Calculate AUC in the test set
auc_testResNet50 = roc_auc_score(y_true, testResNet50)
print('The AUC in the test set is {}.'.format(auc_testResNet50))
The AUC in the test set is 0.9902949415060143.