-
Notifications
You must be signed in to change notification settings - Fork 327
Expand file tree
/
Copy pathbase.py
More file actions
367 lines (306 loc) · 15.4 KB
/
base.py
File metadata and controls
367 lines (306 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
Base classes for active learning algorithms
------------------------------------------
"""
import abc
import sys
import warnings
from typing import Any, Callable, Iterator, List, Tuple, Union
import numpy as np
import scipy.sparse as sp
from modAL.utils.data import data_hstack, modALinput, retrieve_rows
from sklearn.base import BaseEstimator
from sklearn.ensemble._base import _BaseHeterogeneousEnsemble
from sklearn.pipeline import Pipeline
if sys.version_info >= (3, 4):
ABC = abc.ABC
else:
ABC = abc.ABCMeta('ABC', (), {})
class BaseLearner(ABC, BaseEstimator):
"""
Core abstraction in modAL.
Args:
estimator: The estimator to be used in the active learning loop.
query_strategy: Function providing the query strategy for the active learning loop,
for instance, modAL.uncertainty.uncertainty_sampling.
force_all_finite: When True, forces all values of the data finite.
When False, accepts np.nan and np.inf values.
on_transformed: Whether to transform samples with the pipeline defined by the estimator
when applying the query strategy.
**fit_kwargs: keyword arguments.
Attributes:
estimator: The estimator to be used in the active learning loop.
query_strategy: Function providing the query strategy for the active learning loop.
"""
def __init__(self,
estimator: BaseEstimator,
query_strategy: Callable,
on_transformed: bool = False,
force_all_finite: bool = True,
**fit_kwargs
) -> None:
assert callable(query_strategy), 'query_strategy must be callable'
self.estimator = estimator
self.query_strategy = query_strategy
self.on_transformed = on_transformed
assert isinstance(force_all_finite,
bool), 'force_all_finite must be a bool'
self.force_all_finite = force_all_finite
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
"""
Transforms the data as supplied to the estimator.
* In case the estimator is an skearn pipeline, it applies all pipeline components but the last one.
* In case the estimator is an ensemble, it concatenates the transformations for each classfier
(pipeline) in the ensemble.
* Otherwise returns the non-transformed dataset X
Args:
X: dataset to be transformed
Returns:
Transformed data set
"""
Xt = []
pipes = [self.estimator]
if isinstance(self.estimator, _BaseHeterogeneousEnsemble):
pipes = self.estimator.estimators_
################################
# transform data with pipelines used by estimator
for pipe in pipes:
if isinstance(pipe, Pipeline):
# NOTE: The used pipeline class might be an extension to sklearn's!
# Create a new instance of the used pipeline class with all
# components but the final estimator, which is replaced by an empty (passthrough) component.
# This prevents any special handling of the final transformation pipe, which is usually
# expected to be an estimator.
transformation_pipe = pipe.__class__(
steps=[*pipe.steps[:-1], ('passthrough', 'passthrough')])
Xt.append(transformation_pipe.transform(X))
# in case no transformation pipelines are used by the estimator,
# return the original, non-transfored data
if not Xt:
return X
################################
# concatenate all transformations and return
return data_hstack(Xt)
def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> 'BaseLearner':
"""
Fits self.estimator to the given data and labels.
Args:
X: The new samples for which the labels are supplied by the expert.
y: Labels corresponding to the new instances in X.
bootstrap: If True, the method trains the model on a set bootstrapped from X.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
Returns:
self
"""
if not bootstrap:
self.estimator.fit(X, y, **fit_kwargs)
else:
bootstrap_idx = np.random.choice(
range(X.shape[0]), X.shape[0], replace=True)
self.estimator.fit(X[bootstrap_idx], y[bootstrap_idx])
return self
@abc.abstractmethod
def fit(self, *args, **kwargs) -> None:
pass
def predict(self, X: modALinput, **predict_kwargs) -> Any:
"""
Estimator predictions for X. Interface with the predict method of the estimator.
Args:
X: The samples to be predicted.
**predict_kwargs: Keyword arguments to be passed to the predict method of the estimator.
Returns:
Estimator predictions for X.
"""
return self.estimator.predict(X, **predict_kwargs)
def predict_proba(self, X: modALinput, **predict_proba_kwargs) -> Any:
"""
Class probabilities if the predictor is a classifier. Interface with the predict_proba method of the classifier.
Args:
X: The samples for which the class probabilities are to be predicted.
**predict_proba_kwargs: Keyword arguments to be passed to the predict_proba method of the classifier.
Returns:
Class probabilities for X.
"""
return self.estimator.predict_proba(X, **predict_proba_kwargs)
def query(self, X_pool, *query_args, return_metrics: bool = False, **query_kwargs) -> Union[Tuple, modALinput]:
"""
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
Args:
X_pool: Pool of unlabeled instances to retrieve most informative instances from
return_metrics: boolean to indicate, if the corresponding query metrics should be (not) returned
*query_args: The arguments for the query strategy. For instance, in the case of
:func:`~modAL.uncertainty.uncertainty_sampling`, it is the pool of samples from which the query strategy
should choose instances to request labels.
**query_kwargs: Keyword arguments for the query strategy function.
Returns:
Value of the query_strategy function. Should be the indices of the instances from the pool chosen to be
labelled and the instances themselves. Can be different in other cases, for instance only the instance to be
labelled upon query synthesis.
query_metrics: returns also the corresponding metrics, if return_metrics == True
"""
try:
query_result, query_metrics = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
except:
query_metrics = None
query_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
if return_metrics:
if query_metrics is None:
warnings.warn(
"The selected query strategy doesn't support return_metrics")
return query_result, retrieve_rows(X_pool, query_result), query_metrics
else:
return query_result, retrieve_rows(X_pool, query_result)
def score(self, X: modALinput, y: modALinput, **score_kwargs) -> Any:
"""
Interface for the score method of the predictor.
Args:
X: The samples for which prediction accuracy is to be calculated.
y: Ground truth labels for X.
**score_kwargs: Keyword arguments to be passed to the .score() method of the predictor.
Returns:
The score of the predictor.
"""
return self.estimator.score(X, y, **score_kwargs)
@abc.abstractmethod
def teach(self, *args, **kwargs) -> None:
pass
class BaseCommittee(ABC, BaseEstimator):
"""
Base class for query-by-committee setup.
Args:
learner_list: List of ActiveLearner objects to form committee.
query_strategy: Function to query labels.
on_transformed: Whether to transform samples with the pipeline defined by each learner's estimator
when applying the query strategy.
"""
def __init__(self, learner_list: List[BaseLearner], query_strategy: Callable, on_transformed: bool = False) -> None:
assert type(learner_list) == list, 'learners must be supplied in a list'
self.learner_list = learner_list
self.query_strategy = query_strategy
self.on_transformed = on_transformed
# TODO: update training data when using fit() and teach() methods
self.X_training = None
def __iter__(self) -> Iterator[BaseLearner]:
for learner in self.learner_list:
yield learner
def __len__(self) -> int:
return len(self.learner_list)
def _add_training_data(self, X: modALinput, y: modALinput) -> None:
"""
Adds the new data and label to the known data for each learner, but does not retrain the model.
Args:
X: The new samples for which the labels are supplied by the expert.
y: Labels corresponding to the new instances in X.
Note:
If the learners have been fitted, the features in X have to agree with the training samples which the
classifier has seen.
"""
for learner in self.learner_list:
learner._add_training_data(X, y)
def _fit_to_known(self, bootstrap: bool = False, **fit_kwargs) -> None:
"""
Fits all learners to the training data and labels provided to it so far.
Args:
bootstrap: If True, each estimator is trained on a bootstrapped dataset. Useful when
using bagging to build the ensemble.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
for learner in self.learner_list:
learner._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> None:
"""
Fits all learners to the given data and labels.
Args:
X: The new samples for which the labels are supplied by the expert.
y: Labels corresponding to the new instances in X.
bootstrap: If True, the method trains the model on a set bootstrapped from X.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
for learner in self.learner_list:
learner._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
"""
Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
learner using the data it has seen, use the method .rebag()!
Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!
Args:
X: The samples to be fitted on.
y: The corresponding labels.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
for learner in self.learner_list:
learner.fit(X, y, **fit_kwargs)
return self
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
"""
Transforms the data as supplied to each learner's estimator and concatenates transformations.
Args:
X: dataset to be transformed
Returns:
Transformed data set
"""
return data_hstack([learner.transform_without_estimating(X) for learner in self.learner_list])
def query(self, X_pool, return_metrics: bool = False, *query_args, **query_kwargs) -> Union[Tuple, modALinput]:
"""
Finds the n_instances most informative point in the data provided by calling the query_strategy function.
Args:
X_pool: Pool of unlabeled instances to retrieve most informative instances from
return_metrics: boolean to indicate, if the corresponding query metrics should be (not) returned
*query_args: The arguments for the query strategy. For instance, in the case of
:func:`~modAL.disagreement.max_disagreement_sampling`, it is the pool of samples from which the query.
strategy should choose instances to request labels.
**query_kwargs: Keyword arguments for the query strategy function.
Returns:
Return value of the query_strategy function. Should be the indices of the instances from the pool chosen to
be labelled and the instances themselves. Can be different in other cases, for instance only the instance to
be labelled upon query synthesis.
query_metrics: returns also the corresponding metrics, if return_metrics == True
"""
try:
query_result, query_metrics = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
except:
query_metrics = None
query_result = self.query_strategy(
self, X_pool, *query_args, **query_kwargs)
if return_metrics:
if query_metrics is None:
warnings.warn(
"The selected query strategy doesn't support return_metrics")
return query_result, retrieve_rows(X_pool, query_result), query_metrics
else:
return query_result, retrieve_rows(X_pool, query_result)
def rebag(self, **fit_kwargs) -> None:
"""
Refits every learner with a dataset bootstrapped from its training instances. Contrary to .bag(), it bootstraps
the training data for each learner based on its own examples.
Todo:
Where is .bag()?
Args:
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
self._fit_to_known(bootstrap=True, **fit_kwargs)
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
"""
Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
Args:
X: The new samples for which the labels are supplied by the expert.
y: Labels corresponding to the new instances in X.
bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging.
only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
self._add_training_data(X, y)
if not only_new:
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
else:
self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
@abc.abstractmethod
def predict(self, X: modALinput) -> Any:
pass
@abc.abstractmethod
def vote(self, X: modALinput) -> Any: # TODO: clarify typing
pass