Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights sharding for Keras saving #19286

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions keras/saving/saving_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def save_model(model, filepath, overwrite=True, **kwargs):
"""
include_optimizer = kwargs.pop("include_optimizer", True)
save_format = kwargs.pop("save_format", False)
sharded = kwargs.pop("sharded", False)
shard_size = kwargs.pop("shard_size", None)
if save_format:
if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
".keras"
Expand Down Expand Up @@ -97,7 +99,12 @@ def save_model(model, filepath, overwrite=True, **kwargs):
return

if str(filepath).endswith(".keras"):
saving_lib.save_model(model, filepath)
saving_lib.save_model(
model,
filepath,
sharded=sharded,
shard_size=shard_size,
)
elif str(filepath).endswith((".h5", ".hdf5")):
legacy_h5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer
Expand Down Expand Up @@ -204,17 +211,18 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):


def load_weights(model, filepath, skip_mismatch=False, **kwargs):
sharded = kwargs.pop("sharded", False)
if str(filepath).endswith(".keras"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
model, filepath, sharded=sharded, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".weights.h5"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
model, filepath, sharded=sharded, skip_mismatch=skip_mismatch
)
elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"):
by_name = kwargs.pop("by_name", False)
Expand Down
197 changes: 187 additions & 10 deletions keras/saving/saving_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import datetime
import io
import json
import os
import re
import tempfile
import warnings
import zipfile
Expand Down Expand Up @@ -35,7 +37,9 @@
_ASSETS_DIRNAME = "assets"


def save_model(model, filepath, weights_format="h5"):
def save_model(
model, filepath, weights_format="h5", sharded=False, shard_size=None
):
"""Save a zip-archive representing a Keras model to the given filepath.

The zip-based archive contains the following structure:
Expand Down Expand Up @@ -67,6 +71,12 @@ def save_model(model, filepath, weights_format="h5"):
)
if weights_format == "h5" and h5py is None:
raise ImportError("h5py must be installed in order to save a model.")
if weights_format != "h5" and sharded:
raise NotImplementedError(
"Sharding is only currently supported in the H5 weights format. "
"Please pass `sharded=False` or switch to `weights_format=h5`. "
f"Received: weights_format={weights_format}, sharded={sharded}."
)

if not model.built:
warnings.warn(
Expand Down Expand Up @@ -99,7 +109,18 @@ def save_model(model, filepath, weights_format="h5"):
f.write(config_json.encode())

if weights_format == "h5":
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w")
if sharded:
max_size = shard_size if shard_size is not None else "10GB"
weights_store = ShardedH5IOStore(
_VARS_FNAME + ".h5",
archive=zf,
mode="w",
max_size=max_size,
)
else:
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=zf, mode="w"
)
elif weights_format == "npz":
weights_store = NpzIOStore(
_VARS_FNAME + ".npz", archive=zf, mode="w"
Expand Down Expand Up @@ -158,7 +179,16 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):

