Coverage for sacred/sacred/run.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# coding=utf-8
4import datetime
5import os.path
6import sys
7import traceback as tb
9from sacred import metrics_logger
10from sacred.metrics_logger import linearize_metrics
11from sacred.randomness import set_global_seed
12from sacred.utils import SacredInterrupt, join_paths, IntervalTimer
13from sacred.stdout_capturing import get_stdcapturer
16class Run:
17 """Represent and manage a single run of an experiment."""
19 def __init__(
20 self,
21 config,
22 config_modifications,
23 main_function,
24 observers,
25 root_logger,
26 run_logger,
27 experiment_info,
28 host_info,
29 pre_run_hooks,
30 post_run_hooks,
31 captured_out_filter=None,
32 ):
34 self._id = None
35 """The ID of this run as assigned by the first observer"""
37 self.captured_out = ""
38 """Captured stdout and stderr"""
40 self.config = config
41 """The final configuration used for this run"""
43 self.config_modifications = config_modifications
44 """A ConfigSummary object with information about config changes"""
46 self.experiment_info = experiment_info
47 """A dictionary with information about the experiment"""
49 self.host_info = host_info
50 """A dictionary with information about the host"""
52 self.info = {}
53 """Custom info dict that will be sent to the observers"""
55 self.root_logger = root_logger
56 """The root logger that was used to create all the others"""
58 self.run_logger = run_logger
59 """The logger that is used for this run"""
61 self.main_function = main_function
62 """The main function that is executed with this run"""
64 self.observers = observers
65 """A list of all observers that observe this run"""
67 self.pre_run_hooks = pre_run_hooks
68 """List of pre-run hooks (captured functions called before this run)"""
70 self.post_run_hooks = post_run_hooks
71 """List of post-run hooks (captured functions called after this run)"""
73 self.result = None
74 """The return value of the main function"""
76 self.status = None
77 """The current status of the run, from QUEUED to COMPLETED"""
79 self.start_time = None
80 """The datetime when this run was started"""
82 self.stop_time = None
83 """The datetime when this run stopped"""
85 self.debug = False
86 """Determines whether this run is executed in debug mode"""
88 self.pdb = False
89 """If true the pdb debugger is automatically started after a failure"""
91 self.meta_info = {}
92 """A custom comment for this run"""
94 self.beat_interval = 10.0 # sec
95 """The time between two heartbeat events measured in seconds"""
97 self.unobserved = False
98 """Indicates whether this run should be unobserved"""
100 self.force = False
101 """Disable warnings about suspicious changes"""
103 self.queue_only = False
104 """If true then this run will only fire the queued_event and quit"""
106 self.captured_out_filter = captured_out_filter
107 """Filter function to be applied to captured output"""
109 self.fail_trace = None
110 """A stacktrace, in case the run failed"""
112 self.capture_mode = None
113 """Determines the way the stdout/stderr are captured"""
115 self._heartbeat = None
116 self._failed_observers = []
117 self._output_file = None
119 self._metrics = metrics_logger.MetricsLogger()
121 def open_resource(self, filename, mode="r"):
122 """Open a file and also save it as a resource.
124 Opens a file, reports it to the observers as a resource, and returns
125 the opened file.
127 In Sacred terminology a resource is a file that the experiment needed
128 to access during a run. In case of a MongoObserver that means making
129 sure the file is stored in the database (but avoiding duplicates) along
130 its path and md5 sum.
132 See also :py:meth:`sacred.Experiment.open_resource`.
134 Parameters
135 ----------
136 filename : str
137 name of the file that should be opened
138 mode : str
139 mode that file will be open
141 Returns
142 -------
143 file
144 the opened file-object
146 """
147 filename = os.path.abspath(filename)
148 self._emit_resource_added(filename) # TODO: maybe non-blocking?
149 return open(filename, mode)
151 def add_resource(self, filename):
152 """Add a file as a resource.
154 In Sacred terminology a resource is a file that the experiment needed
155 to access during a run. In case of a MongoObserver that means making
156 sure the file is stored in the database (but avoiding duplicates) along
157 its path and md5 sum.
159 See also :py:meth:`sacred.Experiment.add_resource`.
161 Parameters
162 ----------
163 filename : str
164 name of the file to be stored as a resource
165 """
166 filename = os.path.abspath(filename)
167 self._emit_resource_added(filename)
169 def add_artifact(self, filename, name=None, metadata=None, content_type=None):
170 """Add a file as an artifact.
172 In Sacred terminology an artifact is a file produced by the experiment
173 run. In case of a MongoObserver that means storing the file in the
174 database.
176 See also :py:meth:`sacred.Experiment.add_artifact`.
178 Parameters
179 ----------
180 filename : str
181 name of the file to be stored as artifact
182 name : str, optional
183 optionally set the name of the artifact.
184 Defaults to the filename.
185 metadata: dict
186 optionally attach metadata to the artifact.
187 This only has an effect when using the MongoObserver.
188 content_type: str, optional
189 optionally attach a content-type to the artifact.
190 This only has an effect when using the MongoObserver.
191 """
192 filename = os.path.abspath(filename)
193 name = os.path.basename(filename) if name is None else name
194 self._emit_artifact_added(name, filename, metadata, content_type)
196 def __call__(self, *args):
197 r"""Start this run.
199 Parameters
200 ----------
201 \*args
202 parameters passed to the main function
204 Returns
205 -------
206 the return value of the main function
208 """
209 if self.start_time is not None:
210 raise RuntimeError(
211 "A run can only be started once. "
212 "(Last start was {})".format(self.start_time)
213 )
215 if self.unobserved:
216 self.observers = []
217 else:
218 self.observers = sorted(self.observers, key=lambda x: -x.priority)
220 self.warn_if_unobserved()
221 set_global_seed(self.config["seed"])
223 if self.capture_mode is None and not self.observers:
224 capture_mode = "no"
225 else:
226 capture_mode = self.capture_mode
227 capture_mode, capture_stdout = get_stdcapturer(capture_mode)
228 self.run_logger.debug('Using capture mode "%s"', capture_mode)
230 if self.queue_only:
231 self._emit_queued()
232 return
233 try:
234 with capture_stdout() as self._output_file:
235 self._emit_started()
236 self._start_heartbeat()
237 self._execute_pre_run_hooks()
238 self.result = self.main_function(*args)
239 self._execute_post_run_hooks()
240 if self.result is not None:
241 self.run_logger.info("Result: {}".format(self.result))
242 elapsed_time = self._stop_time()
243 self.run_logger.info("Completed after %s", elapsed_time)
244 self._get_captured_output()
245 self._stop_heartbeat()
246 self._emit_completed(self.result)
247 except (SacredInterrupt, KeyboardInterrupt) as e:
248 self._stop_heartbeat()
249 status = getattr(e, "STATUS", "INTERRUPTED")
250 self._emit_interrupted(status)
251 raise
252 except BaseException:
253 exc_type, exc_value, trace = sys.exc_info()
254 self._stop_heartbeat()
255 self._emit_failed(exc_type, exc_value, trace.tb_next)
256 raise
257 finally:
258 self._warn_about_failed_observers()
259 self._wait_for_observers()
261 return self.result
263 def _get_captured_output(self):
264 if self._output_file.closed:
265 return
266 text = self._output_file.get()
267 if isinstance(text, bytes):
268 text = text.decode("utf-8", "replace")
269 if self.captured_out:
270 text = self.captured_out + text
271 if self.captured_out_filter is not None:
272 text = self.captured_out_filter(text)
273 self.captured_out = text
275 def _start_heartbeat(self):
276 self.run_logger.debug("Starting Heartbeat")
277 if self.beat_interval > 0:
278 self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create(
279 self._emit_heartbeat, self.beat_interval
280 )
281 self._heartbeat.start()
283 def _stop_heartbeat(self):
284 self.run_logger.debug("Stopping Heartbeat")
285 # only stop if heartbeat was started
286 if self._heartbeat is not None:
287 self._stop_heartbeat_event.set()
288 self._heartbeat.join(timeout=2)
290 def _emit_queued(self):
291 self.status = "QUEUED"
292 queue_time = datetime.datetime.utcnow()
293 self.meta_info["queue_time"] = queue_time
294 command = join_paths(
295 self.main_function.prefix, self.main_function.signature.name
296 )
297 self.run_logger.info("Queuing-up command '%s'", command)
298 for observer in self.observers:
299 _id = observer.queued_event(
300 ex_info=self.experiment_info,
301 command=command,
302 host_info=self.host_info,
303 queue_time=queue_time,
304 config=self.config,
305 meta_info=self.meta_info,
306 _id=self._id,
307 )
308 if self._id is None:
309 self._id = _id
310 # do not catch any exceptions on startup:
311 # the experiment SHOULD fail if any of the observers fails
313 if self._id is None:
314 self.run_logger.info("Queued")
315 else:
316 self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
318 def _emit_started(self):
319 self.status = "RUNNING"
320 self.start_time = datetime.datetime.utcnow()
321 command = join_paths(
322 self.main_function.prefix, self.main_function.signature.name
323 )
324 self.run_logger.info("Running command '%s'", command)
325 for observer in self.observers:
326 _id = observer.started_event(
327 ex_info=self.experiment_info,
328 command=command,
329 host_info=self.host_info,
330 start_time=self.start_time,
331 config=self.config,
332 meta_info=self.meta_info,
333 _id=self._id,
334 )
335 if self._id is None:
336 self._id = _id
337 # do not catch any exceptions on startup:
338 # the experiment SHOULD fail if any of the observers fails
339 if self._id is None:
340 self.run_logger.info("Started")
341 else:
342 self.run_logger.info('Started run with ID "{}"'.format(self._id))
344 def _emit_heartbeat(self):
345 beat_time = datetime.datetime.utcnow()
346 self._get_captured_output()
347 # Read all measured metrics since last heartbeat
348 logged_metrics = self._metrics.get_last_metrics()
349 metrics_by_name = linearize_metrics(logged_metrics)
350 for observer in self.observers:
351 self._safe_call(
352 observer, "log_metrics", metrics_by_name=metrics_by_name, info=self.info
353 )
354 self._safe_call(
355 observer,
356 "heartbeat_event",
357 info=self.info,
358 captured_out=self.captured_out,
359 beat_time=beat_time,
360 result=self.result,
361 )
363 def _stop_time(self):
364 self.stop_time = datetime.datetime.utcnow()
365 elapsed_time = datetime.timedelta(
366 seconds=round((self.stop_time - self.start_time).total_seconds())
367 )
368 return elapsed_time
370 def _emit_completed(self, result):
371 self.status = "COMPLETED"
372 for observer in self.observers:
373 self._final_call(
374 observer, "completed_event", stop_time=self.stop_time, result=result
375 )
377 def _emit_interrupted(self, status):
378 self.status = status
379 elapsed_time = self._stop_time()
380 self.run_logger.warning("Aborted after %s!", elapsed_time)
381 for observer in self.observers:
382 self._final_call(
383 observer,
384 "interrupted_event",
385 interrupt_time=self.stop_time,
386 status=status,
387 )
389 def _emit_failed(self, exc_type, exc_value, trace):
390 self.status = "FAILED"
391 elapsed_time = self._stop_time()
392 self.run_logger.error("Failed after %s!", elapsed_time)
393 self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
394 for observer in self.observers:
395 self._final_call(
396 observer,
397 "failed_event",
398 fail_time=self.stop_time,
399 fail_trace=self.fail_trace,
400 )
402 def _emit_resource_added(self, filename):
403 for observer in self.observers:
404 self._safe_call(observer, "resource_event", filename=filename)
406 def _emit_artifact_added(self, name, filename, metadata, content_type):
407 for observer in self.observers:
408 self._safe_call(
409 observer,
410 "artifact_event",
411 name=name,
412 filename=filename,
413 metadata=metadata,
414 content_type=content_type,
415 )
417 def _safe_call(self, obs, method, **kwargs):
418 if obs not in self._failed_observers:
419 try:
420 getattr(obs, method)(**kwargs)
421 except Exception as e:
422 self._failed_observers.append(obs)
423 self.run_logger.warning(
424 "An error ocurred in the '{}' " "observer: {}".format(obs, e)
425 )
427 def _final_call(self, observer, method, **kwargs):
428 try:
429 getattr(observer, method)(**kwargs)
430 except Exception:
431 # Feels dirty to catch all exceptions, but it is just for
432 # finishing up, so we don't want one observer to kill the
433 # others
434 self.run_logger.error(tb.format_exc())
436 def _wait_for_observers(self):
437 """Block until all observers finished processing."""
438 for observer in self.observers:
439 self._safe_call(observer, "join")
441 def _warn_about_failed_observers(self):
442 for observer in self._failed_observers:
443 self.run_logger.warning(
444 "The observer '{}' failed at some point "
445 "during the run.".format(observer)
446 )
448 def _execute_pre_run_hooks(self):
449 for pr in self.pre_run_hooks:
450 pr()
452 def _execute_post_run_hooks(self):
453 for pr in self.post_run_hooks:
454 pr()
456 def warn_if_unobserved(self):
457 if not self.observers and not self.debug and not self.unobserved:
458 self.run_logger.warning("No observers have been added to this run")
460 def log_scalar(self, metric_name, value, step=None):
461 """
462 Add a new measurement.
464 The measurement will be processed by the MongoDB observer
465 during a heartbeat event.
466 Other observers are not yet supported.
468 :param metric_name: The name of the metric, e.g. training.loss
469 :param value: The measured value
470 :param step: The step number (integer), e.g. the iteration number
471 If not specified, an internal counter for each metric
472 is used, incremented by one.
473 """
474 # Method added in change https://github.com/chovanecm/sacred/issues/4
475 # The same as Experiment.log_scalar (if something changes,
476 # update the docstring too!)
478 self._metrics.log_scalar_metric(metric_name, value, step)