Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/run.py: 71%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

206 statements  

1#!/usr/bin/env python 

2# coding=utf-8 

3 

4import datetime 

5import os.path 

6import sys 

7import traceback as tb 

8 

9from sacred import metrics_logger 

10from sacred.metrics_logger import linearize_metrics 

11from sacred.randomness import set_global_seed 

12from sacred.utils import SacredInterrupt, join_paths, IntervalTimer 

13from sacred.stdout_capturing import get_stdcapturer 

14 

15 

16class Run: 

17 """Represent and manage a single run of an experiment.""" 

18 

19 def __init__( 

20 self, 

21 config, 

22 config_modifications, 

23 main_function, 

24 observers, 

25 root_logger, 

26 run_logger, 

27 experiment_info, 

28 host_info, 

29 pre_run_hooks, 

30 post_run_hooks, 

31 captured_out_filter=None, 

32 ): 

33 

34 self._id = None 

35 """The ID of this run as assigned by the first observer""" 

36 

37 self.captured_out = "" 

38 """Captured stdout and stderr""" 

39 

40 self.config = config 

41 """The final configuration used for this run""" 

42 

43 self.config_modifications = config_modifications 

44 """A ConfigSummary object with information about config changes""" 

45 

46 self.experiment_info = experiment_info 

47 """A dictionary with information about the experiment""" 

48 

49 self.host_info = host_info 

50 """A dictionary with information about the host""" 

51 

52 self.info = {} 

53 """Custom info dict that will be sent to the observers""" 

54 

55 self.root_logger = root_logger 

56 """The root logger that was used to create all the others""" 

57 

58 self.run_logger = run_logger 

59 """The logger that is used for this run""" 

60 

61 self.main_function = main_function 

62 """The main function that is executed with this run""" 

63 

64 self.observers = observers 

65 """A list of all observers that observe this run""" 

66 

67 self.pre_run_hooks = pre_run_hooks 

68 """List of pre-run hooks (captured functions called before this run)""" 

69 

70 self.post_run_hooks = post_run_hooks 

71 """List of post-run hooks (captured functions called after this run)""" 

72 

73 self.result = None 

74 """The return value of the main function""" 

75 

76 self.status = None 

77 """The current status of the run, from QUEUED to COMPLETED""" 

78 

79 self.start_time = None 

80 """The datetime when this run was started""" 

81 

82 self.stop_time = None 

83 """The datetime when this run stopped""" 

84 

85 self.debug = False 

86 """Determines whether this run is executed in debug mode""" 

87 

88 self.pdb = False 

89 """If true the pdb debugger is automatically started after a failure""" 

90 

91 self.meta_info = {} 

92 """A custom comment for this run""" 

93 

94 self.beat_interval = 10.0 # sec 

95 """The time between two heartbeat events measured in seconds""" 

96 

97 self.unobserved = False 

98 """Indicates whether this run should be unobserved""" 

99 

100 self.force = False 

101 """Disable warnings about suspicious changes""" 

102 

103 self.queue_only = False 

104 """If true then this run will only fire the queued_event and quit""" 

105 

106 self.captured_out_filter = captured_out_filter 

107 """Filter function to be applied to captured output""" 

108 

109 self.fail_trace = None 

110 """A stacktrace, in case the run failed""" 

111 

112 self.capture_mode = None 

113 """Determines the way the stdout/stderr are captured""" 

114 

115 self._heartbeat = None 

116 self._failed_observers = [] 

117 self._output_file = None 

118 

119 self._metrics = metrics_logger.MetricsLogger() 

120 

121 def open_resource(self, filename, mode="r"): 

122 """Open a file and also save it as a resource. 

123 

124 Opens a file, reports it to the observers as a resource, and returns 

125 the opened file. 

126 

127 In Sacred terminology a resource is a file that the experiment needed 

128 to access during a run. In case of a MongoObserver that means making 

129 sure the file is stored in the database (but avoiding duplicates) along 

130 its path and md5 sum. 

131 

132 See also :py:meth:`sacred.Experiment.open_resource`. 

133 

134 Parameters 

135 ---------- 

136 filename : str 

137 name of the file that should be opened 

138 mode : str 

139 mode that file will be open 

140 

141 Returns 

142 ------- 

143 file 

144 the opened file-object 

145 

146 """ 

