Coverage for sacred/sacred/observers/mongo.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

299 statements  

1from typing import Optional, Union 

2import mimetypes 

3import os.path 

4import pickle 

5import re 

6import sys 

7import time 

8from tempfile import NamedTemporaryFile 

9import warnings 

10 

11import sacred.optional as opt 

12from sacred.commandline_options import cli_option 

13from sacred.dependencies import get_digest 

14from sacred.observers.base import RunObserver 

15from sacred.observers.queue import QueueObserver 

16from sacred.serializer import flatten 

17from sacred.utils import ObserverError, PathType 

18import pkg_resources 

19 

20DEFAULT_MONGO_PRIORITY = 30 

21 

22# This ensures consistent mimetype detection across platforms. 

23mimetype_detector = mimetypes.MimeTypes( 

24 filenames=[pkg_resources.resource_filename("sacred", "data/mime.types")] 

25) 

26 

27 

28def force_valid_bson_key(key): 

29 key = str(key) 

30 if key.startswith("$"): 

31 key = "@" + key[1:] 

32 key = key.replace(".", ",") 

33 return key 

34 

35 

36def force_bson_encodeable(obj): 

37 import bson 

38 

39 if isinstance(obj, dict): 

40 try: 

41 bson.BSON.encode(obj, check_keys=True) 

42 return obj 

43 except bson.InvalidDocument: 

44 return { 

45 force_valid_bson_key(k): force_bson_encodeable(v) 

46 for k, v in obj.items() 

47 } 

48 

49 elif opt.has_numpy and isinstance(obj, opt.np.ndarray): 

50 return obj 

51 else: 

52 try: 

53 bson.BSON.encode({"dict_just_for_testing": obj}) 

54 return obj 

55 except bson.InvalidDocument: 

56 return str(obj) 

57 

58 

59class MongoObserver(RunObserver): 

60 COLLECTION_NAME_BLACKLIST = { 

61 "fs.files", 

62 "fs.chunks", 

63 "_properties", 

64 "system.indexes", 

65 "search_space", 

66 "search_spaces", 

67 } 

68 VERSION = "MongoObserver-0.7.0" 

69 

70 @classmethod 

71 def create(cls, *args, **kwargs): 

72 warnings.warn( 

73 "MongoObserver.create(...) is deprecated. " 

74 "Please use MongoObserver(...) instead.", 

75 DeprecationWarning, 

76 ) 

77 return cls(*args, **kwargs) 

78 

79 def __init__( 

80 self, 

81 url: Optional[str] = None, 

82 db_name: str = "sacred", 

83 collection: str = "runs", 

84 collection_prefix: str = "", 

85 overwrite: Optional[Union[int, str]] = None, 

86 priority: int = DEFAULT_MONGO_PRIORITY, 

87 client: Optional["pymongo.MongoClient"] = None, 

88 failure_dir: Optional[PathType] = None, 

89 **kwargs, 

90 ): 

91 """Initializer for MongoObserver. 

92 

93 Parameters 

94 ---------- 

95 url 

96 Mongo URI to connect to. 

97 db_name 

98 Database to connect to. 

99 collection 

100 Collection to write the runs to. (default: "runs"). 

101 **DEPRECATED**, please use collection_prefix instead. 

102 collection_prefix 

103 Prefix the runs and metrics collection, 

104 i.e. runs will be stored to PREFIX_runs, metrics to PREFIX_metrics. 

105 If empty runs are stored to 'runs', metrics to 'metrics'. 

106 overwrite 

107 _id of a run that should be overwritten. 

108 priority 

109 (default 30) 

110 client 

111 Client to connect to. Do not use client and URL together. 

112 failure_dir 

113 Directory to save the run of a failed observer to. 

114 """ 

115 import pymongo 

116 import gridfs 

117 

118 if client is not None: 

119 if not isinstance(client, pymongo.MongoClient): 

