An attempt to classify patients with RBD (REM sleep behaviour disorder) and predict whether an RBD patient will develop the Parkinson disease (PD) using an RNN-augmented LGGNet with PSG data.
The code architecture and the 'LGGnet' model is taken from yi-ding-cs's
repository (https://github.com/yi-ding-cs/LGG).
You probably want to check their repository for better understanding.
I added an RNN part (a Gru layer) and made the training in 3 phases.
- 1st phase: Train the non-RNN part (LGGNet) only
- 2nd phase: Train the RNN part only
- 3rd phase: Train the whole model
A resnet50-inspired model is also used to compare with RNNLGGnet. This model is adapted to take the same inputs as
RNNLGGnet.
The first considered dataset is from a sleep centre, containing whole-night EEG, EOG (electrooculogram), ECG (electrocardiogram) and EMG (electromyogram) recordings of patients suffering from REM sleep behaviour disorder (RBD). Some of these patients were later diagnosed with PD, and the dataset does not include healthy subjects. This dataset is not publicly available. This dataset alone can only train neural networks to predict whether an RBD sufferer will develop PD. This is why the DREAMS database (https://zenodo.org/records/2650142#.ZFJ0pnbMJD8), especially the DREAMS Subjects Database, which comprises whole-night EEG and ECG data from healthy subjects is added to complete the dataset. This enables the detection of RBD in comparison to healthy patients.
A drawback from using multiple datasets simultaneously is data incompatibility. Therefore, data is preprocessed through a band-pass filter (0.5-50 Hz), resampled to 128 Hz, and finally normalised. An optional ICA is present in the code. The data is cut into epochs of 4 seconds (512 points) by default.
The dataset from the sleep centre includes the channels F3-A2, F4-A1, C3-A2, C4-A1, O1-A2, O2-A1, LOC-A2, EEG ROC-A1, EMG Chin, ECG EKG, EMG Left_Leg, EMG Right_Leg, Snoring Snore, Airflow, Resp Thorax, Resp Abdomen, Manual. Those are the channels used for PD detection.
However, because the recordings (sleep center and DREAMS) do not include the same channels, RBD detection is limited to only three channels (O1, O2, and ECG) are used for RBD detection.
I added a sigmoid layer at the end of the forward method of the model and the fc layer is augmented by an
intermediate Linear layer.
It is trained exactly as in the original repository otherwise, except for the early stopping condition. In order to
reduce overfitting risks, the early stopping condition now use the validation set instead of the unused data from the
inner 3-fold (which is from the same subject whom training data belong to).
The data is shuffled between subjects while training.
A bigger RNNLGGNet model is being trained.
The output of the GCN is used for the GRU layer and the fc layer from phase 1 (named fcLGG).
The output of the GRU layer goes into another fc layer.
Then, the outputs of the fcLGG and fc go into a Linear (named out_weights) layer before the sigmoid.
This phase only trains the GRU, fc, fcLGG and the final Linear layers.
Since the GRU layer's hidden state is intended to be specific to each subject, data isn't shuffled while training and
linearly given to the model.
Then the hidden state is reset at each new subject.
All the parameters are being trained simultaneously. Everything else is like the 2nd phase.
The training of the resnet model is the same as the modified first stage of RNNLGGnet.
RNNLGGnet is an RNN-augmented LGGnet. Recurrent neural networks (RNN) are a type
of neural network designed to handle time-series data, which is a data variable in length. It can also retain
information to influence future predictions. The most well-known are long short-term memory networks (LSTM) and gated
recurrent units (GRU). The choice of the RNN was straightforward as LSTM and GRU have similar performances while GRU
requires fewer parameters, thus making it easier to train.
The multiphase pipeline is designed to give sense to the different parts of the model. The LGGnet is meant to focus on
learning to recognise patterns while the RNN part tries to keep the relevant information depending on the past data.Due
to the presence of RNN, phase 2 and 3 have the input training data presented linearly subject per subject, resetting the
RNN part between each subject. This choice is motivated by the will of letting the model trying to predict from an
entire time-series input as a physician would do.
ResNet, the ResNet-50-inspired model is developed to handle 1D data, especially the time-series data provided. While
RNNLGGnet aims at predicting over the whole subject’s segment, ResNet will attempt to recognise the pattern
associated with either RBD or PD. It is indeed well known that ResNet-50 networks perform very well in image
classification.
Short version: Don't train phases 2 and 3. Phase 1 is largely better (which actually consists of a slightly less
overfitting LGGnet).
Link to results: https://docs.google.com/document/d/1E_bEYQ98wCeGFsEgza_1uB-pqniT3MLgYNjN5CrfDJs
Note that the enhanced mention means that the phase 1 is the modified one.
- RBD
The performance is assessed using a subject-wise cross-validation. It means that for each fold, a certain percentage of subjects are excluded from the training set and are used as the validation set. In the following results, a 20% excluded subjects rate is applied on the 39 subjects for RBD (both DREAMS and sleep centre datasets) and the 19 subjects for the PD (sleep centre dataset only).
RNNLGGnet easily reaches 99% of accuracy when predicting whether a subject is suffering from RBD after
phase 1. Phase 2 actually reduces accuracy to below 50% and tends to make the model favour one output
over the other (0 or 1 for the presence of RBD). Phase 3 can have different behaviour, it was able to fix
the bias of phase 2 as well as going deeper in the bias and almost only giving one type of output.
As a result, it is more efficient to only keep the slightly modified LGGnet training step. The high accuracy might also come from the distinct datasets since data from DREAMS is labelled as 0 (healthy) while data from sleep centre is labelled as 1 (have RBD) even if everything have been normalised. A whole dataset from the same conditions will be crucial to have trustable results.
The adapted ResNet-50 has similar results when detecting RBD as LGGnet, which is to easily reach 99%
of accuracy. The resemblance doesn’t stop here, as ResNet-50 could only reach 68% accuracy at best for
a certain subject group, while it had between 19% and 31% accuracy for other groups.
- Parkinson
On the parkinson part, the results highly depend on the subjects taken for training. The best results involve the EEG
channels.
RNNLGGnet can occasionally reach 80% of accuracy when determining whether a
patient will develop the PD. However, this result is only attained for a certain subject group out of the 5
folds. For the other 4 groups, the accuracy can only reach 40% at best. Overall, the mean accuracy is only
about 50% while ResNet-50 could only reach 68% accuracy at best for
a certain subject group, while it had between 19% and 31% accuracy for other groups.
Given the repartition of results, the labels should be more refined. The labels are currently naive for the parkinson
part. The whole REM segments would be labelled as 1 (meaning will develop parkinson later).
However, the manifestation of the parkinson's disease is probably not visible on the entirety of REM segments. It will be better to isolate the characteristic marks.
Moreover, the time between the record time and the actual parkinson disease's contraction might be too long, thus not having parkinson's disease's specific marks.
- Preprocessing
concatenate_HOSP.pyconcatenatesedffiles from the used dataset (not public). Can be useful as a base for other databases.edf_preprocessing.pypreprocessesedffiles and store the results asfiffiles. Only useful forvisualize_fif.py.edf_to_dat.pypreprocessesedffiles and store results asdatfiles while distributing the files to ensure label distribution in the cross-validation groups. Used for training.explore.pyhelps exploring the channels of each raw data file. Useful to verify which channels are in common.preprocess.pycontains the code to preprocess data.
- Training
cross_validation.pyis in charge of the cross-validations. Based fromLGGnet's repository.main.pyis the code to run to train models. Based fromLGGnet's repository.model.pydefines the models (LGGnet,RNNLGGnet,ResNet). Based fromLGGnet's repository.prepare_data.pyprepares the data for cross-validations. Based fromLGGnet's repository.train.pyprovides the functions used bycross_validation.pyfor training. Based fromLGGnet's repository.utils.pyhas miscellaneous tools. Based fromLGGnet's repository.
- Checking/Visualising
compare.pyevaluates the performances of multiple models.visualize.pyandvisualize_fif.pyplothdfandfiffiles. Their purpose is not relevant enough as for now.
- Other
example.pygives an example usage code of a model trained using this repository.
- Sawtooth waves
ConceFT.pycontains the code aiming to do the calculations from https://iopscience.iop.org/article/10.1088/1361-6579/ad66aa. It's purpose was to use it for detection of sawtooth waves but the execution time is too long in the current state. The parameters are therefore not tuned for sawtooth waves. Check the article for more details.stw_psd.pywas meant to find a way to detect sawtooth waves using PSD.
RBD should be detectable using simple signal processing. I didn't have the time to clearly search for an efficient way.
But before searching such a way, I would advise to retrain LGGnet or ResNet on a single dataset containing subjects
suffering from RBD or not in order to be sure these models can effectively detect RBD.
Sawtooth waves seem to be the way to know whether a sufferer from RBD will detect a future Parkinson's disease. I am imagining a check of the frequency of occurrences or the length of these waves. There are unfortunately not enough work done for detection of these waves.