Coverage for /home/ubuntu/Documents/Research/mut_p6/sacred/sacred/observers/sql_bases.py: 0%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import hashlib
2import json
3import os
5import sqlalchemy as sa
6from sqlalchemy.ext.declarative import declarative_base
8from sacred.dependencies import get_digest
9from sacred.serializer import restore
12Base = declarative_base()
15class Source(Base):
16 __tablename__ = "source"
18 @classmethod
19 def get_or_create(cls, filename, md5sum, session, basedir):
20 instance = (
21 session.query(cls).filter_by(filename=filename, md5sum=md5sum).first()
22 )
23 if instance:
24 return instance
25 full_path = os.path.join(basedir, filename)
26 md5sum_ = get_digest(full_path)
27 assert md5sum_ == md5sum, "found md5 mismatch for {}: {} != {}".format(
28 filename, md5sum, md5sum_
29 )
30 with open(full_path, "r") as f:
31 return cls(filename=filename, md5sum=md5sum, content=f.read())
33 source_id = sa.Column(sa.Integer, primary_key=True)
34 filename = sa.Column(sa.String(256))
35 md5sum = sa.Column(sa.String(32))
36 content = sa.Column(sa.Text)
38 def to_json(self):
39 return {"filename": self.filename, "md5sum": self.md5sum}
42class Repository(Base):
43 __tablename__ = "repository"
45 @classmethod
46 def get_or_create(cls, url, commit, dirty, session):
47 instance = (
48 session.query(cls).filter_by(url=url, commit=commit, dirty=dirty).first()
49 )
50 if instance:
51 return instance
52 return cls(url=url, commit=commit, dirty=dirty)
54 repository_id = sa.Column(sa.Integer, primary_key=True)
55 url = sa.Column(sa.String(2048))
56 commit = sa.Column(sa.String(40))
57 dirty = sa.Column(sa.Boolean)
59 def to_json(self):
60 return {"url": self.url, "commit": self.commit, "dirty": self.dirty}
63class Dependency(Base):
64 __tablename__ = "dependency"
66 @classmethod
67 def get_or_create(cls, dep, session):
68 name, _, version = dep.partition("==")
69 instance = session.query(cls).filter_by(name=name, version=version).first()
70 if instance:
71 return instance
72 return cls(name=name, version=version)
74 dependency_id = sa.Column(sa.Integer, primary_key=True)
75 name = sa.Column(sa.String(32))
76 version = sa.Column(sa.String(16))
78 def to_json(self):
79 return "{}=={}".format(self.name, self.version)
82class Artifact(Base):
83 __tablename__ = "artifact"
85 @classmethod
86 def create(cls, name, filename):
87 with open(filename, "rb") as f:
88 return cls(filename=name, content=f.read())
90 artifact_id = sa.Column(sa.Integer, primary_key=True)
91 filename = sa.Column(sa.String(64))
92 content = sa.Column(sa.LargeBinary)
94 run_id = sa.Column(sa.String(24), sa.ForeignKey("run.run_id"))
95 run = sa.orm.relationship("Run", backref=sa.orm.backref("artifacts"))
97 def to_json(self):
98 return {"_id": self.artifact_id, "filename": self.filename}
101class Resource(Base):
102 __tablename__ = "resource"
104 @classmethod
105 def get_or_create(cls, filename, session):
106 md5sum = get_digest(filename)
107 instance = (
108 session.query(cls).filter_by(filename=filename, md5sum=md5sum).first()
109 )
110 if instance:
111 return instance
112 with open(filename, "rb") as f:
113 return cls(filename=filename, md5sum=md5sum, content=f.read())
115 resource_id = sa.Column(sa.Integer, primary_key=True)
116 filename = sa.Column(sa.String(256))
117 md5sum = sa.Column(sa.String(32))
118 content = sa.Column(sa.LargeBinary)
120 def to_json(self):
121 return {"filename": self.filename, "md5sum": self.md5sum}
124class Host(Base):
125 __tablename__ = "host"
127 @classmethod
128 def get_or_create(cls, host_info, session):
129 h = dict(
130 hostname=host_info["hostname"],
131 cpu=host_info["cpu"],
132 os=host_info["os"][0],
133 os_info=host_info["os"][1],
134 python_version=host_info["python_version"],
135 )
137 return session.query(cls).filter_by(**h).first() or cls(**h)
139 host_id = sa.Column(sa.Integer, primary_key=True)
140 cpu = sa.Column(sa.String(64))
141 hostname = sa.Column(sa.String(64))
142 os = sa.Column(sa.String(16))
143 os_info = sa.Column(sa.String(64))
144 python_version = sa.Column(sa.String(16))
146 def to_json(self):
147 return {
148 "cpu": self.cpu,
149 "hostname": self.hostname,
150 "os": [self.os, self.os_info],
151 "python_version": self.python_version,
152 }
155experiment_source_association = sa.Table(
156 "experiments_sources",
157 Base.metadata,
158 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")),
159 sa.Column("source_id", sa.Integer, sa.ForeignKey("source.source_id")),
160)
162experiment_repository_association = sa.Table(
163 "experiments_repositories",
164 Base.metadata,
165 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")),
166 sa.Column("repository_id", sa.Integer, sa.ForeignKey("repository.repository_id")),
167)
169experiment_dependency_association = sa.Table(
170 "experiments_dependencies",
171 Base.metadata,
172 sa.Column("experiment_id", sa.Integer, sa.ForeignKey("experiment.experiment_id")),
173 sa.Column("dependency_id", sa.Integer, sa.ForeignKey("dependency.dependency_id")),
174)
177class Experiment(Base):
178 __tablename__ = "experiment"
180 @classmethod
181 def get_or_create(cls, ex_info, session):
182 name = ex_info["name"]
183 # Compute a MD5sum of the ex_info to determine its uniqueness
184 h = hashlib.md5()
185 h.update(json.dumps(ex_info).encode())
186 md5 = h.hexdigest()
187 instance = session.query(cls).filter_by(name=name, md5sum=md5).first()
188 if instance:
189 return instance
191 dependencies = [
192 Dependency.get_or_create(d, session) for d in ex_info["dependencies"]
193 ]
194 sources = [
195 Source.get_or_create(s, md5sum, session, ex_info["base_dir"])
196 for s, md5sum in ex_info["sources"]
197 ]
198 repositories = set()
199 for r in ex_info["repositories"]:
200 repository = Repository.get_or_create(
201 r["url"], r["commit"], r["dirty"], session
202 )
203 session.add(repository)
204 repositories.add(repository)
205 repositories = list(repositories)
207 return cls(
208 name=name,
209 dependencies=dependencies,
210 sources=sources,
211 repositories=repositories,
212 md5sum=md5,
213 base_dir=ex_info["base_dir"],
214 )
216 experiment_id = sa.Column(sa.Integer, primary_key=True)
217 name = sa.Column(sa.String(32))
218 md5sum = sa.Column(sa.String(32))
219 base_dir = sa.Column(sa.String(64))
220 sources = sa.orm.relationship(
221 "Source", secondary=experiment_source_association, backref="experiments"
222 )
223 repositories = sa.orm.relationship(
224 "Repository", secondary=experiment_repository_association, backref="experiments"
225 )
226 dependencies = sa.orm.relationship(
227 "Dependency", secondary=experiment_dependency_association, backref="experiments"
228 )
230 def to_json(self):
231 return {
232 "name": self.name,
233 "base_dir": self.base_dir,
234 "sources": [s.to_json() for s in self.sources],
235 "repositories": [r.to_json() for r in self.repositories],
236 "dependencies": [d.to_json() for d in self.dependencies],
237 }
240run_resource_association = sa.Table(
241 "runs_resources",
242 Base.metadata,
243 sa.Column("run_id", sa.String(24), sa.ForeignKey("run.run_id")),
244 sa.Column("resource_id", sa.Integer, sa.ForeignKey("resource.resource_id")),
245)
248class Run(Base):
249 __tablename__ = "run"
250 id = sa.Column(sa.Integer, primary_key=True)
252 run_id = sa.Column(sa.String(24), unique=True)
254 command = sa.Column(sa.String(64))
256 # times
257 start_time = sa.Column(sa.DateTime)
258 heartbeat = sa.Column(sa.DateTime)
259 stop_time = sa.Column(sa.DateTime)
260 queue_time = sa.Column(sa.DateTime)
262 # meta info
263 priority = sa.Column(sa.Float)
264 comment = sa.Column(sa.Text)
266 fail_trace = sa.Column(sa.Text)
268 # Captured out
269 # TODO: move to separate table?
270 captured_out = sa.Column(sa.Text)
272 # Configuration & info
273 # TODO: switch type to json if possible
274 config = sa.Column(sa.Text)
275 info = sa.Column(sa.Text)
277 status = sa.Column(
278 sa.Enum(
279 "RUNNING",
280 "COMPLETED",
281 "INTERRUPTED",
282 "TIMEOUT",
283 "FAILED",
284 name="status_enum",
285 )
286 )
288 host_id = sa.Column(sa.Integer, sa.ForeignKey("host.host_id"))
289 host = sa.orm.relationship("Host", backref=sa.orm.backref("runs"))
291 experiment_id = sa.Column(sa.Integer, sa.ForeignKey("experiment.experiment_id"))
292 experiment = sa.orm.relationship("Experiment", backref=sa.orm.backref("runs"))
294 # artifacts = backref
295 resources = sa.orm.relationship(
296 "Resource", secondary=run_resource_association, backref="runs"
297 )
299 result = sa.Column(sa.Float)
301 def to_json(self):
302 return {
303 "_id": self.run_id,
304 "command": self.command,
305 "start_time": self.start_time,
306 "heartbeat": self.heartbeat,
307 "stop_time": self.stop_time,
308 "queue_time": self.queue_time,
309 "status": self.status,
310 "result": self.result,
311 "meta": {"comment": self.comment, "priority": self.priority},
312 "resources": [r.to_json() for r in self.resources],
313 "artifacts": [a.to_json() for a in self.artifacts],
314 "host": self.host.to_json(),
315 "experiment": self.experiment.to_json(),
316 "config": restore(json.loads(self.config)),
317 "captured_out": self.captured_out,
318 "fail_trace": self.fail_trace,
319 }