120 raise ValueError( 

121 "client needs to be a pymongo.MongoClient, " 

122 "but is {} instead".format(type(client)) 

123 ) 

124 if url is not None: 

125 raise ValueError("Cannot pass both a client and a url.") 

126 else: 

127 client = pymongo.MongoClient(url, **kwargs) 

128 

129 database = client[db_name] 

130 if collection != "runs": 

131 # the 'old' way of setting a custom collection name 

132 # still works as before for backward compatibility 

133 warnings.warn( 

134 'Argument "collection" is deprecated. ' 

135 'Please use "collection_prefix" instead.', 

136 DeprecationWarning, 

137 ) 

138 if collection_prefix != "": 

139 raise ValueError("Cannot pass both collection and a collection prefix.") 

140 runs_collection_name = collection 

141 metrics_collection_name = "metrics" 

142 else: 

143 if collection_prefix != "": 

144 # separate prefix from 'runs' / 'collections' by an underscore. 

145 collection_prefix = "{}_".format(collection_prefix) 

146 

147 runs_collection_name = "{}runs".format(collection_prefix) 

148 metrics_collection_name = "{}metrics".format(collection_prefix) 

149 

150 if runs_collection_name in MongoObserver.COLLECTION_NAME_BLACKLIST: 

151 raise KeyError( 

152 'Collection name "{}" is reserved. ' 

153 "Please use a different one.".format(runs_collection_name) 

154 ) 

155 

156 if metrics_collection_name in MongoObserver.COLLECTION_NAME_BLACKLIST: 

157 raise KeyError( 

158 'Collection name "{}" is reserved. ' 

159 "Please use a different one.".format(metrics_collection_name) 

160 ) 

161 

162 runs_collection = database[runs_collection_name] 

163 metrics_collection = database[metrics_collection_name] 

164 fs = gridfs.GridFS(database) 

165 self.initialize( 

166 runs_collection, 

167 fs, 

168 overwrite=overwrite, 

169 metrics_collection=metrics_collection, 

170 failure_dir=failure_dir, 

171 priority=priority, 

172 ) 

173 

174 def initialize( 

175 self, 

176 runs_collection, 

177 fs, 

178 overwrite=None, 

179 metrics_collection=None, 

180 failure_dir=None, 

181 priority=DEFAULT_MONGO_PRIORITY, 

182 ): 

183 self.runs = runs_collection 

184 self.metrics = metrics_collection 

185 self.fs = fs 

186 if overwrite is not None: 

187 overwrite = int(overwrite) 

188 run = self.runs.find_one({"_id": overwrite}) 

189 if run is None: 

190 raise RuntimeError( 

191 "Couldn't find run to overwrite with " "_id='{}'".format(overwrite) 

192 ) 

193 else: 

194 overwrite = run 

195 self.overwrite = overwrite 

196 self.run_entry = None 

197 self.priority = priority 

198 self.failure_dir = failure_dir 

199 

200 @classmethod 

201 def create_from(cls, *args, **kwargs): 

202 self = cls.__new__(cls) # skip __init__ call 

203 self.initialize(*args, **kwargs) 

204 return self 

205 

206 def queued_event( 

207 self, ex_info, command, host_info, queue_time, config, meta_info, _id 

208 ): 

209 if self.overwrite is not None: 

210 raise RuntimeError("Can't overwrite with QUEUED run.") 

211 self.run_entry = { 

212 "experiment": dict(ex_info), 

213 "command": command, 

214 "host": dict(host_info), 

215 "config": flatten(config), 

216 "meta": meta_info, 

217 "status": "QUEUED", 

218 } 

219 # set ID if given 

220 if _id is not None: 

221 self.run_entry["_id"] = _id 

222 # save sources 

223 self.run_entry["experiment"]["sources"] = self.save_sources(ex_info) 

224 self.insert() 

225 return self.run_entry["_id"] 

226 

227 def started_event( 

228 self, ex_info, command, host_info, start_time, config, meta_info, _id 

229 ): 