147 filename = os.path.abspath(filename) 

148 self._emit_resource_added(filename) # TODO: maybe non-blocking? 

149 return open(filename, mode) 

150 

151 def add_resource(self, filename): 

152 """Add a file as a resource. 

153 

154 In Sacred terminology a resource is a file that the experiment needed 

155 to access during a run. In case of a MongoObserver that means making 

156 sure the file is stored in the database (but avoiding duplicates) along 

157 its path and md5 sum. 

158 

159 See also :py:meth:`sacred.Experiment.add_resource`. 

160 

161 Parameters 

162 ---------- 

163 filename : str 

164 name of the file to be stored as a resource 

165 """ 

166 filename = os.path.abspath(filename) 

167 self._emit_resource_added(filename) 

168 

169 def add_artifact(self, filename, name=None, metadata=None, content_type=None): 

170 """Add a file as an artifact. 

171 

172 In Sacred terminology an artifact is a file produced by the experiment 

173 run. In case of a MongoObserver that means storing the file in the 

174 database. 

175 

176 See also :py:meth:`sacred.Experiment.add_artifact`. 

177 

178 Parameters 

179 ---------- 

180 filename : str 

181 name of the file to be stored as artifact 

182 name : str, optional 

183 optionally set the name of the artifact. 

184 Defaults to the filename. 

185 metadata: dict 

186 optionally attach metadata to the artifact. 

187 This only has an effect when using the MongoObserver. 

188 content_type: str, optional 

189 optionally attach a content-type to the artifact. 

190 This only has an effect when using the MongoObserver. 

191 """ 

192 filename = os.path.abspath(filename) 

193 name = os.path.basename(filename) if name is None else name 

194 self._emit_artifact_added(name, filename, metadata, content_type) 

195 

196 def __call__(self, *args): 

197 r"""Start this run. 

198 

199 Parameters 

200 ---------- 

201 \*args 

202 parameters passed to the main function 

203 

204 Returns 

205 ------- 

206 the return value of the main function 

207 

208 """ 

209 if self.start_time is not None: 

210 raise RuntimeError( 

211 "A run can only be started once. " 

212 "(Last start was {})".format(self.start_time) 

213 ) 

214 

215 if self.unobserved: 

216 self.observers = [] 

217 else: 

218 self.observers = sorted(self.observers, key=lambda x: -x.priority) 

219 

220 self.warn_if_unobserved() 

221 set_global_seed(self.config["seed"]) 

222 

223 if self.capture_mode is None and not self.observers: 

224 capture_mode = "no" 

225 else: 

226 capture_mode = self.capture_mode 

227 capture_mode, capture_stdout = get_stdcapturer(capture_mode) 

228 self.run_logger.debug('Using capture mode "%s"', capture_mode) 

229 

230 if self.queue_only: 

231 self._emit_queued() 

232 return 

233 try: 

234 with capture_stdout() as self._output_file: 

235 self._emit_started() 

236 self._start_heartbeat() 

237 self._execute_pre_run_hooks() 

238 self.result = self.main_function(*args) 

239 self._execute_post_run_hooks() 

240 if self.result is not None: 

241 self.run_logger.info("Result: {}".format(self.result)) 

242 elapsed_time = self._stop_time() 

243 self.run_logger.info("Completed after %s", elapsed_time) 

244 self._get_captured_output() 

245 self._stop_heartbeat() 

246 self._emit_completed(self.result) 

247 except (SacredInterrupt, KeyboardInterrupt) as e: 

248 self._stop_heartbeat() 

249 status = getattr(e, "STATUS", "INTERRUPTED") 

250 self._emit_interrupted(status) 

251 raise 

252 except BaseException: 

253 exc_type, exc_value, trace = sys.exc_info() 

