Coverage for sacred/sacred/observers/sql_bases.py: 0%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

154 statements  

1import hashlib 

2import json 

3import os 

4 

5import sqlalchemy as sa 

6from sqlalchemy.ext.declarative import declarative_base 

7 

8from sacred.dependencies import get_digest 

9from sacred.serializer import restore 

10 

11 

12Base = declarative_base() 

13 

14 

15class Source(Base): 

16 __tablename__ = "source" 

17 

18 @classmethod 

19 def get_or_create(cls, filename, md5sum, session, basedir): 

20 instance = ( 

21 session.query(cls).filter_by(filename=filename, md5sum=md5sum).first() 

22 ) 

23 if instance: 

24 return instance 

25 full_path = os.path.join(basedir, filename) 

26 md5sum_ = get_digest(full_path) 

27 assert md5sum_ == md5sum, "found md5 mismatch for {}: {} != {}".format( 

28 filename, md5sum, md5sum_ 

29 ) 

30 with open(full_path, "r") as f: 

31 return cls(filename=filename, md5sum=md5sum, content=f.read()) 

32 

33 source_id = sa.Column(sa.Integer, primary_key=True) 

34 filename = sa.Column(sa.String(256)) 

35 md5sum = sa.Column(sa.String(32)) 

36 content = sa.Column(sa.Text) 

37 

38 def to_json(self): 

39 return {"filename": self.filename, "md5sum": self.md5sum} 

40 

41 

42class Repository(Base): 

43 __tablename__ = "repository" 

44 

45 @classmethod 

46 def get_or_create(cls, url, commit, dirty, session): 

47 instance = ( 

48 session.query(cls).filter_by(url=url, commit=commit, dirty=dirty).first() 

49 ) 

50 if instance: 

51 return instance 

52 return cls(url=url, commit=commit, dirty=dirty) 

53 

54 repository_id = sa.Column(sa.Integer, primary_key=True) 

55 url = sa.Column(sa.String(2048)) 

56 commit = sa.Column(sa.String(40)) 

57 dirty = sa.Column(sa.Boolean) 

58 

59 def to_json(self): 

60 return {"url": self.url, "commit": self.commit, "dirty": self.dirty} 

61 

62 

63class Dependency(Base): 

64 __tablename__ = "dependency" 

65 

66 @classmethod 

67 def get_or_create(cls, dep, session): 

68 name, _, version = dep.partition("==") 

69 instance = session.query(cls).filter_by(name=name, version=version).first() 

70 if instance: 

71 return instance 

72 return cls(name=name, version=version) 

73 

74 dependency_id = sa.Column(sa.Integer, primary_key=True) 

75 name = sa.Column(sa.String(32)) 

76 version = sa.Column(sa.String(16)) 

77 

78 def to_json(self): 

79 return "{}=={}".format(self.name, self.version) 

80 

81 

82class Artifact(Base): 

83 __tablename__ = "artifact" 

84 

85 @classmethod 

86 def create(cls, name, filename): 

87 with open(filename, "rb") as f: 

88 return cls(filename=name, content=f.read()) 

89 

90 artifact_id = sa.Column(sa.Integer, primary_key=True) 

91 filename = sa.Column(sa.String(64)) 

92 content = sa.Column(sa.LargeBinary) 

93 

94 run_id = sa.Column(sa.String(24), sa.ForeignKey("run.run_id")) 

95 run = sa.orm.relationship("Run", backref=sa.orm.backref("artifacts")) 

96 

97 def to_json(self): 

98 return {"_id": self.artifact_id, "filename": self.filename} 

99 

100 

101class Resource(Base): 

102 __tablename__ = "resource" 

103 

104 @classmethod 

105 def get_or_create(cls, filename, session): 

106 md5sum = get_digest(filename) 

107 instance = ( 

108 session.query(cls).filter_by(filename=filename, md5sum=md5sum).first() 

109 ) 

110 if instance: 

111 return instance 

112 with open(filename, "rb") as f: 

113 return cls(filename=filename, md5sum=md5sum, content=f.read()) 

114 

115 resource_id = sa.Column(sa.Integer, primary_key=True) 

116 filename = sa.Column(sa.String(256)) 

117 md5sum = sa.Column(sa.String(32)) 

118 content = sa.Column(sa.LargeBinary) 

119 

120 def to_json(self): 

121 return {"filename": self.filename, "md5sum": self.md5sum} 

122 

123 

124class Host(Base): 

125 __tablename__ = "host" 

126 

127 @classmethod 

128 def get_or_create(cls, host_info, session): 

129 h = dict( 

130 hostname=host_info["hostname"], 

131 cpu=host_info["cpu"], 

132 os=host_info["os"][0], 

133 os_info=host_info["os"][1], 

134 python_version=host_info["python_version"], 

135 ) 

136 

137 return session.query(cls).filter_by(**h).first() or cls(**h) 

138 

139 host_id = sa.Column(sa.Integer, primary_key=True) 

140 cpu = sa.Column(sa.String(64)) 

141 hostname = sa.Column(sa.String(64)) 

142 os = sa.Column(sa.String(16)) 

143 os_info = sa.Column(sa.String(64)) 

144 python_version = sa.Column(sa.String(16)) 

145 

146 def to_json(self): 

147 return { 

148 "cpu": self.cpu, 

149 "hostname": self.hostname, 

150 "os": [self.os, self.os_info], 

151 "python_version": self.python_version, 

152 } 

153 

154 

155experiment_source_association = sa.Table( 

156 "experiments_sources", 

157 Base.metadata, 

158 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")), 

159 sa.Column("source_id", sa.Integer, sa.ForeignKey("source.source_id")), 

160) 

161 