230 if self.overwrite is None: 

231 self.run_entry = {"_id": _id} 

232 else: 

233 if self.run_entry is not None: 

234 raise RuntimeError("Cannot overwrite more than once!") 

235 # TODO sanity checks 

236 self.run_entry = self.overwrite 

237 

238 self.run_entry.update( 

239 { 

240 "experiment": dict(ex_info), 

241 "format": self.VERSION, 

242 "command": command, 

243 "host": dict(host_info), 

244 "start_time": start_time, 

245 "config": flatten(config), 

246 "meta": meta_info, 

247 "status": "RUNNING", 

248 "resources": [], 

249 "artifacts": [], 

250 "captured_out": "", 

251 "info": {}, 

252 "heartbeat": None, 

253 } 

254 ) 

255 

256 # save sources 

257 self.run_entry["experiment"]["sources"] = self.save_sources(ex_info) 

258 self.insert() 

259 return self.run_entry["_id"] 

260 

261 def heartbeat_event(self, info, captured_out, beat_time, result): 

262 self.run_entry["info"] = flatten(info) 

263 self.run_entry["captured_out"] = captured_out 

264 self.run_entry["heartbeat"] = beat_time 

265 self.run_entry["result"] = flatten(result) 

266 self.save() 

267 

268 def completed_event(self, stop_time, result): 

269 self.run_entry["stop_time"] = stop_time 

270 self.run_entry["result"] = flatten(result) 

271 self.run_entry["status"] = "COMPLETED" 

272 self.final_save(attempts=10) 

273 

274 def interrupted_event(self, interrupt_time, status): 

275 self.run_entry["stop_time"] = interrupt_time 

276 self.run_entry["status"] = status 

277 self.final_save(attempts=3) 

278 

279 def failed_event(self, fail_time, fail_trace): 

280 self.run_entry["stop_time"] = fail_time 

281 self.run_entry["status"] = "FAILED" 

282 self.run_entry["fail_trace"] = fail_trace 

283 self.final_save(attempts=1) 

284 

285 def resource_event(self, filename): 

286 if self.fs.exists(filename=filename): 

287 md5hash = get_digest(filename) 

288 if self.fs.exists(filename=filename, md5=md5hash): 

289 resource = (filename, md5hash) 

290 if resource not in self.run_entry["resources"]: 

291 self.run_entry["resources"].append(resource) 

292 self.save() 

293 return 

294 with open(filename, "rb") as f: 

295 file_id = self.fs.put(f, filename=filename) 

296 md5hash = self.fs.get(file_id).md5 

297 self.run_entry["resources"].append((filename, md5hash)) 

298 self.save() 

299 

300 def artifact_event(self, name, filename, metadata=None, content_type=None): 

301 with open(filename, "rb") as f: 

302 run_id = self.run_entry["_id"] 

303 db_filename = "artifact://{}/{}/{}".format(self.runs.name, run_id, name) 

304 if content_type is None: 

305 content_type = self._try_to_detect_content_type(filename) 

306 

307 file_id = self.fs.put( 

308 f, filename=db_filename, metadata=metadata, content_type=content_type 

309 ) 

310 

311 self.run_entry["artifacts"].append({"name": name, "file_id": file_id}) 

312 self.save() 

313 

314 @staticmethod 

315 def _try_to_detect_content_type(filename): 

316 mime_type, _ = mimetype_detector.guess_type(filename) 

317 if mime_type is not None: 

318 print( 

319 "Added {} as content-type of artifact {}.".format(mime_type, filename) 

320 ) 

321 else: 

322 print( 

323 "Failed to detect content-type automatically for " 

324 "artifact {}.".format(filename) 

325 ) 

326 return mime_type 

327 

328 def log_metrics(self, metrics_by_name, info): 

329 """Store new measurements to the database. 

330 

331 Take measurements and store them into 

332 the metrics collection in the database. 

333 Additionally, reference the metrics 

334 in the info["metrics"] dictionary. 

335 """ 

