Ensures X can be worked on using the numpy API (useful for indexing!).
If X is an object that does not strictly follow the numpy API (like pandas.DataFrame),
then it internally stores the metadata (like columns), casts X to a numpy array, calls the generate function,
and finally restore and restores the original type.
@TODO check y too!
Source code in badgers/core/decorators.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 | def numpy_API(generate_func):
"""
Ensures X can be worked on using the numpy API (useful for indexing!).
If X is an object that does not strictly follow the numpy API (like pandas.DataFrame),
then it internally stores the metadata (like columns), casts X to a numpy array, calls the generate function,
and finally restore and restores the original type.
@TODO check y too!
"""
@functools.wraps(generate_func)
def wrapper(self, X, y, **params):
X_data_type = None
if isinstance(X, pd.DataFrame):
# when X is a pandas DataFrame, then locally save the columns and make X a numpy array
X_data_type = TabularDataType.PANDAS_DATAFRAME
columns = X.columns
X = X.to_numpy()
# call to generate function
Xt, yt = generate_func(self, X, y, **params)
if X_data_type == TabularDataType.PANDAS_DATAFRAME:
Xt = pd.DataFrame(Xt, columns=columns)
return Xt, yt
return wrapper
|