Commit 1f38fb32 authored by Alex Fout's avatar Alex Fout

fixed remove artifacts function, simplified get_examples function, and added save_examples function

parent 52d9e221
import os
import numpy as np
import pandas as pd
import scipy as sp
from feature_extractors import extractors
from matplotlib import pyplot as plt
from config import data_directory
plot_ignore_columns = ["window", "time"]
class EEGSession():
def __init__(self, raw, artifacts):
def __init__(self, id, raw, artifacts):
self.id = str(id)
self.raw = raw
self.artifacts = artifacts
self.window_size = None
......@@ -20,8 +24,12 @@ class EEGSession():
"""
for each channel, replaces artifact frames from the raw data frame. artifacts are indicated by 1's in self.artifacts for the same frame/channel
"""
for i, col in enumerate(self.raw.columns):
self.raw[col] = pd.Series(np.zeros(self.artifacts[:, i] == 1, self.raw[col].as_matrix, self.raw[col]))
cols = [col for col in self.raw.columns if col not in plot_ignore_columns]
# replace each colums with zeros where the artifacts matrix is 1's:
for i, col in enumerate(cols):
# make sure the artifacts file is the same length as the raw file. this is not true for some datasets
if len(self.artifacts.as_matrix()[:, i]) == len(self.raw[col].as_matrix()):
self.raw[col] = pd.Series(np.where(self.artifacts.as_matrix()[:, i] == 1, np.zeros_like(self.raw[col].as_matrix()), self.raw[col].as_matrix()), dtype=np.float64)
def extract_windows(self, window_size="256"):
"""
......@@ -96,7 +104,7 @@ class EEGSession():
def plot_dataframe(self, df_name, channels=""):
pass
def get_examples(self, epoch_size="all", channels="all", coh=False, corr=True):
def get_examples(self, feature_args, epoch_size="all", channels="all"):
if channels == "all":
channels = [col for col in self.raw.columns if col not in plot_ignore_columns]
if epoch_size == "all":
......@@ -107,26 +115,20 @@ class EEGSession():
for i in range(n_epochs):
feature_list = []
raw_epoch = raw_matrix[i*epoch_size:(i+1)*epoch_size]
raw_epoch = raw_matrix[i*epoch_size:(i+1)*epoch_size].astype(np.float64)
# extract alpha, beta, waves etc.
# correlation features
if corr:
feature_list.append(np.ndarray.flatten(np.corrcoef(np.transpose(raw_epoch))))
# coherance
if coh:
coherences = []
for i in range(raw_epoch.shape[1]):
for j in range(i):
coherences.append(sp.signal.coherence(raw_epoch[:, i], raw_epoch[:, j])[0])
feature_list.append(np.hstack(coherences))
for extractor, args in feature_args:
extractor = extractors[extractor]
feature_list.append(extractor(raw_epoch, *args))
# create feature array for this exmaple
features = np.hstack(feature_list)
examples.append(features)
# create numpy array for all these features
examples = np.vstack(examples)
self.examples = np.vstack(examples)
return self.examples
return examples
def save_examples(self):
np.savetxt(os.path.join(data_directory, self.id + ".csv"), self.examples, delimiter=",")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment