-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdenmune_wrapper.py
36 lines (28 loc) · 1.21 KB
/
denmune_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import pandas as pd
from denmune import DenMune
from typing import Union, Tuple, Dict, Optional, List
class DenMuneWrapper:
def __init__(self, k: int, rgn_tsne: bool=False, **kwargs):
self.k = k
self.kwargs = kwargs
self.rgn_tsne = rgn_tsne
self.model = None
self.labels_ = None
self.labels = None
self.validity_ = None
return
def fit(self, X: Union[np.ndarray, pd.DataFrame], y: Optional[Union[np.ndarray, pd.Series]]=None):
assert isinstance(X, (np.ndarray, pd.DataFrame))
# if isinstance(X, np.ndarray):
# X = pd.DataFrame(X, columns=pd.Int64Index(range(X.shape[1])))
self.model = DenMune(train_data=X,
train_truth=y,
k_nearest=self.k,
rgn_tsne=self.rgn_tsne)
self.labels, self.validity_ = self.model.fit_predict(**self.kwargs)
self.labels_ = self.labels['train']
return
def fit_predict(self, X: Union[np.ndarray, pd.DataFrame], y: Optional[Union[np.ndarray, pd.Series]]=None) -> np.ndarray:
self.fit(X, y)
return self.labels_