336 if self.metrics is None: 

337 # If, for whatever reason, the metrics collection has not been set 

338 # do not try to save anything there. 

339 return 

340 for key in metrics_by_name: 

341 query = {"run_id": self.run_entry["_id"], "name": key} 

342 push = { 

343 "steps": {"$each": metrics_by_name[key]["steps"]}, 

344 "values": {"$each": metrics_by_name[key]["values"]}, 

345 "timestamps": {"$each": metrics_by_name[key]["timestamps"]}, 

346 } 

347 update = {"$push": push} 

348 result = self.metrics.update_one(query, update, upsert=True) 

349 if result.upserted_id is not None: 

350 # This is the first time we are storing this metric 

351 info.setdefault("metrics", []).append( 

352 {"name": key, "id": str(result.upserted_id)} 

353 ) 

354 

355 def insert(self): 

356 import pymongo.errors 

357 

358 if self.overwrite: 

359 return self.save() 

360 

361 autoinc_key = self.run_entry.get("_id") is None 

362 while True: 

363 if autoinc_key: 

364 c = self.runs.find({}, {"_id": 1}) 

365 c = c.sort("_id", pymongo.DESCENDING).limit(1) 

366 self.run_entry["_id"] = ( 

367 c.next()["_id"] + 1 if self.runs.count_documents({}, limit=1) else 1 

368 ) 

369 try: 

370 self.runs.insert_one(self.run_entry) 

371 return 

372 except pymongo.errors.InvalidDocument as e: 

373 raise ObserverError( 

374 "Run contained an unserializable entry." 

375 "(most likely in the info)\n{}".format(e) 

376 ) 

377 except pymongo.errors.DuplicateKeyError: 

378 if not autoinc_key: 

379 raise 

380 

381 def save(self): 

382 import pymongo.errors 

383 

384 try: 

385 self.runs.update_one( 

386 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry} 

387 ) 

388 except pymongo.errors.AutoReconnect: 

389 pass # just wait for the next save 

390 except pymongo.errors.InvalidDocument: 

391 raise ObserverError( 

392 "Run contained an unserializable entry." "(most likely in the info)" 

393 ) 

394 

395 def final_save(self, attempts): 

396 import pymongo.errors 

397 

398 for i in range(attempts): 

399 try: 

400 self.runs.update_one( 

401 {"_id": self.run_entry["_id"]}, 

402 {"$set": self.run_entry}, 

403 upsert=True, 

404 ) 

405 return 

406 except pymongo.errors.AutoReconnect: 

407 if i < attempts - 1: 

408 time.sleep(1) 

409 except pymongo.errors.ConnectionFailure: 

410 pass 

411 except pymongo.errors.InvalidDocument: 

412 self.run_entry = force_bson_encodeable(self.run_entry) 

413 print( 

414 "Warning: Some of the entries of the run were not " 

415 "BSON-serializable!\n They have been altered such that " 

416 "they can be stored, but you should fix your experiment!" 

417 "Most likely it is either the 'info' or the 'result'.", 

418 file=sys.stderr, 

419 ) 

420 

421 os.makedirs(self.failure_dir, exist_ok=True) 

422 with NamedTemporaryFile( 

423 suffix=".pickle", 

424 delete=False, 

425 prefix="sacred_mongo_fail_{}_".format(self.run_entry["_id"]), 

426 dir=self.failure_dir, 

427 ) as f: 

428 pickle.dump(self.run_entry, f) 

429 print( 

430 "Warning: saving to MongoDB failed! " 

431 "Stored experiment entry in '{}'".format(f.name), 

432 file=sys.stderr, 

433 ) 

434 

435 def save_sources(self, ex_info): 

436 base_dir = ex_info["base_dir"] 

437 source_info = [] 

438 for source_name, md5 in ex_info["sources"]: 

