Commit 2df55bdc authored by Alex Fout's avatar Alex Fout

added docs to classes/functions and cleaned up some formatting

parent b33bfd76
...@@ -13,17 +13,28 @@ from config import data_directory ...@@ -13,17 +13,28 @@ from config import data_directory
plot_ignore_columns = ["window", "time"] plot_ignore_columns = ["window", "time"]
class EEGSession(): class EEGSession(object):
"""
This class represents an EEG session that has a "raw" file and an "artifact" file. The raw file contains channel
measurements for each time step (rows are time steps and columns are channels) where there are 256 time steps per
second
"""
def __init__(self, id, raw, artifacts): def __init__(self, id, raw, artifacts):
self.id = str(id) self.id = str(id)
self.raw = raw self.raw = raw
self.artifacts = artifacts self.artifacts = artifacts
self.window_size = None self.window_size = None
self.n_windows = None self.n_windows = None
self.examples = None
def remove_artifacts(self, mode="normal"): def remove_artifacts(self, mode="normal"):
""" """
for each channel, replaces artifact frames from the raw data frame. artifacts are indicated by 1's in self.artifacts for the same frame/channel for each channel, replaces artifact frames from the raw data frame. artifacts are indicated by 1's in self.artifacts for the same frame/channel
:param mode: specifies what to do when replacing artifacts. options are:
"zero": replace with zeros (these sections have zero variance and will mess up correlation features)
"normal": replace with data from a random normal distribution with the same mean and variance as all
non-artifact data in that channel
:type mode: string
""" """
cols = [col for col in self.raw.columns if col not in plot_ignore_columns] cols = [col for col in self.raw.columns if col not in plot_ignore_columns]
# replace each colums with zeros where the artifacts matrix is 1's: # replace each colums with zeros where the artifacts matrix is 1's:
...@@ -43,9 +54,10 @@ class EEGSession(): ...@@ -43,9 +54,10 @@ class EEGSession():
self.raw = None self.raw = None
return return
def extract_windows(self, window_size="256"): def extract_windows(self, window_size="256"):
""" """
adds a "windows" column to pandas in order to plot several windows on top of one another. the value in this
column increments by 1 ever <window_size> time steps
:param window_size: number of frames in a window :param window_size: number of frames in a window
:type window_size: string :type window_size: string
:return: dataframe with "window" column :return: dataframe with "window" column
...@@ -118,6 +130,24 @@ class EEGSession(): ...@@ -118,6 +130,24 @@ class EEGSession():
pass pass
def get_examples(self, feature_args, epoch_size="all", channels="all", filtered_waves=True): def get_examples(self, feature_args, epoch_size="all", channels="all", filtered_waves=True):
"""
generates data examples from this session. time series are split up into epochs and static features are
computed for each epoch. Each epoch is treated as a separate example from this session
:param feature_args: a list of tuples that specifies the features to calculate. the first element of each tuple
is a string which corresponds to the feature being calculated. the second element of the tuple is a list of
positional arguments which will be passed to the feature generation function
:type feature_args: list
:param epoch_size: size of epoch or "all"
:type epoch_size: int or string
:param channels: list of strings which specify which columns to use to calculate features. corresponds to
columns in the pandas data array
:type channels: list
:param filtered_waves: indicates whether to calculate features based on the raw features or the filtered (alpha,
beta, ...) versions of those waves.
:type filtered_waves: bool
:return: data matrix where rows are examples and colums are calculated features
:rtype: np.ndarray
"""
if channels == "all": if channels == "all":
channels = [col for col in self.raw.columns if col not in plot_ignore_columns] channels = [col for col in self.raw.columns if col not in plot_ignore_columns]
if filtered_waves: if filtered_waves:
...@@ -169,4 +199,10 @@ class EEGSession(): ...@@ -169,4 +199,10 @@ class EEGSession():
return self.examples return self.examples
def save_examples(self): def save_examples(self):
np.savetxt(os.path.join(data_directory, self.id + ".csv"), self.examples, delimiter=",") """
Saves examples to the data folder with .csv extension
"""
if self.examples is not None:
np.savetxt(os.path.join(data_directory, self.id + ".csv"), self.examples, delimiter=",")
else:
print("Examples has not been computed yet, not saving anything")
...@@ -2,16 +2,38 @@ from sklearn.decomposition import PCA ...@@ -2,16 +2,38 @@ from sklearn.decomposition import PCA
class Embedding(object): class Embedding(object):
"""
This class calculates a low dimensional embedding based on some training data
"""
def __init__(self, type="pca", **kwargs): def __init__(self, type="pca", **kwargs):
"""
initializes embedding options
:param type: specifies type of embedding
:type type: string
:param kwargs: dictionary of optional keyword arguments
:type kwargs: dict
"""
self.type = type self.type = type
self.n_components = kwargs["n_components"] self.n_components = kwargs["n_components"]
def train(self, train_data): def train(self, train_data):
"""
trains an embedding based on passed training data
:param train_data: training data with which to calculate the embedding
:type train_data: ndarray
"""
if self.type == "pca": if self.type == "pca":
pca = PCA(n_components=self.n_components) pca = PCA(n_components=self.n_components)
pca.fit(train_data) pca.fit(train_data)
self.pca = pca self.pca = pca
def embed(self, train_data): def embed(self, data):
"""
embeds data according to a trained embedding
:param data: data to embed
:type data: ndarray
:return: embedding of the data
:rtype: ndarray
"""
if self.type == "pca": if self.type == "pca":
return self.pca.transform(train_data) return self.pca.transform(data)
...@@ -3,60 +3,94 @@ import scipy as sp ...@@ -3,60 +3,94 @@ import scipy as sp
def correlation(raw_matrix): def correlation(raw_matrix):
"""
Calculates pairwise correlations between columns of the matrix
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
"""
corr = np.corrcoef(np.transpose(raw_matrix)) corr = np.corrcoef(np.transpose(raw_matrix))
return corr[np.triu_indices(corr.shape[0], k=1, m=corr.shape[1])] return corr[np.triu_indices(corr.shape[0], k=1, m=corr.shape[1])]
def coherence(raw_matrix): def coherence(raw_matrix):
"""
Calculates the pairwise coherence values of the matrix
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
"""
coherences = [] coherences = []
for i in range(raw_matrix.shape[1]): for i in range(raw_matrix.shape[1]):
for j in range(i+1): for j in range(i + 1):
coherences.append(sp.signal.coherence(raw_matrix[:, i], raw_matrix[:, j])[0]) coherences.append(sp.signal.coherence(raw_matrix[:, i], raw_matrix[:, j])[0])
return np.hstack(coherences) return np.hstack(coherences)
def rms(raw_matrix): def rms(raw_matrix):
""" """
Calculates the root mean square value of a time series Calculates the root mean square value of a time series
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
""" """
rmsValues = [] rmsValues = []
for i in range(raw_matrix.shape[1]): for i in range(raw_matrix.shape[1]):
x = raw_matrix[:,i] x = raw_matrix[:, i]
rmsValues.append(np.sqrt(np.mean(x**2))) rmsValues.append(np.sqrt(np.mean(x ** 2)))
return np.hstack(rmsValues) return np.hstack(rmsValues)
def meanAbs(raw_matrix): def meanAbs(raw_matrix):
""" """
Calculates the mean of absolute values Calculates the mean of absolute values
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
""" """
meanAbsValues = [] meanAbsValues = []
for i in range(raw_matrix.shape[1]): for i in range(raw_matrix.shape[1]):
x = raw_matrix[:,i] x = raw_matrix[:, i]
meanAbsValues.append(np.mean(np.abs(x))) meanAbsValues.append(np.mean(np.abs(x)))
return np.hstack(meanAbsValues) return np.hstack(meanAbsValues)
def std(raw_matrix): def std(raw_matrix):
""" """
standard deviation of a time series standard deviation of a time series
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
""" """
stdValues = [] stdValues = []
for i in range(raw_matrix.shape[1]): for i in range(raw_matrix.shape[1]):
x = raw_matrix[:,i] x = raw_matrix[:, i]
stdValues.append(np.std(x)) stdValues.append(np.std(x))
return np.hstack(stdValues) return np.hstack(stdValues)
def subBandRatio(raw_matrix, nBands=6): def subBandRatio(raw_matrix, nBands=6):
""" """
The ratio of the mean of absolute values, between adjacent columns The ratio of the mean of absolute values, between adjacent columns
Note: This measure was used in a paper where the columns of the matrix represent the Note: This measure was used in a paper where the columns of the matrix represent the
frequency bands frequency bands
:param raw_matrix: data matrix where rows are examples and columns are raw features
:type raw_matrix: ndarray
:return: feature matrix where rows are examples and columns are calculated features
:rtype: ndarray
""" """
pass pass
extractors = { extractors = {
"correlation": correlation, "correlation": correlation,
"coherence": coherence, "coherence": coherence,
"rms" : rms, "rms": rms,
"meanAbs": meanAbs, "meanAbs": meanAbs,
"std": std, "std": std,
} }
...@@ -63,12 +63,12 @@ class Patient(object): ...@@ -63,12 +63,12 @@ class Patient(object):
# count number of concussions # count number of concussions
self.n_concussions = len(self.intermediate_tests) self.n_concussions = len(self.intermediate_tests)
def load_session(self, filename, id=""): def load_session(self, filename, id):
""" """
:param filename: file prefix, as in: <prefix>.raw and <prefix>.art :param filename: file prefix, as in: <prefix>.raw and <prefix>.art
:type filename: string :type filename: string
:param suffix: optional suffix for the id of the eeg session :param id: id of the eeg session
:type suffix: string :type id: string
:return: an EEGSession object which has a pandas data frame for each of Session.raw and Session.art :return: an EEGSession object which has a pandas data frame for each of Session.raw and Session.art
:rtype: EEGSession :rtype: EEGSession
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment