parent
e57d5dc3ba
commit
419a2d9134
@ -0,0 +1,86 @@ |
|||||||
|
|
||||||
|
import matplotlib.pyplot as plt |
||||||
|
import seaborn as sns |
||||||
|
import pandas as pd |
||||||
|
import numpy as np |
||||||
|
import random |
||||||
|
from scipy.io import loadmat |
||||||
|
from scipy import signal |
||||||
|
from google.colab import drive |
||||||
|
import warnings |
||||||
|
import string |
||||||
|
import os |
||||||
|
|
||||||
|
from keras.layers import * |
||||||
|
from keras.models import Model |
||||||
|
from keras import backend as K |
||||||
|
from keras import Sequential |
||||||
|
from keras.callbacks import ModelCheckpoint |
||||||
|
from keras.callbacks import EarlyStopping |
||||||
|
from sklearn.metrics import confusion_matrix |
||||||
|
from sklearn.preprocessing import scale |
||||||
|
from sklearn.preprocessing import MinMaxScaler |
||||||
|
|
||||||
|
# Install mne library for topoplot |
||||||
|
!pip install mne |
||||||
|
import mne |
||||||
|
|
||||||
|
warnings.filterwarnings('ignore') |
||||||
|
drive.mount('/content/drive') |
||||||
|
|
||||||
|
####################### SETTINGS ######################## |
||||||
|
# Set this variable with the desidered model's name # |
||||||
|
MODEL_SELECTED = "MCNN1" # |
||||||
|
# # |
||||||
|
# Set this variable with the desidered subject's letter # |
||||||
|
SUBJECT_SELECTED = "B" # |
||||||
|
# # |
||||||
|
# Set this variable with the desired electrode subset # |
||||||
|
ELECTRODE_SELECTED = "O" # CNN2c only # |
||||||
|
######################################################### |
||||||
|
|
||||||
|
model_names = ["CNN1", "CNN2a", "CNN2b", "CNN2c", "CNN3", "MCNN1", "MCNN2", "MCNN3"] |
||||||
|
subject_names = ["A", "B"] |
||||||
|
electrode_names = ["F", "C", "P", "O", "LT", "RT"] |
||||||
|
|
||||||
|
# Check for errors in the settings |
||||||
|
if MODEL_SELECTED not in model_names: |
||||||
|
raise ValueError("MODEL_SELECTED value {} is invalid.\nPlease enter one of the following parameters {}".format(MODEL_SELECTED, model_names)) |
||||||
|
elif SUBJECT_SELECTED not in subject_names: |
||||||
|
raise ValueError("SUBJECT_SELECTED value {} is invalid.\nPlease enter one of the following parameters {}".format(SUBJECT_SELECTED, subject_names)) |
||||||
|
elif MODEL_SELECTED == "CNN2c" and ELECTRODE_SELECTED not in electrode_names: |
||||||
|
raise ValueError("ELECTRODE_SELECTED value {} is invalid\nPlease enter one of the following parameters {}".format(ELECTRODE_SELECTED, electrodes_names)) |
||||||
|
|
||||||
|
# Google drive data paths |
||||||
|
MODEL_LOCATIONS_FILE_PATH = 'drive/My Drive/AY1920_DT_P300_SPELLER_03/Trained_models/' + MODEL_SELECTED + '/' + SUBJECT_SELECTED |
||||||
|
SUBJECT_TRAIN_FILE_PATH = 'drive/My Drive/AY1920_DT_P300_SPELLER_03/Dataset/Subject_' + SUBJECT_SELECTED + '_Train.mat' |
||||||
|
SUBJECT_TEST_FILE_PATH = 'drive/My Drive/AY1920_DT_P300_SPELLER_03/Dataset/Subject_' + SUBJECT_SELECTED + '_Test.mat' |
||||||
|
CHANNEL_LOCATIONS_FILE_PATH = 'drive/My Drive/AY1920_DT_P300_SPELLER_03/Dataset/channels.csv' |
||||||
|
CHANNEL_COORD = 'drive/My Drive/AY1920_DT_P300_SPELLER_03/Dataset/coordinates.csv' |
||||||
|
|
||||||
|
# Channel selection |
||||||
|
if MODEL_SELECTED == "CNN2a": |
||||||
|
CHANNELS = [10, 33, 48, 50, 52, 55, 59, 61] |
||||||
|
|
||||||
|
elif MODEL_SELECTED == "CNN2b": |
||||||
|
if SUBJECT_SELECTED == "A": |
||||||
|
CHANNELS = [10, 14, 17, 50, 55, 57, 59, 60] |
||||||
|
elif SUBJECT_SELECTED == "B": |
||||||
|
CHANNELS = [17, 50, 55, 56, 57, 58, 59, 60] |
||||||
|
|
||||||
|
elif MODEL_SELECTED == "CNN2c": |
||||||
|
if ELECTRODE_SELECTED == "F": |
||||||
|
CHANNELS = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] |
||||||
|
elif ELECTRODE_SELECTED == "C": |
||||||
|
CHANNELS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] |
||||||
|
elif ELECTRODE_SELECTED == "P": |
||||||
|
CHANNELS = [15, 16, 17, 18, 19, 48, 49, 50, 51, 52] |
||||||
|
elif ELECTRODE_SELECTED == "O": |
||||||
|
CHANNELS = [55, 56, 57, 58, 59, 60, 61, 62] |
||||||
|
elif ELECTRODE_SELECTED == "LT": |
||||||
|
CHANNELS = [14, 38, 40, 44, 46, 47] |
||||||
|
elif ELECTRODE_SELECTED == "RT": |
||||||
|
CHANNELS = [20, 39, 41, 45, 53, 54] |
||||||
|
|
||||||
|
else: |
||||||
|
CHANNELS = [i for i in range(64)] |
||||||
Loading…
Reference in new issue