What is DICOM?
Understanding DICOM and using fastai to read and work with DICOM images
- What is DICOM?
- Building a classifier with fastai
- Interpreting the classifier with fastai
- CAM and GradCAM
What is DICOM?
- DICOM stands for Digital Imaging and COmmunication in Medicine
- It is a software integration standard that is used in medical imaging
- It is the standard that establishes rules that allows different medical imaging modalities (such as X-Rays, Ultrasound, CT, MRI) from different vendors and hospitals to exchange information between them
There core of DICOM
- DICOM File Format
- This is the important for part for DL.
- DICOM images uses the
.dcm
extension -
.dcm
allows patient data, image pixel values to be stored under different tags - Like mentioned above apart from images, DICOM also contains patient details (such as patient name and age) and image acqusition data (such as type of equipment used)
- DICOM Network protocol
- The protocol allows for information exchange between different imaging modalities that is connected to the hospital network
- Used for searching for images from the archive and to display images on the workstation
- This protocol can also be used to monitor treatment, schedule procedures, report status
Let's use fastai
to read a dcm
file and use it to understand the information it contains. We will use the SIIM-ACR Pneumothorax Segentation dataset.
! pip install fastai -q --upgrade
! pip install pydicom kornia opencv-python scikit-image nbdev -q
import fastai
print(fastai.__version__)
from fastai.basics import *
from fastai.callback.all import *
from fastai.vision.all import *
from fastai.medical.imaging import *
import pydicom
import pandas as pd
# downloading the dataset
pneumothorax_source = untar_data(URLs.SIIM_SMALL)
# reading the dcm files
items = get_dicom_files(pneumothorax_source/f"train/")
# lets read the dcm file for patient 11
patient = 11
xray_sample = items[patient].dcmread()
# lets take a look at the DICOM metafile
xray_sample
As can be seen, the dataset file has many rows. Each row contains a data element.
An example of a data element is
(0028, 0010) Rows US: 1024
Let's break down the data element
(0028, 0010)
is the tag. There are two parts - Group (0028
) and Element (0010
). From our example of patient 11 above, we can see that Group 0010
groups all patient details.
In our data element example, the tag is followed by Rows
which describes the data element. Following the tag and its description, the next value is Value Representation (VR) which describes the data type of the data element. In our example, the VR is US
which means Unsigned Short
. The VR is then followed by Value Length. In our example, there are 1024 rows.
There are 1,000s of data elements in the DICOM. In this, we will focus on the 0028
group which describes different image/pixel related attributes and (7fe0, 0010)
which describes and contains the pixel data.
-
(0028, 0002) Samples per pixel
- This indicates if the image is grayscale (1) or RGB (3). In our example, we have a grayscale image. -
(0028, 0004) Photometric Interpretation
- describes the color space of our image. Some of the possible values are -MONOCHROME
,MONOCHROME2
,PALETTE COLOR
,RGB
. In our case it isMONOCHROME2
where low values are dark and high values are bright. It is the opposite inMONOCHROME
.PALETTE COLOR
contains a color image with a single sample per pixel.RGB
describes red, green and blue image planes. ForRGB
, samples per pixel would be 3. -
(0028, 0010) Rows
- describes the number of rows in the image. In our example, there are 1024 rows. -
(0028, 0011) Columns
- describes the number of columns in the image. In our example, there are 1024 columns. -
(0028, 0030) Pixel Spacing
- describes the distance between centers of two neighbouring pixels. In our example[0.19431099999999998, 0.19431099999999998]
, the first number is the Row Spacing, the second number is the Column Spacing. -
(0028, 0100) Bits Allocated
- Number of bits allocated for each pixel sample. -
(0028, 0101) Bits Stored
- Number of bits stored for each pixel sample. A 8 bits image would have pixel value between 0-255. -
(0028, 0102) High Bit
- Most significant bit for pixel sample data. Each sample shall have the same high bit. -
(0028, 0103) Pixel Representation
- can either be unsigned(0) or signed(1). If you are like me and need a refresher on signed vs unsigned integer, here is a link -
(0028, 2110) Lossy Image Compression
- Specifies whether an Image has undergone lossy compression.00
image has not been subjected to lossy compression.01
image has been subjected to lossy compression. lossy compression or irreversible compression is the class of data encoding methods that uses inexact approximations and partial data discarding to represent the content. These techniques are used to reduce data size for storing, handling, and transmitting content. -
(0028, 2114) Lossy Image Compression Method
- the methods used in Lossy Image compression. ISO_10918_1 : JPEG Lossy Compression, ISO_14495_1 : JPEG-LS Near-lossless Compression, ISO_15444_1 : JPEG 2000 Irreversible Compression : ISO_13818_2 MPEG2 Compression, ISO_14496_10 : MPEG-4 AVC/H.264 Compression, ISO_23008_2 : HEVC/H.265 Lossy Compression. In our example, it isISO_10918_1
which isJPEG Lossy Compression
. -
(7fe0, 0010) Pixel Data
- an array of pixel data. Data type is OB. Let's take a look below.
Apart from the above, let's also understand the below
-
(0008,0060) Modality
- Type of equipment that originally acquired the data used to create the images in this Series.For all the different values, refer here. Some examples are CT : Computed Tomography, CR: Computed Radiography, DX : Digital Radiography, ES : Endoscopy, IVUS : Intravascular Ultrasound
Let's take a look at a sample of PixelData
xray_sample.PixelData[:200]
As the raw PixelData
are complex. Let's use .pixel_array
to read the data in a more familiar format.
xray_sample.pixel_array
Let's use .show
to show the image
xray_sample.show()
fastai provides the following function to create dataframe from the dicom files. Apart from reading the DICOM file, it also calculate summary statistics of the image pixels (mean/min/max/std) when px_summ is set to True
dicom_dataframe = pd.DataFrame.from_dicoms(items, px_summ=True)
dicom_dataframe[:5]
We have 250 samples which wouldn't be sufficient to build anything of significance. We will use it to understand the DICOM data and to learn how fastai can be used to work with medical images.
len(dicom_dataframe)
Let's take a look at the different columns. img_min
, img_max
, img_mean
, img_std
, img_pct_window
are calculated by from_dicom fastai function.
dicom_dataframe.columns
We have CR
as the Modelity which is Computed Radiograpy
dicom_dataframe['Modality'].unique()
We can see the age distribution of the patient.
plt.style.use('seaborn')
dicom_dataframe['PatientAge'].astype(int).hist(bins=10)
plt.xlabel('Age')
plt.ylabel('Frequency')
We have an equal number of M and F
dicom_dataframe['PatientSex'].hist()
Next, let's use fastai.medical
to build a simple classifier. Again, we only have 250 samples which won't be enough to build anything of significance. The goal is to show how we can use fastai.medical
to work with medical images.
Let's get the different folders/files.
pneumothorax_source.ls()
Let's read the labels.csv
and see what is in it
train = pd.read_csv(pneumothorax_source/'labels.csv')
Let's build a simple dataloader
pneumothorax = DataBlock(blocks=(ImageBlock(cls=PILDicom), CategoryBlock),
get_x=lambda x:pneumothorax_source/f"{x[0]}",
get_y=lambda x:x[1],
splitter=RandomSplitter(),
batch_tfms=aug_transforms(size=400))
dls = pneumothorax.dataloaders(train.values, bs=8)
Once, we have made the dataloader, we can take a look at a batch.
dls.show_batch(max_n=8)
Let's use fastai's xrestnet
model - xresnet is based on "Bag of Tricks for ResNet" paper. We also use Mish
activation instead of the usual ReLU
and we will use self-attention
.
model = xresnet50(pretrained=False, act_cls=Mish, sa=True, n_out=2)
model[0][0]
# Here, we will set the first layer to accept single channel image
model[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
For the optimizer, we will use ranger which uses RAdam and Lookahead.
learn = Learner(dls,
model=model,
loss_func=LabelSmoothingCrossEntropy(),
metrics= accuracy,
opt_func=ranger)
learn.lr_find()
learn.fit_flat_cos(5, 5e-4)
Among the many cool things fastai provides, interpretation is one. Let's take a look at the classifier and see how well our classifier is doing.
learn.show_results(max_n=8)
# Let's initiate a ClassificationInterpretation
interp = ClassificationInterpretation.from_learner(learn)
Let's take a look at our top_losses. Our classifier confuses No Pneumothorax for Pneumothorax. This is likely because of the lack of training data. Again, the goal here is to understand how we can use fastai and its many tools.
interp.plot_top_losses(6, figsize=(14,8))
As expected, there is a lot of False Negative (predicts "No Pneumothorax" when it is "Pneumothorax")
interp.plot_confusion_matrix(figsize=(7,7))
credits: https://docs.paperspace.com/machine-learning/wiki/interpretability
Interpretability or explainability is the degree to which a model's prediction/decision can be explained in human terms. This is a huge area of research as often ML models are said to be a black-box with no interpretability. There are certain tools being developed to address this area. Among those tools are CAM and GradCAM.
Class Activation Map (CAM) uses the activation of the last convolution layer and the predictions of the last layer to plot heatmap visualization. The visualization gives an idea of why the model made its decision. In medical imaging, this sort of heatmap visualization could augment radiologists and other doctors apart from doing the classification. Fastbook has a chapter dedicated to CAM and GradCAM which can be found here.
Let's see how we can make use of CAM.
Below, we define a hook class. Hooks are similar to callbacks and they let us inject codes into forward and backward calculation.
class Hook():
def hook_func(self, m, i, o): self.stored = o.detach().clone()
Let's define the path for no_pneumothorax and pneumothorax class
nopneumo = (pneumothorax_source/'train').ls()[0].ls()
pneumo = (pneumothorax_source/'train').ls()[1].ls()
Then, we initiate the Hook class and use the register_forward_hook to attach the hook class to the forward function. learn.model[-5]
would access the whole xresnet model without the head and register_forward_hook would be able to attach our hook to the last convolution layer.
hook_output = Hook()
hook = learn.model[-5].register_forward_hook(hook_output.hook_func)
Let's define a function to grab a sampel of image either from nopneumo or pneumo folders.
def grab_x(path, patient):
x = first(dls.test_dl([path[patient]]))
return x[0]
Let's define a function to get the CAM map. As you can see we make use of the einsum function. It is awesome funtion and here is one of my fav video on this topic.
def get_cammap(x):
with torch.no_grad():
output = learn.model.eval()(x)
act = hook_output.stored[0]
print(F.softmax(output, dim=-1))
cam_map = torch.einsum('ck,kij->cij', learn.model[-1].weight, act)
hook.remove()
return cam_map
Then, Let's define a function to plot two images - left_image
= input image and right_image
= input image superimposed by the CAM activation heatmap. idx
=0 to see nopneumo class activation and idx
=1 to see pneumo class activation.
def plot_cam(x, cls, cam_map, img_size=400):
x_dec = TensorDicom(dls.train.decode((x,))[0][0])
_,ax = plt.subplots(1,2, figsize=(15,10))
x_dec.show(ctx=ax[0])
x_dec.show(ctx=ax[1])
ax[1].imshow(cam_map[cls].detach().cpu(), alpha=0.6, extent=(0,img_size,img_size,0),
interpolation='bilinear', cmap='magma');
Let's write a function to wrap everything.
def get_plot_cam_image(path, patient, cls):
x = grab_x(path, patient)
cam_map = get_cammap(x)
plot_cam(x, cls, cam_map, img_size=400)
Below, is an example for No pneumothorax. Areas in bright yellow/orange corresponds to high activations while areas in purple corresponds to low activations. Unfortunately, our classifier hasn't learnt much to show this. Hence, lets see pic from fastbook.
The activation map on the cat allows one to peek into model's 'reasons' for its prediction. In medical imaging, this might highlight tumors and other such abnormalities that the radiologists could further scrutanize.
get_plot_image(nopneumo, 5, 0)
Below is the same image but observing for class=1 or for Pneumothorax. Some of the bright orange are around the lungs as oppose to above where for No Pneumothorax the lungs appeared purple highlighting no activation around the lungs.
get_plot_image(nopneumo, 5, 1)
Having seen how to use CAM in fastai. Let's take a look at GradCAM.
GradCAM is similar to CAM except in GradCAM we make use of the gradient to plot the visualization. Because we use gradient, we are able to plot the visualization for the earlier conv
layers too. With CAM, we were only able to observe the visualization for the final conv
layer because once we obtained the activation of the conv layer, we need to multiply by the last weight matrix.This method only works for the final conv
layer. This variant was introduced in the paper - "Grad-CAM: Why Did You Say That? Visual Explanations from Deep Networks via Gradient-based Localization" in 2016.
# A hook to store the output of a layer
class Hook():
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_func)
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
# A hook to store the grad of a layer
class HookBwd():
def __init__(self, m):
self.hook = m.register_backward_hook(self.hook_func)
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
def get_gradcammap(x, cls, model_layer):
with HookBwd(model_layer) as hookg:
with Hook(model_layer) as hook:
output = learn.model.eval()(x.cuda())
act = hook.stored
output[0,cls].backward()
grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
return cam_map
def plot_gcam(x, img_size=400):
x_dec = TensorDicom(dls.train.decode((x,))[0][0])
_,ax = plt.subplots(1,2, figsize=(15,10))
x_dec.show(ctx=ax[0])
x_dec.show(ctx=ax[1])
ax[1].imshow(cam_map.detach().cpu(), alpha=0.6, extent=(0,img_size,img_size,0),
interpolation='bilinear', cmap='magma');
def get_plot_gcam_image(path, patient, cls, layer=-5):
x = grab_x(path, patient)
cam_map = get_gradcammap(x, cls, layer)
plot_gcam(x)
get_plot_gcam_image(nopneumo, 7, cls=0, layer=learn.model[-5])
get_plot_gcam_image(nopneumo, 7, cls=1, layer=learn.model[-5])
get_plot_gcam_image(nopneumo, 7, cls=0, layer=learn.model[-5])