439 abs_path = os.path.join(base_dir, source_name) 

440 file = self.fs.find_one({"filename": abs_path, "md5": md5}) 

441 if file: 

442 _id = file._id 

443 else: 

444 with open(abs_path, "rb") as f: 

445 _id = self.fs.put(f, filename=abs_path) 

446 source_info.append([source_name, _id]) 

447 return source_info 

448 

449 def __eq__(self, other): 

450 if isinstance(other, MongoObserver): 

451 return self.runs == other.runs 

452 return False 

453 

454 

455@cli_option("-m", "--mongo_db") 

456def mongo_db_option(args, run): 

457 """Add a MongoDB Observer to the experiment. 

458 

459 The argument value is the database specification. 

460 Should be in the form: 

461 

462 `[host:port:]db_name[.collection[:id]][!priority]` 

463 """ 

464 kwargs = parse_mongo_db_arg(args) 

465 mongo = MongoObserver(**kwargs) 

466 run.observers.append(mongo) 

467 

468 

469def get_pattern(): 

470 run_id_pattern = r"(?P<overwrite>\d{1,12})" 

471 port1_pattern = r"(?P<port1>\d{1,5})" 

472 port2_pattern = r"(?P<port2>\d{1,5})" 

473 priority_pattern = r"(?P<priority>-?\d+)?" 

474 db_name_pattern = r"(?P<db_name>[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" 

475 coll_name_pattern = ( 

476 r"(?P<collection>[_A-Za-z]" r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})" 

477 ) 

478 hostname1_pattern = ( 

479 r"(?P<host1>" 

480 r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" 

481 r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" 

482 r"[0-9A-Za-z])?)*)" 

483 ) 

484 hostname2_pattern = ( 

485 r"(?P<host2>" 

486 r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" 

487 r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" 

488 r"[0-9A-Za-z])?)*)" 

489 ) 

490 

491 host_only = r"^(?:{host}:{port})$".format( 

492 host=hostname1_pattern, port=port1_pattern 

493 ) 

494 full = ( 

495 r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" 

496 r"(?:!{priority})?$".format( 

497 host=hostname2_pattern, 

498 port=port2_pattern, 

499 db=db_name_pattern, 

500 collection=coll_name_pattern, 

501 rid=run_id_pattern, 

502 priority=priority_pattern, 

503 ) 

504 ) 

505 

506 return r"{host_only}|{full}".format(host_only=host_only, full=full) 

507 

508 

509def parse_mongo_db_arg(mongo_db): 

510 g = re.match(get_pattern(), mongo_db).groupdict() 

511 if g is None: 

512 raise ValueError( 

513 'mongo_db argument must have the form "db_name" ' 

514 'or "host:port[:db_name]" but was {}'.format(mongo_db) 

515 ) 

516 

517 kwargs = {} 

518 if g["host1"]: 

519 kwargs["url"] = "{}:{}".format(g["host1"], g["port1"]) 

520 elif g["host2"]: 

521 kwargs["url"] = "{}:{}".format(g["host2"], g["port2"]) 

522 

523 if g["priority"] is not None: 

524 kwargs["priority"] = int(g["priority"]) 

525 

526 for p in ["db_name", "collection", "overwrite"]: 

527 if g[p] is not None: 

528 kwargs[p] = g[p] 

529 

530 return kwargs 

531 

532 

533class QueueCompatibleMongoObserver(MongoObserver): 

534 def log_metrics(self, metric_name, metrics_values, info): 

535 """Store new measurements to the database. 

536 

537 Take measurements and store them into 

538 the metrics collection in the database. 

539 Additionally, reference the metrics 

540 in the info["metrics"] dictionary. 

541 """ 

542 if self.metrics is None: 

543 # If, for whatever reason, the metrics collection has not been set 

544 # do not try to save anything there. 

545 return 

546 query = {"run_id": self.run_entry["_id"], "name": metric_name} 

