Commit 82680bc9 authored by Can Pervane's avatar Can Pervane

Added lda (linear discriminant analysis) to embedding. This will find new...

Added lda (linear discriminant analysis) to embedding. This will find new features that seperates best the samples according to a label given by the user
parent 82416a50
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
class Embedding(object):
......@@ -15,6 +16,7 @@ class Embedding(object):
"""
self.type = type
self.n_components = kwargs["n_components"]
self.label_data = kwargs["label_data"]
def train(self, train_data):
"""
......@@ -26,10 +28,20 @@ class Embedding(object):
pca = PCA(n_components=self.n_components)
pca.fit(train_data)
self.pca = pca
if self.type == "lda":
"""
label_data : 1xn_sample, labels the data
train_data : n_samplexn_features
"""
lda = LinearDiscriminantAnalysis(n_components=self.n_components)
lda.fit(train_data, self.label_data)
self.lda = lda
def embed(self, data):
"""
embeds data according to a trained embedding
data: n_Samplesxn_features
return: n_SamplesxNewfeatures
:param data: data to embed
:type data: ndarray
:return: embedding of the data
......@@ -37,3 +49,5 @@ class Embedding(object):
"""
if self.type == "pca":
return self.pca.transform(data)
if self.type == "lda":
return self.lda.transform(data)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment