Medical Specialty Classification Using Patient Query
I have been having pain around my knee for more than a month. Which medical specialist should I see?
Recently, one of my friend came down with a pain around his wrist. Because I work at a hospital, he asked me to pick a medical specialist for him to consult. Since I am a good friend, I suggested for him to visit the best hand surgeon in town.
Being a good friend, I followed up on his consult. To my surprise, he said he had to see an urologist because it could be a kidney stone.
WHAT??! Kidney stones. You must be kidding.
Apparently, that is a possibility. The wrist pain could be a result of gout which in turn could be a result of kidney stones. For more info on joint pain, gout and kidney stones see this( link text)
That is when the idea for this blog popped up.
In this blog we ask the question, Can we build an AI chatbox that takes in a patient's query and outputs an appropriate specialty the patient should see?
Such AI assistants could come in handy to some healthcare providers to guide patients to the right specialty.
- It can be used by telephone operators (who might not have aqequate medical knowledge to answer such questions) at those hospitals to better serve the patients.
- It can also integrated with hospital's telemedicine app to guide patients.
Now that we have defined our problem statement. Let's spell out the plan to achieve this goal.
For this we will loosely follow this paper as our guide. The paper uses NLP to tackle a similar problem.
Briefly, the following is our plan
-
building the dataset
- we will extract patient queries from online medical forums
- the queries are organised under different topics such as
diabetes
,thyroid issues
- we manually map the topics to specialty
- clean the queries
-
training using fastai's
ULMFit
- build a language model using the dataset we have prepared
- build a text classifier using the dataset and the encoader we duild during the language model training
-
inference and understanding what we have built
This is the raw dataset following scraping the data from medical forums.
df_raw.head(5)
There are about 1368
unique topics. Let's take a look at some of them.
df_raw['topic'].unique()[random.choices(range(1368), k=50)]
As we can note some of the topics are name of medicines. In the subsequent process, I have decided to remove them.
df_processed.head(5)
I mapped the topics to specialty to the best of knowledge. Let's take a look at some of these.
map = df_processed[['topic', 'specialty']].drop_duplicates('topic').set_index('topic')
map.head(5)
map['specialty'].unique()
Apart from the mapping, some basic preprocessing was also done.
- removing white spaces
- removing unnecessary words
Next, to training those language models.
dls_lm = TextDataLoaders.from_df(df_processed,
is_lm=True,
valid_pct=0.1)
learn = language_model_learner(dls_lm,
AWD_LSTM,
metrics=[accuracy, Perplexity()],
wd=0.1).to_fp16()
We have defined the language model dataloader and learner as above.
Let's take a look at some of the vocab. As expected, fastai set up additional tokens such as xxunk
which is used when there is an unknown word that is not part of the vocab
set, xxbos
to indicate beginning of sentence.
dls_lm.vocab[:30]
That is a sufficient number of words in the vocab.
The training for the language model is defined below.
learn.fit_one_cycle(1, 1e-2)
learn.unfreeze()
learn.fit_one_cycle(10, 1e-3)
For classification, we first load the vocab
that we built from the language model training.
vocab = load_pickle('models/vocabv2.pkl')
Then, define our classification dataloader.
dls_clas = TextDataLoaders.from_df(df_processed,
text_col='text',
label_col='specialty',
text_vocab=vocab,
valid_pct=0.1,
seq_len=100,
bs=64,
is_lm=False,
y_block=CategoryBlock())
Let's take a look at some sample from the dls
dls_clas.show_batch()
Let's define our learner to do the job.
learn = text_classifier_learner(dls_clas,
AWD_LSTM,
drop_mult=0.6,
metrics=accuracy)
Load the encoder from the language model training. Also, lets take a look at the model we are using.
learn.model
The model contain the encoder module and the classification module. Later, we will load the encoder from the language model.
learn.model[0]
Our training set up is as below.
learn.fit_one_cycle(3, 2e-2)
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-2/(2.6**4),1e-2))
learn.freeze_to(-3)
learn.fit_one_cycle(1, slice(5e-3/(2.6**4),5e-3))
learn.unfreeze()
learn.fit_one_cycle(20, slice(1e-3/(2.6**4), 1e-3), cbs=[GradientAccumulation(16), SaveModelCallback(fname='classi_v2'), EarlyStoppingCallback(comp=np.less, patience=3)])
Let's try to interpret what we have achieved so far.
interp = ClassificationInterpretation.from_learner(learn)
Let's take a look at the confusion matrix
interp.plot_confusion_matrix(figsize=(15,15))
Weighted avg f1-score of 0.87 looks decent. Overall for all classes the metrics looks fine except maybe O&G, oncologist, plastic surgeon. Maybe our dataset might not contain enough data for these data.
interp.print_classification_report()
Let's take a look at one of the query that our model messed up. Honestly, even I would have gotten confused with that. Perhaps shows the limitation of our dataset preparation method we used.
interp.plot_top_losses(1)
Let's make some queries and see what our model says.
get_prediction('lump under my armpit', learn)
Query lump under my armpit
outputs general surgeon
with a confidence of 70.66%
. That is a good start.
Below we have few more examples. The model prefers descriptive queries which is understandable as that is how it was trained. Seizure
as a query gives wrong prediction while a descriptive queries that include seizure
predicts correctly.
The word feeling
strongly predicts psychiatrist
. A descriptive query around feeling
does help with better prediciton.
get_prediction('seizure', learn)
get_prediction('i have had seizure 3 times in the past 2 days', learn)
get_prediction('i am feeling cold', learn)
get_prediction('i am feeling cold at night with slight cough and fever', learn)
get_prediction('knee pain with slight swelling', learn)
get_prediction('my ldl levels are high', learn)
We have come to the end of the blog. What is next?
- increase the size of dataset
- use transformer based model