254 self._stop_heartbeat() 

255 self._emit_failed(exc_type, exc_value, trace.tb_next) 

256 raise 

257 finally: 

258 self._warn_about_failed_observers() 

259 self._wait_for_observers() 

260 

261 return self.result 

262 

263 def _get_captured_output(self): 

264 if self._output_file.closed: 

265 return 

266 text = self._output_file.get() 

267 if isinstance(text, bytes): 

268 text = text.decode("utf-8", "replace") 

269 if self.captured_out: 

270 text = self.captured_out + text 

271 if self.captured_out_filter is not None: 

272 text = self.captured_out_filter(text) 

273 self.captured_out = text 

274 

275 def _start_heartbeat(self): 

276 self.run_logger.debug("Starting Heartbeat") 

277 if self.beat_interval > 0: 

278 self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create( 

279 self._emit_heartbeat, self.beat_interval 

280 ) 

281 self._heartbeat.start() 

282 

283 def _stop_heartbeat(self): 

284 self.run_logger.debug("Stopping Heartbeat") 

285 # only stop if heartbeat was started 

286 if self._heartbeat is not None: 

287 self._stop_heartbeat_event.set() 

288 self._heartbeat.join(timeout=2) 

289 

290 def _emit_queued(self): 

291 self.status = "QUEUED" 

292 queue_time = datetime.datetime.utcnow() 

293 self.meta_info["queue_time"] = queue_time 

294 command = join_paths( 

295 self.main_function.prefix, self.main_function.signature.name 

296 ) 

297 self.run_logger.info("Queuing-up command '%s'", command) 

298 for observer in self.observers: 

299 _id = observer.queued_event( 

300 ex_info=self.experiment_info, 

301 command=command, 

302 host_info=self.host_info, 

303 queue_time=queue_time, 

304 config=self.config, 

305 meta_info=self.meta_info, 

306 _id=self._id, 

307 ) 

308 if self._id is None: 

309 self._id = _id 

310 # do not catch any exceptions on startup: 

311 # the experiment SHOULD fail if any of the observers fails 

312 

313 if self._id is None: 

314 self.run_logger.info("Queued") 

315 else: 

316 self.run_logger.info('Queued-up run with ID "{}"'.format(self._id)) 

317 

318 def _emit_started(self): 

319 self.status = "RUNNING" 

320 self.start_time = datetime.datetime.utcnow() 

321 command = join_paths( 

322 self.main_function.prefix, self.main_function.signature.name 

323 ) 

324 self.run_logger.info("Running command '%s'", command) 

325 for observer in self.observers: 

326 _id = observer.started_event( 

327 ex_info=self.experiment_info, 

328 command=command, 

329 host_info=self.host_info, 

330 start_time=self.start_time, 

331 config=self.config, 

332 meta_info=self.meta_info, 

333 _id=self._id, 

334 ) 

335 if self._id is None: 

336 self._id = _id 

337 # do not catch any exceptions on startup: 

338 # the experiment SHOULD fail if any of the observers fails 

339 if self._id is None: 

340 self.run_logger.info("Started") 

341 else: 

342 self.run_logger.info('Started run with ID "{}"'.format(self._id)) 

343 

344 def _emit_heartbeat(self): 

345 beat_time = datetime.datetime.utcnow() 

346 self._get_captured_output() 

347 # Read all measured metrics since last heartbeat 

348 logged_metrics = self._metrics.get_last_metrics() 

349 metrics_by_name = linearize_metrics(logged_metrics) 

350 for observer in self.observers: 

351 self._safe_call( 

352 observer, "log_metrics", metrics_by_name=metrics_by_name, info=self.info 

353 ) 

354 self._safe_call( 

355 observer, 

356 "heartbeat_event", 

357 info=self.info, 

358 captured_out=self.captured_out, 

359 beat_time=beat_time, 

360 result=self.result, 

361 ) 

362 

363 def _stop_time(self): 

364 self.stop_time = datetime.datetime.utcnow() 