162experiment_repository_association = sa.Table( 

163 "experiments_repositories", 

164 Base.metadata, 

165 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")), 

166 sa.Column("repository_id", sa.Integer, sa.ForeignKey("repository.repository_id")), 

167) 

168 

169experiment_dependency_association = sa.Table( 

170 "experiments_dependencies", 

171 Base.metadata, 

172 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")), 

173 sa.Column("dependency_id", sa.Integer, sa.ForeignKey("dependency.dependency_id")), 

174) 

175 

176 

177class Experiment(Base): 

178 __tablename__ = "experiment" 

179 

180 @classmethod 

181 def get_or_create(cls, ex_info, session): 

182 name = ex_info["name"] 

183 # Compute a MD5sum of the ex_info to determine its uniqueness 

184 h = hashlib.md5() 

185 h.update(json.dumps(ex_info).encode()) 

186 md5 = h.hexdigest() 

187 instance = session.query(cls).filter_by(name=name, md5sum=md5).first() 

188 if instance: 

189 return instance 

190 

191 dependencies = [ 

192 Dependency.get_or_create(d, session) for d in ex_info["dependencies"] 

193 ] 

194 sources = [ 

195 Source.get_or_create(s, md5sum, session, ex_info["base_dir"]) 

196 for s, md5sum in ex_info["sources"] 

197 ] 

198 repositories = set() 

199 for r in ex_info["repositories"]: 

200 repository = Repository.get_or_create( 

201 r["url"], r["commit"], r["dirty"], session 

202 ) 

203 session.add(repository) 

204 repositories.add(repository) 

205 repositories = list(repositories) 

206 

207 return cls( 

208 name=name, 

209 dependencies=dependencies, 

210 sources=sources, 

211 repositories=repositories, 

212 md5sum=md5, 

213 base_dir=ex_info["base_dir"], 

214 ) 

215 

216 experiment_id = sa.Column(sa.Integer, primary_key=True) 

217 name = sa.Column(sa.String(32)) 

218 md5sum = sa.Column(sa.String(32)) 

219 base_dir = sa.Column(sa.String(64)) 

220 sources = sa.orm.relationship( 

221 "Source", secondary=experiment_source_association, backref="experiments" 

222 ) 

223 repositories = sa.orm.relationship( 

224 "Repository", secondary=experiment_repository_association, backref="experiments" 

225 ) 

226 dependencies = sa.orm.relationship( 

227 "Dependency", secondary=experiment_dependency_association, backref="experiments" 

228 ) 

229 

230 def to_json(self): 

231 return { 

232 "name": self.name, 

233 "base_dir": self.base_dir, 

234 "sources": [s.to_json() for s in self.sources], 

235 "repositories": [r.to_json() for r in self.repositories], 

236 "dependencies": [d.to_json() for d in self.dependencies], 

237 } 

238 

239 

240run_resource_association = sa.Table( 

241 "runs_resources", 

242 Base.metadata, 

243 sa.Column("run_id", sa.String(24), sa.ForeignKey("run.run_id")), 

244 sa.Column("resource_id", sa.Integer, sa.ForeignKey("resource.resource_id")), 

245) 

246 

247 

248class Run(Base): 

249 __tablename__ = "run" 

250 id = sa.Column(sa.Integer, primary_key=True) 

251 

252 run_id = sa.Column(sa.String(24), unique=True) 

253 

254 command = sa.Column(sa.String(64)) 

255 

256 # times 

257 start_time = sa.Column(sa.DateTime) 

258 heartbeat = sa.Column(sa.DateTime) 

259 stop_time = sa.Column(sa.DateTime) 

260 queue_time = sa.Column(sa.DateTime) 

261 

262 # meta info 

263 priority = sa.Column(sa.Float) 

264 comment = sa.Column(sa.Text) 

265 

266 fail_trace = sa.Column(sa.Text) 

267 

268 # Captured out 

269 # TODO: move to separate table? 

270 captured_out = sa.Column(sa.Text) 

271 

272 # Configuration & info 

273 # TODO: switch type to json if possible 

274 config = sa.Column(sa.Text) 

275 info = sa.Column(sa.Text) 

276 

277 status = sa.Column( 

278 sa.Enum( 

279 "RUNNING", 

280 "COMPLETED", 

281 "INTERRUPTED", 

282 "TIMEOUT", 

283 "FAILED", 

284 name="status_enum", 

285 ) 

286 ) 

287 

288 host_id = sa.Column(sa.Integer, sa.ForeignKey("host.host_id")) 

289 host = sa.orm.relationship("Host", backref=sa.orm.backref("runs")) 

290 

291 experiment_id = sa.Column(sa.Integer, sa.ForeignKey("experiment.experiment_id")) 

292 experiment = sa.orm.relationship("Experiment", backref=sa.orm.backref("runs")) 

293 

294 # artifacts = backref 

295 resources = sa.orm.relationship( 

296 "Resource", secondary=run_resource_association, backref="runs" 

297 ) 

298 

299 result = sa.Column(sa.Float) 

300 

301 def to_json(self): 

302 return { 

303 "_id": self.run_id, 

304 "command": self.command, 

305 "start_time": self.start_time, 

306 "heartbeat": self.heartbeat, 

307 "stop_time": self.stop_time, 

308 "queue_time": self.queue_time, 

309 "status": self.status, 

310 "result": self.result, 

311 "meta": {"comment": self.comment, "priority": self.priority}, 

312 "resources": [r.to_json() for r in self.resources], 

313 "artifacts": [a.to_json() for a in self.artifacts], 

314 "host": self.host.to_json(), 

315 "experiment": self.experiment.to_json(), 

316 "config": restore(json.loads(self.config)), 

317 "captured_out": self.captured_out, 

318 "fail_trace": self.fail_trace, 

319 }