Coverage for sacred/sacred/observers/s3_observer.py: 18%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

214 statements  

1import json 

2import os 

3import os.path 

4 

5from sacred.commandline_options import cli_option 

6from sacred.dependencies import get_digest 

7from sacred.observers.base import RunObserver 

8from sacred.serializer import flatten 

9import re 

10import socket 

11 

12DEFAULT_S3_PRIORITY = 20 

13 

14 

15def _is_valid_bucket(bucket_name): 

16 # See https://docs.aws.amazon.com/awscloudtrail/latest/userguide/ 

17 # cloudtrail-s3-bucket-naming-requirements.html 

18 if len(bucket_name) < 3 or len(bucket_name) > 63: 

19 return False 

20 

21 labels = bucket_name.split(".") 

22 # A bucket name consists of "labels" separated by periods 

23 for label in labels: 

24 if len(label) == 0 or label[0] == "-" or label[-1] == "-": 

25 # Labels must be of nonzero length, 

26 # and cannot begin or end with a hyphen 

27 return False 

28 for char in label: 

29 # Labels can only contain digits, lowercase letters, or hyphens. 

30 # Anything else will fail here 

31 if not (char.isdigit() or char.islower() or char == "-"): 

32 return False 

33 try: 

34 # If a name is a valid IP address, it cannot be a bucket name 

35 socket.inet_aton(bucket_name) 

36 except socket.error: 

37 return True 

38 

39 

40def s3_join(*args): 

41 return "/".join(args) 

42 

43 

44class S3Observer(RunObserver): 

45 VERSION = "S3Observer-0.1.0" 

46 

47 def __init__( 

48 self, 

49 bucket, 

50 basedir, 

51 resource_dir=None, 

52 source_dir=None, 

53 priority=DEFAULT_S3_PRIORITY, 

54 region=None, 

55 ): 

56 """Constructor for a S3Observer object. 

57 

58 Run when the object is first created, 

59 before it's used within an experiment. 

60 

61 Parameters 

62 ---------- 

63 bucket 

64 The name of the bucket you want to store results in. 

65 Doesn't need to contain `s3://`, but needs to be a valid bucket name 

66 basedir 

67 The relative path inside your bucket where you want this experiment to store results 

68 resource_dir 

69 Where to store resources for this experiment. By 

70 default, will be <basedir>/_resources 

71 source_dir 

72 Where to store code sources for this experiment. By 

73 default, will be <basedir>/sources 

74 priority 

75 The priority to assign to this observer if 

76 multiple observers are present 

77 region 

78 The AWS region in which you want to create and access 

79 buckets. Needs to be either set here or configured in your AWS 

80 """ 

81 import boto3 

82 

83 if not _is_valid_bucket(bucket): 

84 raise ValueError( 

85 "Your chosen bucket name doesn't follow AWS bucket naming rules" 

86 ) 

87 resource_dir = resource_dir or "/".join([basedir, "_resources"]) 

88 source_dir = source_dir or "/".join([basedir, "_sources"]) 

89 

90 self.basedir = basedir 

91 self.bucket = bucket 

92 # Keeping the convention of referring to locations in S3 as `dir` 

93 # because that is a useful mental model and there isn't a better word 

94 self.resource_dir = resource_dir 

95 self.source_dir = source_dir 

96 self.priority = priority 

97 self.dir = None 

98 self.run_entry = None 

99 self.config = None 

100 self.info = None 

101 self.cout = "" 

102 self.cout_write_cursor = 0 

103 self.saved_metrics = {} 

104 if region is not None: 

105 self.region = region 

106 self.s3 = boto3.resource("s3", region_name=region) 

107 else: 

108 session = boto3.session.Session() 

109 if session.region_name is not None: 

110 self.region = session.region_name 

111 self.s3 = boto3.resource("s3") 

112 else: 

113 raise ValueError( 

114 "You must either pass in an AWS region name, or have a " 

115 "region name specified in your AWS config file" 

116 ) 

117 

118 def _objects_exist_in_dir(self, prefix): 

119 # This should be run after you've confirmed the bucket 

120 # exists, and will error out if it does not exist 

121 

122 bucket = self.s3.Bucket(self.bucket) 

123 all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] 

124 return len(all_keys) > 0 

125 

126 def _bucket_exists(self): 

127 from botocore.errorfactory import ClientError 

128 

129 try: 

130 self.s3.meta.client.head_bucket(Bucket=self.bucket) 

131 except ClientError as er: 

132 if er.response["Error"]["Code"] == "404": 

133 return False 

134 return True 

135 

