diff --git a/ml_logger/README b/ml_logger/README index b1f7945..d41c88b 100644 --- a/ml_logger/README +++ b/ml_logger/README @@ -75,12 +75,12 @@ The logging server uses the ``sanic`` framework, which means defaults to ``sanic``, such as maximum request size are carried over. When using ``ml-logger`` to save and load **very -large** ``pytorch`` checkpoints, you need  to raise ``sanic``\ ’s +large** ``pytorch`` checkpoints, you need  to raise \ ``sanic``\ ’s default request size limit from 100MB to something like a gigabyte or even larger. The file upload is done using multi-part form upload, where each query is kept small. However sanic will throw if the overall size of the query exceeds this -parameter ``SANIC_REQUEST_MAX_SIZE=1000_000_000``. The default is +parameter \ ``SANIC_REQUEST_MAX_SIZE=1000_000_000``. The default is ``100_000_000``, or 100MB. Use ssh tunnel if you are running on a managed cluster. diff --git a/ml_logger/VERSION b/ml_logger/VERSION index c006218..e7c7d3c 100644 --- a/ml_logger/VERSION +++ b/ml_logger/VERSION @@ -1 +1 @@ -0.7.6 +0.7.8 diff --git a/ml_logger/ml_logger/__init__.py b/ml_logger/ml_logger/__init__.py index 6a80c25..5cab1f9 100644 --- a/ml_logger/ml_logger/__init__.py +++ b/ml_logger/ml_logger/__init__.py @@ -1,5 +1,6 @@ -from .struts import ALLOWED_TYPES -from .log_client import LogClient -from .helpers.print_utils import PrintHelper from .caches.summary_cache import SummaryCache +from .helpers.print_utils import PrintHelper +from .log_client import LogClient from .ml_logger import * +from .struts import ALLOWED_TYPES + diff --git a/ml_logger/ml_logger/log_client.py b/ml_logger/ml_logger/log_client.py index 4ce8c9f..f2e279a 100644 --- a/ml_logger/ml_logger/log_client.py +++ b/ml_logger/ml_logger/log_client.py @@ -45,7 +45,7 @@ class LogClient: sync_pool = None async_pool = None - def __init__(self, url: str = None, asynchronous=None, max_workers=None): + def __init__(self, root: str = None, user=None, access_token=None, asynchronous=None, max_workers=None): """ When max_workers is 0, the HTTP requests are synchronous. This allows one to make synchronous requests procedurally. @@ -54,7 +54,7 @@ def __init__(self, url: str = None, asynchronous=None, max_workers=None): Mujoco-py for example, would have trouble with forked processes if multiple threads are started before forking the subprocesses. - :param url: + :param root: :param asynchronous: If this is not None, we create a request pool. This way we can use the (A)SyncContext call right after construction. :param max_workers: @@ -62,16 +62,17 @@ def __init__(self, url: str = None, asynchronous=None, max_workers=None): if asynchronous is not None: self.set_session(asynchronous, max_workers) - if url.startswith("file://"): - self.local_server = LoggingServer(data_dir=url[6:], silent=True) - elif os.path.isabs(url): - self.local_server = LoggingServer(data_dir=url, silent=True) - elif url.startswith('http://'): + if root.startswith("file://"): + self.local_server = LoggingServer(cwd=root[6:], silent=True) + elif os.path.isabs(root): + self.local_server = LoggingServer(cwd=root, silent=True) + elif root.startswith('http://'): self.local_server = None # remove local server to use sessions. - self.url = url - self.stream_url = os.path.join(url, "stream") - self.ping_url = os.path.join(url, "ping") - self.glob_url = os.path.join(url, "glob") + self.url = os.path.join(root, user) + self.access_token = access_token + self.stream_url = os.path.join(root, user, "stream") + self.ping_url = os.path.join(root, user, "ping") + self.glob_url = os.path.join(root, user, "glob") # when setting sessions the first time, default to use Asynchronous Session. if self.session is None: asynchronous = True if asynchronous is None else asynchronous diff --git a/ml_logger/ml_logger/ml_logger.py b/ml_logger/ml_logger/ml_logger.py index a607901..8eaeb7e 100644 --- a/ml_logger/ml_logger/ml_logger.py +++ b/ml_logger/ml_logger/ml_logger.py @@ -24,6 +24,15 @@ from .helpers.print_utils import PrintHelper from .log_client import LogClient +# environment defaults +CWD = os.environ["PWD"] +USER = os.environ["USER"] + +# ML_Logger defaults +ROOT = os.environ.get("ML_LOGGER_ROOT", CWD) +USER = os.environ.get("ML_LOGGER_USER", USER) +ACCESS_TOKEN = os.environ.get("ML_LOGGER_ACCESS_TOKEN", None) + def pJoin(*args): from os.path import join @@ -198,7 +207,9 @@ def __repr__(self): # noinspection PyInitNewSignature # todo: use prefixes as opposed to prefix. (add *prefixae after prefix=None) # todo: resolve path segment with $env variables. - def __init__(self, root_dir: str = None, prefix=None, *prefixae, buffer_size=2048, max_workers=None, + def __init__(self, prefix="", *prefixae, + log_dir=ROOT, user=USER, access_token=ACCESS_TOKEN, + buffer_size=2048, max_workers=None, asynchronous=None, summary_cache_opts: dict = None): """ logger constructor. @@ -213,8 +224,11 @@ def __init__(self, root_dir: str = None, prefix=None, *prefixae, buffer_size=204 | 1. prefix="causal_infogan" => logs to "/tmp/some_dir/causal_infogan" | 2. prefix="" => logs to "/tmp/some_dir" - :param root_dir: the server host and port number :param prefix: the prefix path + :param **prefixae: the rest of the prefix arguments + :param log_dir: the server host and port number + :param user: environment $ML_LOGGER_USER + :param access_token: environment $ML_LOGGER_ACCESS_TOKEN :param asynchronous: When this is not None, we create a http thread pool. :param buffer_size: The string buffer size for the print buffer. :param max_workers: the number of request-session workers for the async http requests. @@ -236,18 +250,19 @@ def __init__(self, root_dir: str = None, prefix=None, *prefixae, buffer_size=204 self.summary_caches = defaultdict(partial(SummaryCache, **(summary_cache_opts or {}))) # todo: add https support - self.root_dir = interpolate(root_dir) or "/" - self.prefix = interpolate(prefix) or os.getcwd()[1:] - if prefix is not None: - self.prefix = os.path.join(*[interpolate(p) for p in (prefix, *prefixae) if p is not None]) + self.root_dir = interpolate(log_dir) or ROOT - # logger client contains thread pools, should not be re-created lightly. - self.client = LogClient(url=self.root_dir, asynchronous=asynchronous, max_workers=max_workers) + prefixae = [interpolate(p) for p in (prefix or "", *prefixae) if p is not None] + self.prefix = os.path.join(*prefixae) if prefixae else "" + self.client = LogClient(root=self.root_dir, user=user, access_token=access_token, + asynchronous=asynchronous, max_workers=max_workers) def configure(self, - root_dir: str = None, prefix=None, *prefixae, + log_dir: str = None, + user=None, + access_token=None, asynchronous=None, max_workers=None, buffer_size=None, @@ -293,8 +308,11 @@ def configure(self, todo: the table at the moment seems a bit verbose. I'm considering making this just a single line print. - :param log_directory: - :param prefix: + :param prefix: the first prefix + :param *prefixae: a list of prefix segments + :param log_dir: + :param user: + :param access_token: :param buffer_size: :param summary_cache_opts: :param asynchronous: @@ -305,9 +323,11 @@ def configure(self, """ # path logic - root_dir = interpolate(root_dir) or os.getcwd() + log_dir = interpolate(log_dir) or os.getcwd() if prefix is not None: - self.prefix = os.path.join(*[interpolate(p) for p in (prefix, *prefixae) if p is not None]) + prefixae = [interpolate(p) for p in (prefix, *prefixae) if p is not None] + if prefixae is not None: + self.prefix = os.path.join(*prefixae) if buffer_size is not None: self.print_buffer_size = buffer_size @@ -318,17 +338,18 @@ def configure(self, self.summary_caches.clear() self.summary_caches = defaultdict(partial(SummaryCache, **(summary_cache_opts or {}))) - if root_dir != self.root_dir or asynchronous is not None or max_workers is not None: - # note: logger.configure shouldn't be called too often, so it is okay to assume - # that we can discard the old logClient. - # To quickly switch back and forth between synchronous and asynchronous calls, - # use the `SyncContext` and `AsyncContext` instead. + if log_dir: + self.root_dir = interpolate(log_dir) or ROOT + if log_dir or asynchronous is not None or max_workers is not None: + # note: logger.configure shouldn't be called too often. To quickly switch back + # and forth between synchronous and asynchronous calls, use the `SyncContext` + # and `AsyncContext` instead. if not silent: - cprint('creating new logging client...', color='yellow', end=' ') - self.root_dir = root_dir - self.client.__init__(url=self.root_dir, asynchronous=asynchronous, max_workers=max_workers) + cprint('creating new logging client...', color='yellow', end='\r') + self.client.__init__(root=self.root_dir, user=user, access_token=access_token, + asynchronous=asynchronous, max_workers=max_workers) if not silent: - cprint('✓ done', color="green") + cprint('✓ created a new logging client', color="green") if not silent: from urllib.parse import quote diff --git a/ml_logger/ml_logger/server.py b/ml_logger/ml_logger/server.py index 4b6144d..d0f1b3b 100644 --- a/ml_logger/ml_logger/server.py +++ b/ml_logger/ml_logger/server.py @@ -12,14 +12,20 @@ class LoggingServer: silent = None - def __init__(self, data_dir, silent=False): - assert os.path.isabs(data_dir) - self.data_dir = data_dir - os.makedirs(data_dir, exist_ok=True) + def abs_path(self, key): + log_dir = os.path.join(self.cwd, key) + return os.path.join(self.root, log_dir[1:]) + + def __init__(self, cwd="/", root="/", silent=False): + assert os.path.isabs(root) + assert os.path.isabs(cwd) + self.root = root + self.cwd = cwd + os.makedirs(root, exist_ok=True) self.silent = silent if not silent: - print('logging data to {}'.format(data_dir)) + cprint(f'logging data to {root}', 'green') configure = __init__ @@ -43,7 +49,7 @@ async def stream_handler(self, req): return sanic.response.text(msg) load_entry = LoadEntry(**req.json) print(f"streaming: {load_entry.key}") - path = os.path.join(self.data_dir, load_entry.key) + path = os.path.join(self.root, load_entry.key) return await sanic.response.file_stream(path) async def ping_handler(self, req): @@ -141,8 +147,9 @@ def glob(self, query, wd, recursive, start, stop): """ Glob under the work directory. (so that the wd is not included in the file paths that are returned.) - :param query: we remove the leading slash so that //home/directory allows you to access absolute path of the server host environment. single leanding slash accesses w.r.t. the data_dir. - :param wd: we remove the leading slash so that //home/directory allows you to access absolute path of the server host environment. single leanding slash accesses w.r.t. the data_dir. + :param query: + :param wd: Use double slash //home/directory to access absolute path of the + server host environment. single leading slash accesses the data_dir. :param recursive: :param start: :param stop: @@ -153,14 +160,11 @@ def glob(self, query, wd, recursive, start, stop): from itertools import islice from ml_logger.helpers.file_helpers import CwdContext - wd = wd[1:] if wd and wd.startswith("/") else wd - query = query[1:] if query and query.startswith("/") else query try: - with CwdContext(os.path.join(self.data_dir, wd or "")): - file_paths = list(islice(iglob(query, recursive=recursive), start, stop)) - return file_paths + with CwdContext(self.abs_path(wd or "")): + return list(islice(iglob(query, recursive=recursive), start, stop)) except PermissionError: - print('PermissionError:', os.path.join(self.data_dir, wd or "")) + print('PermissionError:', os.path.join(self.root, wd or "")) return None except FileNotFoundError: return None @@ -192,8 +196,7 @@ def load(self, key, dtype, start=None, stop=None): :param stop: start index :return: None, or a tuple of each one of the data chunks logged into the file. """ - key = key[1:] if key.startswith("/") else key - abs_path = os.path.join(self.data_dir, key) + abs_path = self.abs_path(key) if dtype == 'byte': try: return list(load_from_file(abs_path))[start:stop] @@ -240,7 +243,7 @@ def remove(self, key): :param key: the path from the logging directory. :return: None """ - abs_path = os.path.join(self.data_dir, key) + abs_path = self.abs_path(key) try: os.remove(abs_path) except FileNotFoundError as e: @@ -253,31 +256,25 @@ def copy(self, src, target): import shutil assert isinstance(src, str), "src needs to be a string" - if target.startswith('/'): - target = target[1:] - if src.startswith('/'): - src = src[1:] + abs_target = self.abs_path(target) + abs_src = self.abs_path(src) - abs_target = os.path.join(self.data_dir, target) - abs_src = os.path.join(self.data_dir, src) os.makedirs(os.path.dirname(abs_target), exist_ok=True) shutil.copyfile(abs_src, abs_target, follow_symlinks=True) return target def save_buffer(self, key, buff): - assert isinstance(src, BytesIO), f"buff needs to be a BytesIO object." - if target.startswith('/'): - target = target[1:] + assert isinstance(buff, BytesIO), f"buff needs to be a BytesIO object." - abs_target = os.path.join(self.data_dir, target) + abs_path = self.abs_path(key) - with open(abs_target, 'wb') as t: + with open(abs_path, 'wb') as t: while True: - content = src.read() + content = buff.read() if content == b"": break t.write(content) - return target + return key def log(self, key, data, dtype, options: LogOptions = None): """ @@ -291,36 +288,31 @@ def log(self, key, data, dtype, options: LogOptions = None): """ # todo: overwrite mode is not tested and not in-use. write_mode = "w" if options and options.overwrite else "a" - if key.startswith('/'): - key = key[1:] + abs_path = self.abs_path(key) + parent_dir = os.path.dirname(abs_path) + # fixme: There is a race condition with multiple requests if dtype == "log": - abs_path = os.path.join(self.data_dir, key) - # fixme: There is a race condition with multiple requests try: with open(abs_path, write_mode + 'b') as f: dill.dump(data, f) except FileNotFoundError: - os.makedirs(os.path.dirname(abs_path), exist_ok=True) + os.makedirs(parent_dir, exist_ok=True) with open(abs_path, write_mode + 'b') as f: dill.dump(data, f) if dtype == "byte": - abs_path = os.path.join(self.data_dir, key) - # fixme: There is a race condition with multiple requests try: with open(abs_path, write_mode + 'b') as f: f.write(data) except FileNotFoundError: - os.makedirs(os.path.dirname(abs_path), exist_ok=True) + os.makedirs(parent_dir, exist_ok=True) with open(abs_path, write_mode + 'b') as f: f.write(data) elif dtype.startswith("text"): - abs_path = os.path.join(self.data_dir, key) - # fixme: There is a race condition with multiple requests try: with open(abs_path, write_mode + "+") as f: f.write(data) except FileNotFoundError: - os.makedirs(os.path.dirname(abs_path), exist_ok=True) + os.makedirs(parent_dir, exist_ok=True) with open(abs_path, write_mode + "+") as f: f.write(data) elif dtype.startswith("yaml"): @@ -338,7 +330,6 @@ def log(self, key, data, dtype, options: LogOptions = None): stream = StringIO() yaml.dump(data, stream) output = stream.getvalue() - abs_path = os.path.join(self.data_dir, key) try: with open(abs_path, write_mode + "+") as f: if options.write_mode == 'key': @@ -348,7 +339,7 @@ def log(self, key, data, dtype, options: LogOptions = None): output = d f.write(output) except FileNotFoundError: - os.makedirs(os.path.dirname(abs_path), exist_ok=True) + os.makedirs(parent_dir, exist_ok=True) with open(abs_path, write_mode + "+") as f: if options.write_mode == 'key': d = load_fn('\n'.join(f)) @@ -357,7 +348,7 @@ def log(self, key, data, dtype, options: LogOptions = None): output = d f.write(output) elif dtype.startswith("image"): - abs_path = os.path.join(self.data_dir, key) + abs_path = self.abs_path(key) if "." not in key: abs_path = abs_path + ".png" from PIL import Image @@ -368,7 +359,7 @@ def log(self, key, data, dtype, options: LogOptions = None): try: im.save(abs_path) except FileNotFoundError: - os.makedirs(os.path.dirname(abs_path), exist_ok=True) + os.makedirs(parent_dir, exist_ok=True) im.save(abs_path) @@ -390,5 +381,5 @@ class Params: v = pkg_resources.get_distribution("ml_logger").version print('running ml_logger.server version {}'.format(v)) - server = LoggingServer(data_dir=Params.logdir) + server = LoggingServer(root=Params.logdir) server.serve(host=Params.host, port=Params.port, workers=Params.workers) diff --git a/ml_logger/ml_logger_tests/test_ml_logger.py b/ml_logger/ml_logger_tests/test_ml_logger.py index 640b180..e5d75f0 100644 --- a/ml_logger/ml_logger_tests/test_ml_logger.py +++ b/ml_logger/ml_logger_tests/test_ml_logger.py @@ -36,7 +36,7 @@ def log_dir(request): @pytest.fixture(scope="session") def setup(log_dir): - logger.configure(log_dir, prefix='main_test_script') + logger.configure('main_test_script', log_dir=log_dir) logger.remove('') print(f"logging to {pathJoin(logger.root_dir, logger.prefix)}") @@ -65,11 +65,12 @@ def test_log_data(setup): def test_save_pkl_abs_path(setup): import numpy + d1 = numpy.random.randn(20, 10) - logger.save_pkl(d1, '/test_file_1.pkl') + logger.save_pkl(d1, "/tmp/ml-logger-test/test_file_1.pkl") sleep(0.1) - data = logger.load_pkl('/test_file_1.pkl') + data = logger.load_pkl("/tmp/ml-logger-test/test_file_1.pkl") assert len(data) == 1, "data should contain only one array because we overwrote it." assert numpy.array_equal(data[0], d1), "first should be the same as d2" @@ -135,8 +136,8 @@ def test_json(setup): def test_json_abs(setup): a = dict(a=0) - logger.save_json(dict(a=0), "/data/d.json") - b = logger.load_json("/data/d.json") + logger.save_json(dict(a=0), "/tmp/ml-logger-test/data/d.json") + b = logger.load_json("/tmp/ml-logger-test/data/d.json") assert a == b, "a and b should be the same" @@ -218,8 +219,9 @@ def im(x, y): frames = [im(100 + i, 80) for i in range(20)] logger.save_video(frames, "test_video.gif") - logger.save_video(frames, "/videos/test_video.gif") - assert '/videos/test_video.gif' in logger.glob('/videos/*.gif') + assert 'test_video.gif' in logger.glob('*.gif') + logger.save_video(frames, "/tmp/ml-logger-test/videos/test_video.gif") + assert 'ml-logger-test/videos/test_video.gif' in logger.glob('**/videos/*.gif', wd="/tmp") def test_load_params(setup): diff --git a/ml_logger/ml_logger_tests/test_prefix_cases.py b/ml_logger/ml_logger_tests/test_prefix_cases.py new file mode 100644 index 0000000..9e63d05 --- /dev/null +++ b/ml_logger/ml_logger_tests/test_prefix_cases.py @@ -0,0 +1,51 @@ +import os + +""" +Details of the logging prefix + +root = $PWD, prefix = "" -> log to $PWD locally. + -> log to server $LOGDIR remotely. +root = $PWD, prefix = "/" -> log to "/" locally. + -> log to server $LOGDIR remotely. (removes the '/') +logger client contains thread pools, should not be re-created lightly. +""" + + +def test_local_relative(): + root = "/" + cwd = "/cwd" + prefix = "wat" + log_dir = os.path.join(cwd, prefix) + assert os.path.join(root, log_dir[1:]) == "/cwd/wat" + + +def test_local_absolute(): + root = "/" + cwd = "/cwd" + prefix = "/wat" + log_dir = os.path.join(cwd, prefix) + assert os.path.join(root, log_dir[1:]) == "/wat" + + +def test_server_relative(): + root = "/data_dir" + cwd = "/" + prefix = "wat" + log_dir = os.path.join(cwd, prefix) + assert os.path.join(root, log_dir[1:]) == "/data_dir/wat" + + +def test_server_absolute(): + root = "/data_dir" + cwd = "/" + prefix = "/wat" + log_dir = os.path.join(cwd, prefix) + assert os.path.join(root, log_dir[1:]) == "/data_dir/wat" + + +def test_server_system_absolute(): + root = "/" + cwd = "/" + prefix = "//wat" + log_dir = os.path.join(cwd, prefix) + assert os.path.join(root, log_dir[1:]) == "/wat"