Commit 88e31a47 authored by Alex Fout's avatar Alex Fout

updates

parent 1d00b381
......@@ -7,7 +7,7 @@ from preprocessing import extractWaves
from matplotlib import pyplot as plt
from config import data_directory
from config import feature_directory
plot_ignore_columns = ["window", "time"]
......@@ -26,7 +26,10 @@ class EEGSession(object):
self.window_size = None
self.n_windows = None
self.examples = None
self.remove_artifacts(artifact_remove_mode)
try:
self.remove_artifacts(artifact_remove_mode)
except:
pass
def remove_artifacts(self, mode="normal"):
"""
......@@ -149,6 +152,8 @@ class EEGSession(object):
:return: data matrix where rows are examples and colums are calculated features
:rtype: np.ndarray
"""
if self.examples is not None:
return self.examples
if channels == "all":
channels = [col for col in self.raw.columns if col not in plot_ignore_columns]
if filtered_waves:
......@@ -157,7 +162,7 @@ class EEGSession(object):
epoch_size = list(self.waves.values())[0].shape[0]
n_epochs = int(list(self.waves.values())[0].shape[0] / epoch_size)
examples = []
wave_matrices = {k: v.as_matrix() for k, v in self.waves.items()}
wave_matrices = {k: v[channels].as_matrix() for k, v in self.waves.items()}
for i in range(n_epochs):
feature_list = []
for wave_name, wave_matrix in wave_matrices.items():
......@@ -199,11 +204,26 @@ class EEGSession(object):
return self.examples
def save_examples(self):
def save_examples(self, subfolder):
"""
Saves examples to the data folder with .csv extension
Saves examples to the feature folder with .csv extension
"""
if self.examples is not None:
np.savetxt(os.path.join(data_directory, self.id + ".csv"), self.examples, delimiter=",")
np.savetxt(os.path.join(feature_directory, subfolder, self.id + ".csv"), self.examples, delimiter=",")
else:
print("Examples has not been computed yet, not saving anything")
def load_examples(self, subfolder):
"""
loads examples from the feature folder
:return:
:rtype:
"""
if os.path.exists(os.path.join(feature_directory, subfolder, self.id + ".csv")):
self.examples = np.loadtxt(os.path.join(feature_directory, subfolder, self.id + ".csv"), delimiter=",", dtype=np.float64)
if np.any(np.isnan(self.examples)):
self.examples = None
else:
print("Examples file not found, loading nothing")
self.examples = None
return self.examples
......@@ -15,8 +15,10 @@ class Embedding(object):
:type kwargs: dict
"""
self.type = type
self.n_components = kwargs["n_components"]
self.label_data = kwargs["label_data"]
if "n_components" in kwargs:
self.n_components = kwargs["n_components"]
if "label_data" in kwargs:
self.label_data = kwargs["label_data"]
def train(self, train_data):
"""
......
......@@ -11,11 +11,11 @@ from embedding import Embedding
class Patient(object):
def __init__(self, pid):
def __init__(self, pid, subfolder, load_session_raw=True, load_session_examples=False):
"""
loads data for a patient and given a string which represents the patient id (which is a number).
for example, file 1a.raw represents raw measurements from the first session for patient 1.
Stores results as instance varibles of this object.
Stores results as instance variables of this object.
IMPORTANT: Requires that you set the 'data_directory' variable in config.py
......@@ -28,7 +28,17 @@ class Patient(object):
self.intermediate_tests = []
self.n_concussions = 0
self.columns = ["fp1", "fp2", "f3", "f4", "f7", "f8", "c3", "c4", "p3", "p4", "o1", "o2", "t3", "t4", "t5", "t6", "fz", "cz", "pz"]
self.load_all(self.pid)
self.make_sessions(self.pid, load_session_raw, load_session_examples, subfolder)
def make_sessions(self, pid, load_session_raw, load_session_examples, subfolder):
if load_session_raw:
self.load_all(pid)
else:
self.pre_test = EEGSession(pid + "_pretest", None, None)
self.post_test = EEGSession(pid + "_post_test", None, None)
if load_session_examples:
self.pre_test.load_examples(subfolder)
self.post_test.load_examples(subfolder)
def load_all(self, pid):
"""
......
......@@ -4,6 +4,7 @@ import numpy as np
ignore_columns = ["time", "window"]
def stft(session, **kwargs):
"""
performs
......@@ -18,6 +19,7 @@ def stft(session, **kwargs):
for col in columns:
session.stft[col] = signal.stft(session.raw[col])
def extractWaves(session, n=4001, samplingRate=256, wave='all'):
"""
Extracts a given waveform from the EEG data.
......@@ -43,9 +45,9 @@ def extractWaves(session, n=4001, samplingRate=256, wave='all'):
if (wave == 'all'):
waves = ['delta', 'theta', 'alpha', 'beta', 'gamma']
for i in waves:
b[i] = FIR(n,samplingRate, i)
b[i] = FIR(n, samplingRate, i)
else:
b[wave] = FIR(n,samplingRate, wave)
b[wave] = FIR(n, samplingRate, wave)
if not hasattr(session, "waves"):
# create a dictionary of pandas dataframes
......@@ -57,13 +59,15 @@ def extractWaves(session, n=4001, samplingRate=256, wave='all'):
for col in columns:
# apply filter, via convolution
s = pd.Series(np.convolve(session.raw[col], b[key], mode='valid'))
df['_'.join([col,key])] = s
#df['_'.join([col,key])] = s # includes wave name
df[col] = s # doesn't include wave name
if 'window' in session.raw.columns:
df['window'] = session.raw['window'][chop:-chop].reset_index(drop=True)
df['time'] = session.raw['time'][chop:-chop].reset_index(drop=True)
session.waves[key] = df
return 0
def FIR(n=4001, samplingRate=256, wave='alpha'):
"""
......
......@@ -4,23 +4,31 @@ from itertools import cycle
import sys
from config import pid_noConcussion, pid_3stepProtocol, pid_testRetest, pid_concussion, feature_functions, epoch_size, \
embedding_args, pid_testlist, channels
embedding_args, pid_testlist, channels, subfolder
from patient import Patient
from embedding import Embedding
colors = cycle(['r', 'b', 'g', 'y', 'c', 'm', 'k'])
train_lists = [pid_concussion, pid_noConcussion]
examples_lists = [[], []]
train_examples = []
labels = ["concussion", "noconcussion"]
def embed_and_plot(emb, examples):
def embed_and_plot(emb, examples, all_color=None, linewidth=2):
pre_post_distances = []
alpha = 0.5 / np.log(len(examples)) if len(examples) > 1 else 1
for tup in examples:
if sys.version_info < (3, 0):
# for python2 use
color = colors.next()
if all_color is None:
if sys.version_info < (3, 0):
# for python2 use
color = colors.next()
else:
# for python3 use
color = next(colors)
else:
# for python3 use
color = next(colors)
color = all_color
pid = tup[0]
pre_emb = emb.embed(tup[1])
post_emb = emb.embed(tup[2])
......@@ -29,7 +37,7 @@ def embed_and_plot(emb, examples):
# calculate centriods and plot a line
pre_cent = centroid(pre_emb)
post_cent = centroid(post_emb)
plt.plot([pre_cent[0], post_cent[0]], [pre_cent[1], post_cent[1]], '-', linewidth=3, color=color)
plt.plot([pre_cent[0], post_cent[0]], [pre_cent[1], post_cent[1]], '-', linewidth=linewidth, color=color)
# record distance
pre_post_distances.append(np.linalg.norm(post_cent - pre_cent))
return pre_post_distances
......@@ -42,72 +50,42 @@ def centroid(data):
return np.array([float(x_sum)/length, float(y_sum)/length])
# get training data from un-concussed individuals
noCon_ex= []
step_ex = []
retest_ex = []
con_ex =[]
n_keep = -1
n_keep = 1000
# for lst, pat_list in zip([pid_noConcussion, pid_3stepProtocol, pid_testRetest, pid_concussion], [noCon_pats, step_pats, retest_pats, con_pats]):
#for lst, pat_list in zip([pid_noConcussion], [noCon_pats]):
for pid in pid_noConcussion:
print("Processing pid: {}".format(pid))
p = Patient(pid, load_session_raw=False, load_session_examples=True)
# get examples from pre_test
pre = post = None
if p.pre_test is not None:
pre = p.pre_test.load_examples()
if pre is not None:
np.random.shuffle(pre)
# get examples from post_test
if p.post_test is not None:
post = p.post_test.load_examples()
if post is not None:
np.random.shuffle(post)
if post is not None and pre is not None:
noCon_ex.append((pid, pre, post))
for pid in pid_concussion:
print("Processing pid: {}".format(pid))
p = Patient(pid, load_session_raw=False, load_session_examples=True)
# get examples from pre_test
pre = post = None
if p.pre_test is not None:
pre = p.pre_test.load_examples()
if pre is not None:
np.random.shuffle(pre)
pre = pre[:n_keep]
# get examples from post_test
if p.post_test is not None:
post = p.post_test.load_examples()
if post is not None:
np.random.shuffle(post)
post = post[:n_keep]
if post is not None and pre is not None:
con_ex.append((pid, pre, post))
for i, pid_list in enumerate(train_lists):
for pid in pid_list:
print("Processing pid: {}".format(pid))
p = Patient(pid, subfolder, load_session_raw=False, load_session_examples=True)
# get examples from pre_test
pre = post = None
if p.pre_test is not None:
pre = p.pre_test.load_examples(subfolder)
if pre is not None:
np.random.shuffle(pre)
# get examples from post_test
if p.post_test is not None:
post = p.post_test.load_examples(subfolder)
if post is not None:
np.random.shuffle(post)
if post is not None and pre is not None:
train_examples.append((pid, pre, post))
examples_lists[i].append((pid, pre, post))
# create training data
train_data = np.vstack([tup[1][:n_keep] for tup in noCon_ex] + [tup[2][n_keep] for tup in noCon_ex] +
[tup[1][:n_keep] for tup in con_ex] + [tup[2][n_keep] for tup in con_ex])
train_data = np.vstack([tup[1][:n_keep] for tup in train_examples] + [tup[2][:n_keep] for tup in train_examples])
# create and train embedding
emb = Embedding(**embedding_args)
emb.train(train_data)
# visualize embedding
nocon_distances = embed_and_plot(emb, noCon_ex)
plt.title("No concussion, pre/post test centroid distance")
plt.legend()
plt.show()
plt.savefig()
con_distances = embed_and_plot(emb, con_ex)
plt.title("Concussion, pre/post test centroid distance")
colors = ["r", "b"]
f = plt.figure()
for label, examples_list, color in zip(labels, examples_lists, colors):
distances = embed_and_plot(emb, examples_list, all_color=color)
plt.title("{} pre/post test centroid distance".format(label))
plt.legend()
plt.show()
plt.hist(nocon_distances)
plt.title("No concussion, pre/post test centroid distance")
plt.show()
plt.hist(con_distances)
plt.title("Concussion, pre/post test centroid distance")
plt.show()
\ No newline at end of file
plt.savefig("distances")
from config import pid_noConcussion, pid_3stepProtocol, pid_testRetest, pid_concussion, feature_functions, epoch_size
from config import pid_noConcussion, pid_3stepProtocol, pid_testRetest, pid_concussion, feature_functions, epoch_size, \
channels
from patient import Patient
# go through each list of ids
......@@ -6,14 +7,14 @@ for lst in [pid_noConcussion, pid_3stepProtocol, pid_testRetest, pid_concussion]
# for each id...
for pid in lst:
print("Processing pid: {}".format(pid))
p = Patient(pid)
p = Patient(pid, load_session_examples=False, load_session_raw=True)
# generate file for pre_test
if p.pre_test is not None:
p.pre_test.remove_artifacts()
p.pre_test.get_examples(feature_functions, epoch_size=epoch_size)
p.pre_test.get_examples(feature_functions, epoch_size=epoch_size, channels=channels)
p.pre_test.save_examples()
if p.post_test is not None:
# generate file for post_test
p.post_test.remove_artifacts()
p.post_test.get_examples(feature_functions, epoch_size=epoch_size)
p.post_test.get_examples(feature_functions, epoch_size=epoch_size, channels=channels)
p.post_test.save_examples()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment