Skip to content

Commit

Permalink
Add queryset save api (#37)
Browse files Browse the repository at this point in the history
* Add queryset save api

* Fix py35 errors and adapt ArtifactSet api
  • Loading branch information
JarnoRFB authored Apr 11, 2019
1 parent 6508af8 commit 2581127
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 109 deletions.
162 changes: 100 additions & 62 deletions demo.ipynb

Large diffs are not rendered by default.

18 changes: 11 additions & 7 deletions example_experiment/conduct.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# -*- coding: future_fstrings -*-
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tensorflow as tf
from tensorflow.python.keras.callbacks import Callback

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FFMpegWriter
from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from sklearn.metrics import confusion_matrix
from tensorflow.python.keras.callbacks import Callback


class MetricsLogger(Callback):
Expand All @@ -23,7 +25,7 @@ def on_epoch_end(self, epoch, logs):
self._run.log_scalar("training_acc", float(logs["acc"]), step=epoch)


def plot_confusion_matrix(confusion_matrix, class_names, figsize=(15, 12), fontsize=14):
def plot_confusion_matrix(confusion_matrix, class_names, figsize=(15, 12)):
"""Prints a confusion matrix, as returned by sklearn.metrics.confusion_matrix, as a heatmap.
Based on https://gist.github.com/shaypal5/94c53d765083101efc0240d776a23823
Expand Down Expand Up @@ -141,15 +143,17 @@ def conduct(epochs, optimizer, _run):

filename = "confusion_matrix.pdf"
fig.savefig(filename)
_run.add_artifact(filename, name="confusion_matrix_pdf")
_run.add_artifact(filename)

plot_accuracy_development(history, _run)
write_csv_as_text(history, _run)
scalar_results = model.evaluate(x_test, y_test, verbose=0)

filename = "model.hdf5"
model.save(filename)
_run.add_artifact(filename)

scalar_results = model.evaluate(x_test, y_test, verbose=0)
results = dict(zip(model.metrics_names, scalar_results))
print("Final test results")
print(results)
for metric, value in results.items():
_run.log_scalar(f"test_{metric}", value)

Expand Down
28 changes: 17 additions & 11 deletions incense/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pickle
import warnings
from copy import copy
from typing import *

import pandas as pd
from IPython import display
Expand All @@ -18,7 +19,7 @@ def __init__(self, name: str, file, content_type: str = None):
self.name = name
self.file = file
self.content_type = content_type
self.extension = None if self.content_type is None else self.content_type.split("/")[-1]
self.extension = "" if self.content_type is None else self.content_type.split("/")[-1]
self._content = None
self._rendered = None

Expand All @@ -43,12 +44,20 @@ def show(self):
)
return self.render()

def save(self, save_dir: str = ""):
"""Save artifact to disk."""
with open(os.path.join(save_dir, self._make_filename()), "wb") as file:
def save(self, to_dir: str = "") -> None:
"""
Save artifact to disk.
Args:
to_dir: Directory in which to save the artifact. Defaults to the current working directory.
"""
if to_dir:
os.makedirs(str(to_dir), exist_ok=True)
with open(os.path.join(str(to_dir), self._make_filename()), "wb") as file:
file.write(self.content)

def as_content_type(self, content_type):
def as_content_type(self, content_type) -> "Artifact":
"""Interpret artifact as being of content-type."""
try:
artifact_type = content_type_to_artifact_cls[content_type]
Expand All @@ -57,7 +66,7 @@ def as_content_type(self, content_type):
else:
return self.as_type(artifact_type)

def as_type(self, artifact_type):
def as_type(self, artifact_type) -> "Artifact":
self.file.seek(0)
return artifact_type(self.name, self.file)

Expand All @@ -69,8 +78,8 @@ def content(self):
return self._content

def _make_filename(self):
parts = self.file.filename.split("/")
return f"{parts[-2]}_{parts[-1]}.{self.extension}"
exp_id, artifact_name = self.file.filename.split("/")[-2:]
return f"{exp_id}_{artifact_name}" + ("" if artifact_name.endswith(self.extension) else f".{self.extension}")


class ImageArtifact(Artifact):
Expand Down Expand Up @@ -133,9 +142,6 @@ class PDFArtifact(Artifact):

content_type_to_artifact_cls = {}
for cls in copy(locals()).values():
# print(cls)
if isinstance(cls, type) and issubclass(cls, Artifact):
for content_type in cls.can_render:
content_type_to_artifact_cls[content_type] = cls

# print(content_type_to_artifact_cls)
41 changes: 41 additions & 0 deletions incense/query_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: future_fstrings -*-
from collections import OrderedDict, UserList, defaultdict
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from fnmatch import fnmatch
from functools import reduce
from typing import *

Expand Down Expand Up @@ -87,3 +89,42 @@ def _get(self, o, name):
return getattr(o, name)
except AttributeError:
return o[name]

@property
def artifacts(self):
return ArtifactIndexer(self)


class ArtifactIndexer:
def __init__(self, experiments: QuerySet):
self._experiments = experiments

def filter(self, pattern):
"""
Get all artifacts that match a name of pattern.
This method does not indicate whether the requested artifacts could be found
only on some artifacts.
Args:
pattern: glob pattern, that is matched against artifact name.
Returns:
"""
return ArtifactSet(
artifact
for exp in self._experiments
for artifact_name, artifact in exp.artifacts.items()
if fnmatch(artifact_name, pattern)
)

def __getitem__(self, item):
return ArtifactSet(exp.artifacts[item] for exp in self._experiments)


class ArtifactSet(UserList):
def save(self, to_dir, n_threads=None):
with ThreadPoolExecutor(max_workers=n_threads) as executer:
for artifact in self.data:
executer.submit(artifact.save, to_dir=to_dir)
60 changes: 31 additions & 29 deletions tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@ def test_png_artifact_render(loader):
assert isinstance(png_artifact.render(), IPython.core.display.Image)


def test_png_artifact_save(loader):
def test_png_artifact_save(loader, tmpdir):
exp = loader.find_by_id(3)
exp.artifacts["confusion_matrix"].save()
filename = "3_confusion_matrix.png"
assert os.path.isfile(filename)
assert imghdr.what(filename) == "png"
os.remove(filename)
exp.artifacts["confusion_matrix"].save(to_dir=tmpdir)
filepath = str(tmpdir / "3_confusion_matrix.png")
assert os.path.isfile(filepath)
assert imghdr.what(filepath) == "png"


def test_csv_artifact_render(loader):
Expand All @@ -43,20 +42,19 @@ def test_csv_artifact_render(loader):
assert isinstance(csv_artifact.render(), pd.DataFrame)


def test_csv_artifact_show_warning(loader):
def test_csv_artifact_render_warning(loader):
exp = loader.find_by_id(3)
csv_artifact = exp.artifacts["predictions"]
assert isinstance(csv_artifact, artifact.CSVArtifact)
with pytest.deprecated_call():
assert isinstance(csv_artifact.show(), pd.DataFrame)


def test_mp4_artifact_save(loader):
def test_mp4_artifact_save(loader, tmpdir):
exp = loader.find_by_id(2)
exp.artifacts["accuracy_movie"].save()
filename = "2_accuracy_movie.mp4"
assert os.path.isfile(filename)
os.remove(filename)
exp.artifacts["accuracy_movie"].save(to_dir=tmpdir)
filepath = str(tmpdir / "2_accuracy_movie.mp4")
assert os.path.isfile(filepath)


def test_mp4_artifact_render(loader):
Expand All @@ -67,13 +65,12 @@ def test_mp4_artifact_render(loader):
os.remove("2_accuracy_movie.mp4")


def test_csv_artifact_save(loader):
def test_csv_artifact_save(loader, tmpdir):
exp = loader.find_by_id(3)
exp.artifacts["predictions"].save()
filename = "3_predictions.csv"
assert os.path.isfile(filename)
assert isinstance(pd.read_csv(filename), pd.DataFrame)
os.remove(filename)
exp.artifacts["predictions"].save(to_dir=tmpdir)
filepath = str(tmpdir / "3_predictions.csv")
assert os.path.isfile(filepath)
assert isinstance(pd.read_csv(filepath), pd.DataFrame)


def test_pickle_artifact_render(loader):
Expand All @@ -85,23 +82,21 @@ def test_pickle_artifact_render(loader):
assert isinstance(pickle_artifact.render(), pd.DataFrame)


def test_pickle_artifact_save(loader):
def test_pickle_artifact_save(loader, tmpdir):
exp = loader.find_by_id(3)
pickle_artifact = exp.artifacts["predictions_df"].as_type(artifact.PickleArtifact)
pickle_artifact.save()
filename = "3_predictions_df.pickle"
pickle_artifact.save(to_dir=tmpdir)
filename = str(tmpdir / "3_predictions_df.pickle")
assert os.path.isfile(filename)
assert isinstance(pickle.load(open(filename, "rb")), pd.DataFrame)
os.remove(filename)


def test_pdf_artifact_save(loader):
def test_pdf_artifact_save(loader, tmpdir):
exp = loader.find_by_id(2)
pdf_artifact = exp.artifacts["confusion_matrix_pdf"]
pdf_artifact.save()
filename = "2_confusion_matrix_pdf.pdf"
assert os.path.isfile(filename)
os.remove(filename)
pdf_artifact = exp.artifacts["confusion_matrix.pdf"]
pdf_artifact.save(to_dir=tmpdir)
filepath = str(tmpdir / "2_confusion_matrix.pdf")
assert os.path.isfile(filepath)


def test_as_type(loader):
Expand Down Expand Up @@ -129,7 +124,14 @@ def test_as_content_type_with_unkwown_content_type(loader):
text_artifact_as_something_strange = exp.artifacts["history"].as_content_type("something/strange")


def test_artifact_render(loader):
def test_artifact_render_with_unknown_content_type(loader):
exp = loader.find_by_id(3)
with raises(NotImplementedError):
exp.artifacts["predictions_df"].render()


def test_artifact_save_with_unknown_content_type(loader, tmpdir):
exp = loader.find_by_id(3)
exp.artifacts["predictions_df"].save(to_dir=tmpdir)
filepath = str(tmpdir / "3_predictions_df")
assert os.path.isfile(filepath)
22 changes: 22 additions & 0 deletions tests/test_set_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: future_fstrings -*-
import os


def test_single_save(loader, tmpdir):
exp_ids = [1, 2, 3]
exps = loader.find_by_ids(exp_ids)
exps.artifacts["confusion_matrix"].save(to_dir=tmpdir)
for exp_id in exp_ids:
filepath = tmpdir / f"{exp_id}_confusion_matrix.png"
assert os.path.isfile(str(filepath))


def test_glob_save(loader, tmpdir):
exp_ids = [1, 2, 3]
exps = loader.find_by_ids(exp_ids)
exps.artifacts.filter("confusion_matrix*").save(to_dir=tmpdir)
for exp_id in exp_ids:
filepath = str(tmpdir / f"{exp_id}_confusion_matrix.png")
assert os.path.isfile(filepath)
filepath = str(tmpdir / f"{exp_id}_confusion_matrix.pdf")
assert os.path.isfile(filepath)

0 comments on commit 2581127

Please sign in to comment.