365 elapsed_time = datetime.timedelta( 

366 seconds=round((self.stop_time - self.start_time).total_seconds()) 

367 ) 

368 return elapsed_time 

369 

370 def _emit_completed(self, result): 

371 self.status = "COMPLETED" 

372 for observer in self.observers: 

373 self._final_call( 

374 observer, "completed_event", stop_time=self.stop_time, result=result 

375 ) 

376 

377 def _emit_interrupted(self, status): 

378 self.status = status 

379 elapsed_time = self._stop_time() 

380 self.run_logger.warning("Aborted after %s!", elapsed_time) 

381 for observer in self.observers: 

382 self._final_call( 

383 observer, 

384 "interrupted_event", 

385 interrupt_time=self.stop_time, 

386 status=status, 

387 ) 

388 

389 def _emit_failed(self, exc_type, exc_value, trace): 

390 self.status = "FAILED" 

391 elapsed_time = self._stop_time() 

392 self.run_logger.error("Failed after %s!", elapsed_time) 

393 self.fail_trace = tb.format_exception(exc_type, exc_value, trace) 

394 for observer in self.observers: 

395 self._final_call( 

396 observer, 

397 "failed_event", 

398 fail_time=self.stop_time, 

399 fail_trace=self.fail_trace, 

400 ) 

401 

402 def _emit_resource_added(self, filename): 

403 for observer in self.observers: 

404 self._safe_call(observer, "resource_event", filename=filename) 

405 

406 def _emit_artifact_added(self, name, filename, metadata, content_type): 

407 for observer in self.observers: 

408 self._safe_call( 

409 observer, 

410 "artifact_event", 

411 name=name, 

412 filename=filename, 

413 metadata=metadata, 

414 content_type=content_type, 

415 ) 

416 

417 def _safe_call(self, obs, method, **kwargs): 

418 if obs not in self._failed_observers: 

419 try: 

420 getattr(obs, method)(**kwargs) 

421 except Exception as e: 

422 self._failed_observers.append(obs) 

423 self.run_logger.warning( 

424 "An error ocurred in the '{}' " "observer: {}".format(obs, e) 

425 ) 

426 

427 def _final_call(self, observer, method, **kwargs): 

428 try: 

429 getattr(observer, method)(**kwargs) 

430 except Exception: 

431 # Feels dirty to catch all exceptions, but it is just for 

432 # finishing up, so we don't want one observer to kill the 

433 # others 

434 self.run_logger.error(tb.format_exc()) 

435 

436 def _wait_for_observers(self): 

437 """Block until all observers finished processing.""" 

438 for observer in self.observers: 

439 self._safe_call(observer, "join") 

440 

441 def _warn_about_failed_observers(self): 

442 for observer in self._failed_observers: 

443 self.run_logger.warning( 

444 "The observer '{}' failed at some point " 

445 "during the run.".format(observer) 

446 ) 

447 

448 def _execute_pre_run_hooks(self): 

449 for pr in self.pre_run_hooks: 

450 pr() 

451 

452 def _execute_post_run_hooks(self): 

453 for pr in self.post_run_hooks: 

454 pr() 

455 

456 def warn_if_unobserved(self): 

457 if not self.observers and not self.debug and not self.unobserved: 

458 self.run_logger.warning("No observers have been added to this run") 

459 

460 def log_scalar(self, metric_name, value, step=None): 

461 """ 

462 Add a new measurement. 

463 

464 The measurement will be processed by the MongoDB observer 

465 during a heartbeat event. 

466 Other observers are not yet supported. 

467 

468 :param metric_name: The name of the metric, e.g. training.loss 

469 :param value: The measured value 

470 :param step: The step number (integer), e.g. the iteration number 

471 If not specified, an internal counter for each metric 

472 is used, incremented by one. 

473 """ 

474 # Method added in change https://github.com/chovanecm/sacred/issues/4 

475 # The same as Experiment.log_scalar (if something changes, 

476 # update the docstring too!) 

477 

478 self._metrics.log_scalar_metric(metric_name, value, step)