-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathsession.py
379 lines (310 loc) · 13.3 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# -*- coding: utf-8 -*-
import itertools
import json
import logging
from collections import Counter, defaultdict
from hashlib import md5
import numpy as np
from tqdm.auto import trange
from baytune.selection.ucb1 import UCB1
from baytune.tuning.tunable import Tunable
from baytune.tuning.tuners.base import StopTuning
from baytune.tuning.tuners.gaussian_process import GPTuner
LOGGER = logging.getLogger(__name__)
class BTBSession:
"""BTBSession class.
A ``BTBSession`` represents the process of selecting and tuning several tunables
until the best possible configuration for a specific ``scorer`` is found.
For this, a loop is run in which for each iteration a combination of a ``Selector`` and
``Tuner`` is used to decide which tunable to score next and with which hyperparameters.
While running, the ``BTBSession`` handles the errors discarding, if configured to do so,
the tunables that have reached as many errors as the user specified.
Attributes:
best_proposal (dict):
Best configuration found with the name of the tunable and the hyperparameters
and crossvalidated score obtained for it.
best_score (float):
Best score obtained for this session so far.
proposals (dict):
Dictionary containing all the proposals generated by the ``BTBSession``.
iterations (int):
Amount of iterations run.
errors (Counter):
A Counter of the errors that each Tunable had during the session.
Args:
tunables (dict):
Python dictionary that has as keys the name of the tunable and
as value a dictionary with the tunable hyperparameters or an
``baytune.tuning.tunable.Tunable`` instance.
scorer (callable object / function):
A callable object or function with signature ``scorer(tunable_name, config)``
wich should return only a single value.
tuner_class (baytune.tuning.tuner.BaseTuner):
A tuner based on BTB ``BaseTuner`` class. This tuner will manage the new proposals.
Defaults to ``baytune.tuning.tuners.gaussian_process.GPTuner``
selector_class (baytune.selection.selector.Selector):
A selector based on BTB ``Selector`` class. This will determinate which one of
the tunables is performing better, and which one to test next. Defaults to
``baytune.selection.selectors.ucb1.UCB1``
maximize (bool):
If ``True`` the scores are interpreted as bigger is better, if ``False`` then smaller
is better, this should depend on the problem type (maximization or minimization).
Defaults to ``True``.
max_erors (int):
Amount of errors allowed for a tunable to not generate a score. Once this amount of
errors is reached, the tunable will be removed from the list. Defaults to 1.
verbose (bool):
If ``True`` a progress bar will be displayed for the ``run`` process.
"""
_tunables = None
_scorer = None
_tuner_class = None
_selector = None
_maximize = None
_max_errors = None
_best_normalized = None
_tunable_names = None
_normalized_scores = None
_tuners = None
_range = None
best_proposal = None
best_score = None
proposals = None
iterations = None
errors = None
def _normalize(self, score):
if score is not None:
return score if self._maximize else -score
def __init__(
self,
tunables,
scorer,
tuner_class=GPTuner,
selector_class=UCB1,
maximize=True,
max_errors=1,
verbose=False,
):
self._tunables = tunables
self._scorer = scorer
self._tuner_class = tuner_class
self._tunable_names = list(self._tunables.keys())
self._selector = selector_class(self._tunable_names)
self._maximize = maximize
self._max_errors = max_errors
self.best_proposal = None
self.proposals = dict()
self.iterations = 0
self.errors = Counter()
self.best_score = None
self._best_normalized = -np.inf
self._normalized_scores = defaultdict(list)
self._tuners = dict()
self._range = trange if verbose else range
def _make_dumpable(self, to_dump):
dumpable = {}
for key, value in to_dump.items():
if not isinstance(key, str):
key = str(key)
if isinstance(value, np.integer):
value = int(value)
elif isinstance(value, np.floating):
value = float(value)
elif isinstance(value, np.ndarray):
value = value.tolist()
elif isinstance(value, np.bool_):
value = bool(value)
elif value == "None":
value = None
dumpable[key] = value
return dumpable
def _make_id(self, name, config):
dumpable_config = self._make_dumpable(config)
proposal = {
"name": name,
"config": dumpable_config,
}
hashable = json.dumps(proposal, sort_keys=True).encode()
return md5(hashable).hexdigest()
def _remove_tunable(self, tunable_name):
"""Remove a tunable from the candidates list.
This is necessary when:
- Duplicates are not allowed and the tunable has exhausted all its
configurations.
- The tunable has failed more than ``max_errors`` times.
When this happens, the tunable is removved from the tunables dict
and its scores are removed from the normmalized_scores dict used by
the selectors.
"""
self._normalized_scores.pop(tunable_name, None)
self._tunables.pop(tunable_name, None)
def _get_next_tunable_name(self):
if self._normalized_scores:
tunable_name = self._selector.select(self._normalized_scores)
else:
# if _normalized_scores is still empty the selector crashes
# this happens when max_errors > 1, all tunables have tuners
# and all previous trials have crashed.
tunable_name = np.random.choice(list(self._tunables.keys()))
return tunable_name
def propose(self):
"""Propose a new configuration to score.
Every time ``propose`` is called, a new tunable will be selected and a new
hyperparameter proposal will be generated for it.
At the begining, the default hyperparameters of each one of the tunables
will be returned sequencially in the same order as they were passed to
the ``BTBSession``.
After that, once each tunable has been scored at least once, the tunable
used to generate the new proposals will be selected optimally each time
by the selector.
If a tunable runs out of proposals, it will be discarded from the list and will
not be proposed again.
Finally, when all the tunables have ran out of proposals, a ``StopTuning`` exception
will be raised.
Returns:
tuple (str, dict):
* Name of the tunable to try next.
* Hyperparameters proposal.
Raises:
StopTuning:
If the ``BTBSession`` has run out of proposals to generate.
"""
if not self._tunables:
raise StopTuning("There are no tunables left to try.")
if len(self._tuners) < len(self._tunable_names):
tunable_name = self._tunable_names[len(self._tuners)]
tunable = self._tunables[tunable_name]
if isinstance(tunable, dict):
LOGGER.info("Creating Tunable instance from dict.")
tunable = Tunable.from_dict(tunable)
if not isinstance(tunable, Tunable):
raise TypeError(
"Tunable can only be an instance of baytune.tuning.Tunable or dict"
)
LOGGER.info("Obtaining default configuration for %s", tunable_name)
config = tunable.get_defaults()
if tunable.cardinality == 1:
LOGGER.warn(
"Skipping tuner creation for Tunable %s with cardinality 1",
tunable_name,
)
tuner = None
else:
tuner = self._tuner_class(tunable)
self._tuners[tunable_name] = tuner
else:
tunable_name = self._get_next_tunable_name()
tuner = self._tuners[tunable_name]
try:
if tuner is None:
raise StopTuning(
"Tunable %s has no tunable hyperparameters", tunable_name
)
LOGGER.info(
"Generating new proposal configuration for %s", tunable_name
)
config = tuner.propose(1)
except StopTuning:
LOGGER.info("%s has no more configs to propose.", tunable_name)
self._remove_tunable(tunable_name)
tunable_name, config = self.propose()
proposal_id = self._make_id(tunable_name, config)
self.proposals[proposal_id] = {
"id": proposal_id,
"name": tunable_name,
"config": config,
}
return tunable_name, config
def handle_error(self, tunable_name):
"""Handle errors when ``score`` is ``None``.
If the given ``tunable_name`` accumulates more errors than ``self._max_errors``
this is removed from the selector's choices.
Args:
tunable_name (str):
The name of the tunable to which this configuration belongs.
"""
self.errors[tunable_name] += 1
errors = self.errors[tunable_name]
if errors >= self._max_errors:
LOGGER.warning(
"Too many errors: %s. Removing tunable %s", errors, tunable_name
)
self._remove_tunable(tunable_name)
def record(self, tunable_name, config, score):
"""Record the configuration and the obtained score to the tuner.
If the score is the best one so far, the ``best_proposal`` and ``best_score`` are
updated.
Args:
tunable_name (str):
The name of the tunable to which this configuration belongs.
config (dict):
Hyperparameter proposal, as given by the tunable.
score (float):
Obtained score with the given configuration.
"""
proposal_id = self._make_id(tunable_name, config)
proposal = self.proposals[proposal_id]
proposal["score"] = score
if score is None:
self.handle_error(tunable_name)
else:
normalized = self._normalize(score)
self._normalized_scores[tunable_name].append(normalized)
if normalized > self._best_normalized:
LOGGER.info("New optimal found: %s - %s", tunable_name, score)
self.best_proposal = proposal
self.best_score = score
self._best_normalized = normalized
try:
tuner = self._tuners[tunable_name]
if tuner is None:
LOGGER.warn(
"Skipping record for Tunable %s with cardinality 1",
tunable_name,
)
else:
tuner.record(config, normalized)
except Exception:
LOGGER.exception(
"Could not record configuration and score for tuner %s.",
tunable_name,
)
def run(self, iterations=None):
"""Run the selection and tuning loop for the given number of iterations.
At each iteration, the `BTBSession` will generate a new proposal calling
``self.propose``, score it using the `self.scorer`, and finally record the
obtained score back to the tuner calling `self.record`.
If no iterations are given, run infinitely until interrupted or until all the
tuner proposals are exhausted.
Scoring errors will also be captured and recorded.
Returns:
best_proposal (dict):
Best configuration found with the name of the tunable and the hyperparameters
and crossvalidated score obtained for it.
"""
if iterations is None:
iterator = itertools.count()
else:
iterator = self._range(iterations)
for _ in iterator:
self.iterations += 1
tunable_name, config = self.propose()
try:
LOGGER.debug(
"Scoring proposal %s - %s: %s",
self.iterations,
tunable_name,
config,
)
score = self._scorer(tunable_name, config)
except Exception:
params = "\n".join("{}: {}".format(k, v) for k, v in config.items())
LOGGER.exception(
"Proposal %s - %s crashed with the following configuration: %s",
self.iterations,
tunable_name,
params,
)
score = None
self.record(tunable_name, config, score)
return self.best_proposal