547 push = { 

548 "steps": {"$each": metrics_values["steps"]}, 

549 "values": {"$each": metrics_values["values"]}, 

550 "timestamps": {"$each": metrics_values["timestamps"]}, 

551 } 

552 update = {"$push": push} 

553 result = self.metrics.update_one(query, update, upsert=True) 

554 if result.upserted_id is not None: 

555 # This is the first time we are storing this metric 

556 info.setdefault("metrics", []).append( 

557 {"name": metric_name, "id": str(result.upserted_id)} 

558 ) 

559 

560 def save(self): 

561 import pymongo 

562 

563 try: 

564 self.runs.update_one( 

565 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry} 

566 ) 

567 except pymongo.errors.InvalidDocument as exc: 

568 raise ObserverError( 

569 "Run contained an unserializable entry. (most likely in the info)" 

570 ) from exc 

571 

572 def final_save(self, attempts): 

573 import pymongo 

574 

575 try: 

576 self.runs.update_one( 

577 {"_id": self.run_entry["_id"]}, {"$set": self.run_entry}, upsert=True 

578 ) 

579 return 

580 

581 except pymongo.errors.InvalidDocument: 

582 self.run_entry = force_bson_encodeable(self.run_entry) 

583 print( 

584 "Warning: Some of the entries of the run were not " 

585 "BSON-serializable!\n They have been altered such that " 

586 "they can be stored, but you should fix your experiment!" 

587 "Most likely it is either the 'info' or the 'result'.", 

588 file=sys.stderr, 

589 ) 

590 

591 with NamedTemporaryFile( 

592 suffix=".pickle", delete=False, prefix="sacred_mongo_fail_" 

593 ) as f: 

594 pickle.dump(self.run_entry, f) 

595 print( 

596 "Warning: saving to MongoDB failed! " 

597 "Stored experiment entry in '{}'".format(f.name), 

598 file=sys.stderr, 

599 ) 

600 

601 raise ObserverError("Warning: saving to MongoDB failed!") 

602 

603 

604class QueuedMongoObserver(QueueObserver): 

605 """MongoObserver that uses a fault-tolerant background process.""" 

606 

607 @classmethod 

608 def create(cls, *args, **kwargs): 

609 warnings.warn( 

610 "QueuedMongoObserver.create(...) is deprecated. " 

611 "Please use QueuedMongoObserver(...) instead.", 

612 DeprecationWarning, 

613 ) 

614 return cls(*args, **kwargs) 

615 

616 def __init__( 

617 self, 

618 interval: float = 20.0, 

619 retry_interval: float = 10.0, 

620 url: Optional[str] = None, 

621 db_name: str = "sacred", 

622 collection: str = "runs", 

623 overwrite: Optional[Union[int, str]] = None, 

624 priority: int = DEFAULT_MONGO_PRIORITY, 

625 client: Optional["pymongo.MongoClient"] = None, 

626 **kwargs, 

627 ): 

628 """Initializer for MongoObserver. 

629 

630 Parameters 

631 ---------- 

632 interval 

633 The interval in seconds at which the background thread is woken up to 

634 process new events. 

635 retry_interval 

636 The interval in seconds to wait if an event failed to be processed. 

637 url 

638 Mongo URI to connect to. 

639 db_name 

640 Database to connect to. 

641 collection 

642 Collection to write the runs to. (default: "runs"). 

643 overwrite 

644 _id of a run that should be overwritten. 

645 priority 

646 (default 30) 

647 client 

648 Client to connect to. Do not use client and URL together. 

649 failure_dir 

650 Directory to save the run of a failed observer to. 

651 """ 

652 super().__init__( 

653 QueueCompatibleMongoObserver( 

654 url=url, 

655 db_name=db_name, 

656 collection=collection, 

657 overwrite=overwrite, 

658 priority=priority, 

659 client=client, 

660 **kwargs, 

661 ), 

662 interval=interval, 

663 retry_interval=retry_interval, 

664 )