136 def _list_s3_subdirs(self, prefix=None): 

137 if prefix is None: 

138 prefix = self.basedir 

139 bucket = self.s3.Bucket(self.bucket) 

140 all_keys = [obj.key for obj in bucket.objects.filter(Prefix=prefix)] 

141 subdir_match = r"{prefix}\/(.*)\/".format(prefix=prefix) 

142 subdirs = [] 

143 for key in all_keys: 

144 match_obj = re.match(subdir_match, key) 

145 if match_obj is None: 

146 continue 

147 else: 

148 subdirs.append(match_obj.groups()[0]) 

149 distinct_subdirs = set(subdirs) 

150 return list(distinct_subdirs) 

151 

152 def _create_bucket(self): 

153 bucket_response = self.s3.create_bucket( 

154 Bucket=self.bucket, 

155 CreateBucketConfiguration={"LocationConstraint": self.region}, 

156 ) 

157 return bucket_response 

158 

159 def _determine_run_dir(self, _id): 

160 if _id is None: 

161 bucket_exists = self._bucket_exists() 

162 if not bucket_exists: 

163 self._create_bucket() 

164 bucket_path_subdirs = [] 

165 else: 

166 bucket_path_subdirs = self._list_s3_subdirs() 

167 

168 if not bucket_path_subdirs: 

169 max_run_id = 0 

170 else: 

171 integer_directories = [ 

172 int(d) for d in bucket_path_subdirs if d.isdigit() 

173 ] 

174 if not integer_directories: 

175 max_run_id = 0 

176 else: 

177 # If there are directories under basedir that aren't 

178 # numeric run directories, ignore those 

179 max_run_id = max(integer_directories) 

180 

181 _id = max_run_id + 1 

182 

183 self.dir = s3_join(self.basedir, str(_id)) 

184 if self._objects_exist_in_dir(self.dir): 

185 raise FileExistsError("S3 dir at {} already exists".format(self.dir)) 

186 return _id 

187 

188 def queued_event( 

189 self, ex_info, command, host_info, queue_time, config, meta_info, _id 

190 ): 

191 _id = self._determine_run_dir(_id) 

192 

193 self.run_entry = { 

194 "experiment": dict(ex_info), 

195 "command": command, 

196 "host": dict(host_info), 

197 "meta": meta_info, 

198 "status": "QUEUED", 

199 } 

200 self.config = config 

201 self.info = {} 

202 

203 self.save_json(self.run_entry, "run.json") 

204 self.save_json(self.config, "config.json") 

205 

206 for s, m in ex_info["sources"]: 

207 self.save_file(s) 

208 

209 return _id 

210 

211 def save_sources(self, ex_info): 

212 base_dir = ex_info["base_dir"] 

213 source_info = [] 

214 for s, m in ex_info["sources"]: 

215 abspath = os.path.join(base_dir, s) 

216 store_path, md5sum = self.find_or_save(abspath, self.source_dir) 

217 source_info.append([s, os.path.relpath(store_path, self.basedir)]) 

218 return source_info 

219 

220 def started_event( 

221 self, ex_info, command, host_info, start_time, config, meta_info, _id 

222 ): 

223 

224 _id = self._determine_run_dir(_id) 

225 

226 ex_info["sources"] = self.save_sources(ex_info) 

227 

228 self.run_entry = { 

229 "experiment": dict(ex_info), 

230 "command": command, 

231 "host": dict(host_info), 

232 "start_time": start_time.isoformat(), 

233 "meta": meta_info, 

234 "status": "RUNNING", 

235 "resources": [], 

236 "artifacts": [], 

237 "heartbeat": None, 

238 } 

239 self.config = config 

240 self.info = {} 

241 self.cout = "" 

242 self.cout_write_cursor = 0 

243 

244 self.save_json(self.run_entry, "run.json") 

245 self.save_json(self.config, "config.json") 

246 self.save_cout() 

247 

248 return _id 

249 

250 def find_or_save(self, filename, store_dir): 

251 source_name, ext = os.path.splitext(os.path.basename(filename)) 

252 md5sum = get_digest(filename) 

253 store_name = source_name + "_" + md5sum + ext 

254 store_path = s3_join(store_dir, store_name) 

255 if len(self._list_s3_subdirs(prefix=store_path)) == 0: 

256 self.save_file(filename, store_path) 

257 return store_path, md5sum 

258 

259 def put_data(self, key, binary_data): 

260 self.s3.Object(self.bucket, key).put(Body=binary_data) 

261 

262 def save_json(self, obj, filename): 

