Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/observers/gcs_observer.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

197 statements  

1import json 

2import os 

3import os.path 

4import re 

5from typing import Optional 

6 

7from sacred.commandline_options import cli_option 

8from sacred.dependencies import get_digest 

9from sacred.observers.base import RunObserver 

10from sacred.serializer import flatten 

11from sacred.utils import PathType 

12 

13DEFAULT_GCS_PRIORITY = 20 

14 

15 

16def _is_valid_bucket(bucket_name: str): 

17 """Validates correctness of bucket naming. 

18 

19 Reference: https://cloud.google.com/storage/docs/naming 

20 """ 

21 if bucket_name.startswith("gs://"): 

22 return False 

23 

24 if len(bucket_name) < 3 or len(bucket_name) > 63: 

25 return False 

26 

27 # IP address 

28 if re.match(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$", bucket_name): 

29 return False 

30 

31 if not re.fullmatch(r"([^A-Z]|-|_|[.]|)+", bucket_name): 

32 return False 

33 

34 if ".." in bucket_name: 

35 return False 

36 

37 if "goog" in bucket_name or "g00g" in bucket_name: 

38 return False 

39 

40 return True 

41 

42 

43def gcs_join(*args): 

44 return "/".join(args) 

45 

46 

47class GoogleCloudStorageObserver(RunObserver): 

48 VERSION = "GoogleCloudStorageObserver-0.1.0" 

49 

50 def __init__( 

51 self, 

52 bucket: str, 

53 basedir: PathType, 

54 resource_dir: Optional[PathType] = None, 

55 source_dir: Optional[PathType] = None, 

56 priority: Optional[int] = DEFAULT_GCS_PRIORITY, 

57 ): 

58 """Constructor for a GoogleCloudStorageObserver object. 

59 

60 Run when the object is first created, 

61 before it's used within an experiment. 

62 

63 Parameters 

64 ---------- 

65 bucket 

66 The name of the bucket you want to store results in. 

67 Needs to be a valid bucket name without 'gs://' 

68 basedir 

69 The relative path inside your bucket where you want this experiment to store results 

70 resource_dir 

71 Where to store resources for this experiment. By 

72 default, will be <basedir>/_resources 

73 source_dir 

74 Where to store code sources for this experiment. By 

75 default, will be <basedir>/sources 

76 priority 

77 The priority to assign to this observer if 

78 multiple observers are present 

79 """ 

80 if not _is_valid_bucket(bucket): 

81 raise ValueError( 

82 "Your chosen bucket name doesn't follow Google Cloud Storage bucket naming rules" 

83 ) 

84 resource_dir = resource_dir or "/".join([basedir, "_resources"]) 

85 source_dir = source_dir or "/".join([basedir, "_sources"]) 

86 

87 self.basedir = basedir 

88 self.bucket_id = bucket 

89 

90 self.resource_dir = resource_dir 

91 self.source_dir = source_dir 

92 self.priority = priority 

93 self.dir = None 

94 self.run_entry = None 

95 self.config = None 

96 self.info = None 

97 self.cout = "" 

98 self.cout_write_cursor = 0 

99 self.saved_metrics = {} 

100 

101 from google.cloud import storage 

102 import google.auth.exceptions 

103 

104 try: 

105 client = storage.Client() 

106 except google.auth.exceptions.DefaultCredentialsError: 

107 raise ConnectionError( 

108 "Could not create Google Cloud Storage observer, are you " 

109 "sure that you have set environment variable GOOGLE_APPLICATION_CREDENTIALS?" 

110 ) 

111 

112 self.bucket = client.bucket(bucket) 

113 

114 def _objects_exist_in_dir(self, prefix): 

115 # This should be run after you've confirmed the bucket 

116 # exists, and will error out if it does not exist 

117 all_blobs = [blob for blob in self.bucket.list_blobs(prefix=prefix)] 

118 return len(all_blobs) > 0 

119 

120 def _list_gcs_subdirs(self, prefix=None): 

121 if prefix is None: 

122 prefix = self.basedir 

123 

124 iterator = self.bucket.list_blobs(prefix=prefix, delimiter="/") 

125 prefixes = set() 

126 for page in iterator.pages: 

127 prefixes.update(page.prefixes) 

128 

129 return list(prefixes) 

130 

131 def _determine_run_dir(self, _id): 

132 if _id is None: 

133 basepath = os.path.join(self.basedir, "") 

134 bucket_path_subdirs = self._list_gcs_subdirs(prefix=basepath) 

135 

136 if not bucket_path_subdirs: 

137 max_run_id = 0 

138 else: 

139 relative_paths = [ 

140 path.replace(self.basedir, "").strip("/") 

141 for path in bucket_path_subdirs 

142 ] 

143 integer_directories = [int(d) for d in relative_paths if d.isdigit()] 

144 if not integer_directories: 

145 max_run_id = 0 

146 else: 

147 # If there are directories under basedir that aren't 

148 # numeric run directories, ignore those 

149 max_run_id = max(integer_directories) 

150 

151 _id = max_run_id + 1 

152 

153 self.dir = gcs_join(self.basedir, str(_id)) 

154 if self._objects_exist_in_dir(self.dir): 

155 raise FileExistsError("GCS dir at {} already exists".format(self.dir)) 

156 return _id 

157 

158 def queued_event( 

159 self, ex_info, command, host_info, queue_time, config, meta_info, _id 

160 ): 

161 _id = self._determine_run_dir(_id) 

162 

163 self.run_entry = { 

164 "experiment": dict(ex_info), 

165 "command": command, 

166 "host": dict(host_info), 

167 "meta": meta_info, 

168 "status": "QUEUED", 

169 } 

170 self.config = config 

171 self.info = {} 

172 

173 self.save_json(self.run_entry, "run.json") 

174 self.save_json(self.config, "config.json") 

175 

176 for s, m in ex_info["sources"]: 

177 self.save_file(s) 

178 

179 return _id 

180 

181 def save_sources(self, ex_info): 

182 base_dir = ex_info["base_dir"] 

183 source_info = [] 

184 for s, m in ex_info["sources"]: 

185 abspath = os.path.join(base_dir, s) 

186 store_path, md5sum = self.find_or_save(abspath, self.source_dir) 

187 source_info.append([s, os.path.relpath(store_path, self.basedir)]) 

188 return source_info 

189 

190 def started_event( 

191 self, ex_info, command, host_info, start_time, config, meta_info, _id 

192 ): 

193 

194 _id = self._determine_run_dir(_id) 

195 

196 ex_info["sources"] = self.save_sources(ex_info) 

197 

198 self.run_entry = { 

199 "experiment": dict(ex_info), 

200 "command": command, 

201 "host": dict(host_info), 

202 "start_time": start_time.isoformat(), 

203 "meta": meta_info, 

204 "status": "RUNNING", 

205 "resources": [], 

206 "artifacts": [], 

207 "heartbeat": None, 

208 } 

209 self.config = config 

210 self.info = {} 

211 self.cout = "" 

212 self.cout_write_cursor = 0 

213 

214 self.save_json(self.run_entry, "run.json") 

215 self.save_json(self.config, "config.json") 

216 self.save_cout() 

217 

218 return _id 

219 

220 def find_or_save(self, filename, store_dir): 

221 source_name, ext = os.path.splitext(os.path.basename(filename)) 

222 md5sum = get_digest(filename) 

223 store_name = source_name + "_" + md5sum + ext 

224 store_path = gcs_join(store_dir, store_name) 

225 if len(self._list_gcs_subdirs(prefix=store_path)) == 0: 

226 self.save_file_to_base(filename, store_path) 

227 return store_path, md5sum 

228 

229 def put_data(self, key, binary_data): 

230 blob = self.bucket.blob(key) 

231 blob.upload_from_file(binary_data) 

232 

233 def save_json(self, obj, filename): 

234 key = gcs_join(self.dir, filename) 

235 blob = self.bucket.blob(key) 

236 blob.upload_from_string( 

237 json.dumps(flatten(obj), sort_keys=True, indent=2), content_type="text/json" 

238 ) 

239 

240 def save_file(self, filename, target_name=None): 

241 target_name = target_name or os.path.basename(filename) 

242 key = gcs_join(self.dir, target_name) 

243 self.put_data(key, open(filename, "rb")) 

244 

245 def save_file_to_base(self, filename, target_name=None): 

246 target_name = target_name or os.path.basename(filename) 

247 self.put_data(target_name, open(filename, "rb")) 

248 

249 def save_directory(self, source_dir, target_name): 

250 target_name = target_name or os.path.basename(source_dir) 

251 all_files = [] 

252 for root, dirs, files in os.walk(source_dir): 

253 all_files += [os.path.join(root, f) for f in files] 

254 

255 for filename in all_files: 

256 file_location = gcs_join( 

257 self.dir, target_name, os.path.relpath(filename, source_dir) 

258 ) 

259 self.put_data(file_location, open(filename, "rb")) 

260 

261 def save_cout(self): 

262 binary_data = self.cout[self.cout_write_cursor :].encode("utf-8") 

263 key = gcs_join(self.dir, "cout.txt") 

264 blob = self.bucket.blob(key) 

265 blob.upload_from_string(binary_data, content_type="text/plain") 

266 self.cout_write_cursor = len(self.cout) 

267 

268 def heartbeat_event(self, info, captured_out, beat_time, result): 

269 self.info = info 

270 self.run_entry["heartbeat"] = beat_time.isoformat() 

271 self.run_entry["result"] = result 

272 self.cout = captured_out 

273 self.save_cout() 

274 self.save_json(self.run_entry, "run.json") 

275 if self.info: 

276 self.save_json(self.info, "info.json") 

277 

278 def completed_event(self, stop_time, result): 

279 self.run_entry["stop_time"] = stop_time.isoformat() 

280 self.run_entry["result"] = result 

281 self.run_entry["status"] = "COMPLETED" 

282 

283 self.save_json(self.run_entry, "run.json") 

284 

285 def interrupted_event(self, interrupt_time, status): 

286 self.run_entry["stop_time"] = interrupt_time.isoformat() 

287 self.run_entry["status"] = status 

288 self.save_json(self.run_entry, "run.json") 

289 

290 def failed_event(self, fail_time, fail_trace): 

291 self.run_entry["stop_time"] = fail_time.isoformat() 

292 self.run_entry["status"] = "FAILED" 

293 self.run_entry["fail_trace"] = fail_trace 

294 self.save_json(self.run_entry, "run.json") 

295 

296 def resource_event(self, filename): 

297 store_path, md5sum = self.find_or_save(filename, self.resource_dir) 

298 self.run_entry["resources"].append([filename, store_path]) 

299 self.save_json(self.run_entry, "run.json") 

300 

301 def artifact_event(self, name, filename, metadata=None, content_type=None): 

302 self.save_file(filename, name) 

303 self.run_entry["artifacts"].append(name) 

304 self.save_json(self.run_entry, "run.json") 

305 

306 def log_metrics(self, metrics_by_name, info): 

307 """Store new measurements into metrics.json.""" 

308 for metric_name, metric_ptr in metrics_by_name.items(): 

309 

310 if metric_name not in self.saved_metrics: 

311 self.saved_metrics[metric_name] = { 

312 "values": [], 

313 "steps": [], 

314 "timestamps": [], 

315 } 

316 

317 self.saved_metrics[metric_name]["values"] += metric_ptr["values"] 

318 self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] 

319 

320 timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]] 

321 self.saved_metrics[metric_name]["timestamps"] += timestamps_norm 

322 

323 self.save_json(self.saved_metrics, "metrics.json") 

324 

325 def __eq__(self, other): 

326 if isinstance(other, GoogleCloudStorageObserver): 

327 return self.bucket_id == other.bucket_id and self.basedir == other.basedir 

328 else: 

329 return False 

330 

331 

332@cli_option("-G", "--gcs") 

333def gcs_option(args, run): 

334 """Add a Google Cloud Storage File observer to the experiment. 

335 

336 The argument value should be `gs://<bucket>/path/to/exp`. 

337 """ 

338 match_obj = re.match(r"gs:\/\/([^\/]*)\/(.*)", args) 

339 if match_obj is None or len(match_obj.groups()) != 2: 

340 raise ValueError( 

341 "Valid bucket specification not found. " 

342 "Enter bucket and directory path like: " 

343 "gs://<bucket>/path/to/exp" 

344 ) 

345 bucket, basedir = match_obj.groups() 

346 run.observers.append(GoogleCloudStorageObserver(bucket=bucket, basedir=basedir))