Skip to content

Commit

Permalink
feat: Add support for loading / writing from files (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbe7a authored Aug 23, 2024
1 parent 1982224 commit 9461551
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "slim-trees"
description = "A python package for efficient pickling of ML models."
version = "0.2.11"
version = "0.2.12"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.8"
Expand Down
18 changes: 11 additions & 7 deletions slim_trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import importlib.metadata
import warnings
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, BinaryIO, Optional, Union

from slim_trees.pickling import (
dump_compressed,
Expand Down Expand Up @@ -42,22 +42,24 @@


def dump_sklearn_compressed(
model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.
Saves the parameters of the model as int16 and float32 instead of int64 and float64.
:param model: the model to save
:param path: where to save the model
:param file: where to save the model, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to `open`
of the compression library.
Options: ["no", "lzma", "gzip", "bz2"]
"""
from slim_trees.sklearn_tree import dump

dump_compressed(model, path, compression, dump)
dump_compressed(model, file, compression, dump)


def dumps_sklearn_compressed(
Expand All @@ -79,22 +81,24 @@ def dumps_sklearn_compressed(


def dump_lgbm_compressed(
model: Any, path: Union[str, Path], compression: Optional[Union[str, dict]] = None
model: Any,
file: Union[str, Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
):
"""
Pickles a model and saves a compressed version to the disk.
Saves the parameters of the model as int16 and float32 instead of int64 and float64.
:param model: the model to save
:param path: where to save the model
:param file: where to save the model, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to `open`
of the compression library.
Options: ["no", "lzma", "gzip", "bz2"]
"""
from slim_trees.lgbm_booster import dump

dump_compressed(model, path, compression, dump)
dump_compressed(model, file, compression, dump)


def dumps_lgbm_compressed(
Expand Down
34 changes: 17 additions & 17 deletions slim_trees/pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import pickle
from collections.abc import Callable
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union


class _NoCompression:
Expand Down Expand Up @@ -51,7 +51,7 @@ def _get_default_kwargs(compression_method: str) -> Dict[str, Any]:

def _unpack_compression_args(
compression: Optional[Union[str, Dict[str, Any]]] = None,
path: Optional[Union[str, pathlib.Path]] = None,
file: Optional[Union[str, pathlib.Path, BinaryIO]] = None,
) -> Tuple[str, dict]:
if compression is not None:
if isinstance(compression, str):
Expand All @@ -61,24 +61,24 @@ def _unpack_compression_args(
k: compression[k] for k in compression if k != "method"
}
raise ValueError("compression must be either a string or a dict")
if path is not None:
if file is not None and isinstance(file, (str, pathlib.Path)):
# try to find out the compression using the file extension
compression_method = _get_compression_from_path(path)
compression_method = _get_compression_from_path(file)
return compression_method, _get_default_kwargs(compression_method)
raise ValueError("path or compression must not be None.")
raise ValueError("File must be a path or compression must not be None.")


def dump_compressed(
obj: Any,
path: Union[str, pathlib.Path],
file: Union[str, pathlib.Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
dump_function: Optional[Callable] = None,
):
"""
Pickles a model and saves it to the disk. If compression is not specified,
the compression method will be determined by the file extension.
:param obj: the object to pickle
:param path: where to save the object
:param file: where to save the object, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method' set
to the compression method and other key-value pairs are forwarded to open()
of the compression library.
Expand All @@ -89,11 +89,11 @@ def dump_compressed(
if dump_function is None:
dump_function = pickle.dump

compression_method, kwargs = _unpack_compression_args(compression, path)
compression_method, kwargs = _unpack_compression_args(compression, file)
with _get_compression_library(compression_method).open(
path, mode="wb", **kwargs
) as file:
dump_function(obj, file)
file, mode="wb", **kwargs
) as fd:
dump_function(obj, fd)


def dumps_compressed(
Expand Down Expand Up @@ -124,13 +124,13 @@ def dumps_compressed(


def load_compressed(
path: Union[str, pathlib.Path],
file: Union[str, pathlib.Path, BinaryIO],
compression: Optional[Union[str, dict]] = None,
unpickler_class: type = pickle.Unpickler,
) -> Any:
"""
Loads a compressed model.
:param path: where to load the object from
:param file: where to load the object from, either a path or a file object
:param compression: the compression method used. Either a string or a dict with key 'method'
set to the compression method and other key-value pairs which are forwarded
to open() of the compression library.
Expand All @@ -139,11 +139,11 @@ def load_compressed(
This is useful to restrict possible imports or to allow unpickling
when required module or function names have been refactored.
"""
compression_method, kwargs = _unpack_compression_args(compression, path)
compression_method, kwargs = _unpack_compression_args(compression, file)
with _get_compression_library(compression_method).open(
path, mode="rb", **kwargs
) as file:
return unpickler_class(file).load()
file, mode="rb", **kwargs
) as fd:
return unpickler_class(fd).load()


def loads_compressed(
Expand Down
4 changes: 2 additions & 2 deletions slim_trees/sklearn_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def dumps(model: Any) -> bytes:
def _tree_pickle(tree: Tree):
assert isinstance(tree, Tree)
reconstructor, args, state = tree.__reduce__()
compressed_state = _compress_tree_state(state)
compressed_state = _compress_tree_state(state) # type: ignore
return _tree_unpickle, (reconstructor, args, (slim_trees_version, compressed_state))


Expand Down Expand Up @@ -113,7 +113,7 @@ def _compress_tree_state(state: Dict) -> Dict:
"values": values,
},
**(
{"missing_go_to_left": np.packbits(missing_go_to_left)}
{"missing_go_to_left": np.packbits(missing_go_to_left)} # type: ignore
if sklearn_version_ge_130
else {}
),
Expand Down
15 changes: 15 additions & 0 deletions tests/test_lgbm_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,19 @@ def test_loads_compressed_custom_unpickler(lgbm_regressor):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


def test_dump_and_load_from_file(tmp_path, lgbm_regressor):
with (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_lgbm_compressed(lgbm_regressor, file, compression="lzma")

with (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_lgbm_compressed(lgbm_regressor, file)


# todo add tests for large models
15 changes: 15 additions & 0 deletions tests/test_sklearn_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,19 @@ def test_loads_compressed_custom_unpickler(random_forest_regressor):
loads_compressed(compressed, unpickler_class=_TestUnpickler)


def test_dump_and_load_from_file(tmp_path, random_forest_regressor):
with (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_sklearn_compressed(random_forest_regressor, file, compression="lzma")

with (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file, compression="lzma")

# No compression method specified
with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("rb") as file:
load_compressed(file)

with pytest.raises(ValueError), (tmp_path / "model.pickle.lzma").open("wb") as file:
dump_sklearn_compressed(random_forest_regressor, file)


# todo add tests for large models

0 comments on commit 9461551

Please sign in to comment.