263 key = s3_join(self.dir, filename) 

264 self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2)) 

265 

266 def save_file(self, filename, target_name=None): 

267 target_name = target_name or os.path.basename(filename) 

268 key = s3_join(self.dir, target_name) 

269 self.put_data(key, open(filename, "rb")) 

270 

271 def save_directory(self, source_dir, target_name): 

272 import boto3 

273 

274 # Stolen from: 

275 # https://github.com/boto/boto3/issues/358#issuecomment-346093506 

276 target_name = target_name or os.path.basename(source_dir) 

277 all_files = [] 

278 for root, dirs, files in os.walk(source_dir): 

279 all_files += [os.path.join(root, f) for f in files] 

280 s3_resource = boto3.resource("s3") 

281 

282 for filename in all_files: 

283 file_location = s3_join( 

284 self.dir, target_name, os.path.relpath(filename, source_dir) 

285 ) 

286 s3_resource.Object(self.bucket, file_location).put( 

287 Body=open(filename, "rb") 

288 ) 

289 

290 def save_cout(self): 

291 binary_data = self.cout[self.cout_write_cursor :].encode("utf-8") 

292 key = s3_join(self.dir, "cout.txt") 

293 self.put_data(key, binary_data) 

294 self.cout_write_cursor = len(self.cout) 

295 

296 def heartbeat_event(self, info, captured_out, beat_time, result): 

297 self.info = info 

298 self.run_entry["heartbeat"] = beat_time.isoformat() 

299 self.run_entry["result"] = result 

300 self.cout = captured_out 

301 self.save_cout() 

302 self.save_json(self.run_entry, "run.json") 

303 if self.info: 

304 self.save_json(self.info, "info.json") 

305 

306 def completed_event(self, stop_time, result): 

307 self.run_entry["stop_time"] = stop_time.isoformat() 

308 self.run_entry["result"] = result 

309 self.run_entry["status"] = "COMPLETED" 

310 

311 self.save_json(self.run_entry, "run.json") 

312 

313 def interrupted_event(self, interrupt_time, status): 

314 self.run_entry["stop_time"] = interrupt_time.isoformat() 

315 self.run_entry["status"] = status 

316 self.save_json(self.run_entry, "run.json") 

317 

318 def failed_event(self, fail_time, fail_trace): 

319 self.run_entry["stop_time"] = fail_time.isoformat() 

320 self.run_entry["status"] = "FAILED" 

321 self.run_entry["fail_trace"] = fail_trace 

322 self.save_json(self.run_entry, "run.json") 

323 

324 def resource_event(self, filename): 

325 store_path, md5sum = self.find_or_save(filename, self.resource_dir) 

326 self.run_entry["resources"].append([filename, store_path]) 

327 self.save_json(self.run_entry, "run.json") 

328 

329 def artifact_event(self, name, filename, metadata=None, content_type=None): 

330 self.save_file(filename, name) 

331 self.run_entry["artifacts"].append(name) 

332 self.save_json(self.run_entry, "run.json") 

333 

334 def log_metrics(self, metrics_by_name, info): 

335 """Store new measurements into metrics.json.""" 

336 for metric_name, metric_ptr in metrics_by_name.items(): 

337 

338 if metric_name not in self.saved_metrics: 

339 self.saved_metrics[metric_name] = { 

340 "values": [], 

341 "steps": [], 

342 "timestamps": [], 

343 } 

344 

345 self.saved_metrics[metric_name]["values"] += metric_ptr["values"] 

346 self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] 

347 

348 timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]] 

349 self.saved_metrics[metric_name]["timestamps"] += timestamps_norm 

350 

351 self.save_json(self.saved_metrics, "metrics.json") 

352 

353 def __eq__(self, other): 

354 if isinstance(other, S3Observer): 

355 return self.bucket == other.bucket and self.basedir == other.basedir 

356 else: 

357 return False 

358 

359 

360@cli_option("-S", "--s3") 

361def s3_option(args, run): 

362 """Add a S3 File observer to the experiment. 

363 

364 The argument value should be `s3://<bucket>/path/to/exp`. 

365 """ 

366 match_obj = re.match(r"s3:\/\/([^\/]*)\/(.*)", args) 

367 if match_obj is None or len(match_obj.groups()) != 2: 

368 raise ValueError( 

369 "Valid bucket specification not found. " 

370 "Enter bucket and directory path like: " 

371 "s3://<bucket>/path/to/exp" 

372 ) 

373 bucket, basedir = match_obj.groups() 

374 run.observers.append(S3Observer(bucket=bucket, basedir=basedir))