all_filenames = zf.namelist()
if _VARS_FNAME + ".h5" in all_filenames:
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r")
if _VARS_FNAME + ".json" in all_filenames:
weights_store = ShardedH5IOStore(
_VARS_FNAME + ".h5",
archive=zf,
mode="r",
)
else:
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=zf, mode="r"
)
elif _VARS_FNAME + ".npz" in all_filenames:
weights_store = NpzIOStore(
_VARS_FNAME + ".npz", archive=zf, mode="r"
Expand Down Expand Up @@ -193,7 +223,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
return model


def save_weights_only(model, filepath):
def save_weights_only(model, filepath, sharded=False, shard_size=None):
"""Save only the weights of a model to a target filepath (.weights.h5).

Note: only supports h5 for now.
Expand All @@ -206,7 +236,11 @@ def save_weights_only(model, filepath):
"Invalid `filepath` argument: expected a `.weights.h5` extension. "
f"Received: filepath={filepath}"
)
weights_store = H5IOStore(filepath, mode="w")
if sharded:
max_size = shard_size if shard_size is not None else "10GB"
weights_store = ShardedH5IOStore(filepath, mode="w", max_size=max_size)
else:
weights_store = H5IOStore(filepath, mode="w")
_save_state(
model,
weights_store=weights_store,
Expand All @@ -217,7 +251,7 @@ def save_weights_only(model, filepath):
weights_store.close()


def load_weights_only(model, filepath, skip_mismatch=False):
def load_weights_only(model, filepath, sharded=False, skip_mismatch=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should sharded be configurable here -- wouldn't it just depend on the file and the model?

"""Load the weights of a model from a filepath (.keras or .weights.h5).

Note: only supports h5 for now.
Expand All @@ -227,12 +261,23 @@ def load_weights_only(model, filepath, skip_mismatch=False):
filepath = str(filepath)
if filepath.endswith(".weights.h5"):
# TODO: download file if h5 filepath is remote
weights_store = H5IOStore(filepath, mode="r")
if sharded:
weights_store = ShardedH5IOStore(filepath, mode="r")
else:
weights_store = H5IOStore(filepath, mode="r")
elif filepath.endswith(".keras"):
archive = zipfile.ZipFile(filepath, "r")
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=archive, mode="r"
)
all_filenames = archive.namelist()
if _VARS_FNAME + ".json" in all_filenames:
weights_store = ShardedH5IOStore(
_VARS_FNAME + ".h5",
archive=archive,
mode="r",
)
else:
weights_store = H5IOStore(
_VARS_FNAME + ".h5", archive=archive, mode="r"
)

failed_trackables = set()
error_msgs = {}
Expand Down Expand Up @@ -617,6 +662,138 @@ def close(self):
self.io_file.close()


class ShardedH5IOStore:
def __init__(self, root_path, max_size="10GB", archive=None, mode="r"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't max_size be an int? e.g. in MB?

self.shard_list = []
self.root_path = root_path
self.mode = mode
self.archive = archive
self.io_file = None
self.max_size = convert_str_bytes_to_int(max_size)
self.current_shard_size = 0

self.var_shard_map_filename = str(root_path).replace(
".weights.h5", ".weights.json"
)
if not os.path.exists(self.var_shard_map_filename):
if self.mode == "w":
self.var_shard_map = {}
if self.mode == "r":
raise FileNotFoundError(
f"Loading a sharded `.weights.h5` file requires "
"its corresponding sharding map JSON file "
f"{self.var_shard_map_filename} in the same directory. "
"Please ensure all weights files and the sharding map "
"JSON file are in the same directory when loading a "
"sharded weights file."
)
else:
with open(self.var_shard_map_filename, "r") as map_file:
self.var_shard_map = json.load(map_file)

self.h5_file = self._create_new_file(root_path)

def _create_new_file(self, path):
if path in self.shard_list:
path = resolve_duplicate_filename(str(path), self.shard_list)
self.root_path = path
if self.archive:
if self.mode == "w":
self.io_file = io.BytesIO()
else:
self.io_file = self.archive.open(path, "r")
return h5py.File(self.io_file, mode=self.mode)
else:
return h5py.File(path, mode=self.mode)

def _change_access_file(self, filename): # Read-only
self.close()
if self.archive:
self.io_file = self.archive.open(filename, "r")
return h5py.File(self.io_file, mode=self.mode)
else:
return h5py.File(filename, mode=self.mode)

def make(self, path):
def _get_size(key):
if isinstance(self.h5_file[key], h5py.Dataset):
self.current_shard_size += self.h5_file[key].nbytes

self.current_shard_size = 0
self.h5_file.visit(_get_size)
if self.current_shard_size > self.max_size:
self.shard_list.append(self.h5_file.filename)
self.close()
self.h5_file = self._create_new_file(self.root_path)
if not path:
group = self.h5_file.create_group("vars")
else:
group = self.h5_file.create_group(path).create_group("vars")
self.var_shard_map[group.name] = self.root_path
return group

def get(self, path):
if not path:
return self.h5_file["vars"]
if path in self.h5_file and "vars" in self.h5_file[path]:
return self.h5_file[path]["vars"]

# If not found, check shard map and switch files
filename = self.var_shard_map.get(path) or self.var_shard_map.get(
"/" + path + "/vars"
)
if filename is not None and self.h5_file.name != filename:
new_file = self._change_access_file(filename)
if "vars" in new_file[path]:
self.h5_file = new_file
return self.h5_file[path]["vars"]
return {}

def close(self):
self.h5_file.close()
if self.mode == "w":
with open(self.var_shard_map_filename, "w") as map_file:
map_file.write(json.dumps(self.var_shard_map))
if self.archive:
self.archive.writestr(self.root_path, self.io_file.getvalue())
if self.io_file:
self.io_file.close()


def convert_str_bytes_to_int(size):
if size.upper().endswith("GB"):
return int(size[:-2]) * (10**9)
if size.upper().endswith("MB"):
return int(size[:-2]) * (10**6)
if size.upper().endswith("KB"):
return int(size[:-2]) * (10**3)
raise ValueError(
"Invalid format for `size`. Use an integer followed by the unit "
"(GB, MB, or KB). For example, '5GB' or '15MB'."
)


def resolve_duplicate_filename(path, path_list):
pattern = re.compile(r"_\d\.weights\.h5")
pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate
if not pre_duplicate.endswith(".weights.h5"):
match_list = list(
filter(lambda x: x.startswith(pre_duplicate), path_list)
)
if len(match_list) > 1:
return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5"
return path.replace(".weights.h5", "_1.weights.h5")


def dtype_to_bytes(dtype):
if "bool" in str(dtype):
return 1 / 8
bits = re.search(r"[^\d](\d+)$", str(dtype))
if bits is None:
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
return int(bits.groups()[0]) // 8 # Bit size in bytes


class H5Entry:
"""Leaf entry in a H5IOStore."""

Expand Down
68 changes: 67 additions & 1 deletion keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,70 @@ def call(self, inputs):
return self.first_layer(self.second_layer(inputs))


def _get_large_model():
model = keras.Sequential(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pick a convnet for a large model?

[
keras.layers.Input(shape=[28, 28, 1], dtype="float32"),
keras.layers.Conv2D(
filters=12,
kernel_size=3,
padding="same",
name="conv1",
use_bias=False,
), # no bias necessary before batch norm
keras.layers.BatchNormalization(
scale=False, center=True
), # no batch norm scaling necessary before "relu"
keras.layers.Activation("relu"), # activation after batch norm
keras.layers.Conv2D(
filters=24,
kernel_size=6,
padding="same",
name="conv2",
use_bias=False,
strides=2,
),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
keras.layers.Conv2D(
filters=32,
kernel_size=6,
padding="same",
name="conv3",
use_bias=False,
strides=2,
),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
keras.layers.Flatten(),
keras.layers.Dense(200, name="dense1", use_bias=False),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
keras.layers.Dropout(0.4), # Dropout on dense layer only
keras.layers.Dense(10, name="dense2", activation="softmax"),
]
)
model.compile(
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]
)
return model


class LargeModelTest(testing.TestCase):
def test_model_sharding(self):
model = _get_large_model()
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5")
ref_input = np.random.random((1, 28, 28, 1))
ref_output = model.predict(ref_input)
saving_lib.save_weights_only(
model, temp_filepath, sharded=True, shard_size="1MB"
)

model = _get_large_model()
model.load_weights(temp_filepath, sharded=True)
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)


class SavingBattleTest(testing.TestCase):
def test_custom_object_without_from_config(self):
temp_filepath = os.path.join(
Expand Down Expand Up @@ -809,7 +873,9 @@ def dense(self):
def call(self, x):
return self.dense(x)

temp_filepath = "normal_model.weights.h5"
temp_filepath = os.path.join(
self.get_temp_dir(), "normal_model.weights.h5"
)
model_a = NormalModel()
model_a(np.random.random((2, 2)))
model_a.save_weights(temp_filepath)
Expand Down