-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weights sharding for Keras saving #19286
Draft
nkovela1
wants to merge
7
commits into
keras-team:master
Choose a base branch
from
nkovela1:weights
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
4058a18
Add initial commit for sharded H5 store
nkovela1 51ea848
Finish ShardedH5IOStore initial implementation
nkovela1 a7d29aa
Add sharding to saving and loading API logic with associated errors
nkovela1 de397ed
Fix sharding API and add size check
nkovela1 933c387
Add large model test and debug sharding algorithm
nkovela1 ad7ca7d
Merge update with master saving changes
nkovela1 cfbb761
Fix formatting
nkovela1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
import datetime | ||
import io | ||
import json | ||
import os | ||
import re | ||
import tempfile | ||
import warnings | ||
import zipfile | ||
|
@@ -35,7 +37,9 @@ | |
_ASSETS_DIRNAME = "assets" | ||
|
||
|
||
def save_model(model, filepath, weights_format="h5"): | ||
def save_model( | ||
model, filepath, weights_format="h5", sharded=False, shard_size=None | ||
): | ||
"""Save a zip-archive representing a Keras model to the given filepath. | ||
|
||
The zip-based archive contains the following structure: | ||
|
@@ -67,6 +71,12 @@ def save_model(model, filepath, weights_format="h5"): | |
) | ||
if weights_format == "h5" and h5py is None: | ||
raise ImportError("h5py must be installed in order to save a model.") | ||
if weights_format != "h5" and sharded: | ||
raise NotImplementedError( | ||
"Sharding is only currently supported in the H5 weights format. " | ||
"Please pass `sharded=False` or switch to `weights_format=h5`. " | ||
f"Received: weights_format={weights_format}, sharded={sharded}." | ||
) | ||
|
||
if not model.built: | ||
warnings.warn( | ||
|
@@ -99,7 +109,18 @@ def save_model(model, filepath, weights_format="h5"): | |
f.write(config_json.encode()) | ||
|
||
if weights_format == "h5": | ||
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") | ||
if sharded: | ||
max_size = shard_size if shard_size is not None else "10GB" | ||
weights_store = ShardedH5IOStore( | ||
_VARS_FNAME + ".h5", | ||
archive=zf, | ||
mode="w", | ||
max_size=max_size, | ||
) | ||
else: | ||
weights_store = H5IOStore( | ||
_VARS_FNAME + ".h5", archive=zf, mode="w" | ||
) | ||
elif weights_format == "npz": | ||
weights_store = NpzIOStore( | ||
_VARS_FNAME + ".npz", archive=zf, mode="w" | ||
|
@@ -158,7 +179,16 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): | |
|
||
all_filenames = zf.namelist() | ||
if _VARS_FNAME + ".h5" in all_filenames: | ||
weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") | ||
if _VARS_FNAME + ".json" in all_filenames: | ||
weights_store = ShardedH5IOStore( | ||
_VARS_FNAME + ".h5", | ||
archive=zf, | ||
mode="r", | ||
) | ||
else: | ||
weights_store = H5IOStore( | ||
_VARS_FNAME + ".h5", archive=zf, mode="r" | ||
) | ||
elif _VARS_FNAME + ".npz" in all_filenames: | ||
weights_store = NpzIOStore( | ||
_VARS_FNAME + ".npz", archive=zf, mode="r" | ||
|
@@ -193,7 +223,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): | |
return model | ||
|
||
|
||
def save_weights_only(model, filepath): | ||
def save_weights_only(model, filepath, sharded=False, shard_size=None): | ||
"""Save only the weights of a model to a target filepath (.weights.h5). | ||
|
||
Note: only supports h5 for now. | ||
|
@@ -206,7 +236,11 @@ def save_weights_only(model, filepath): | |
"Invalid `filepath` argument: expected a `.weights.h5` extension. " | ||
f"Received: filepath={filepath}" | ||
) | ||
weights_store = H5IOStore(filepath, mode="w") | ||
if sharded: | ||
max_size = shard_size if shard_size is not None else "10GB" | ||
weights_store = ShardedH5IOStore(filepath, mode="w", max_size=max_size) | ||
else: | ||
weights_store = H5IOStore(filepath, mode="w") | ||
_save_state( | ||
model, | ||
weights_store=weights_store, | ||
|
@@ -217,7 +251,7 @@ def save_weights_only(model, filepath): | |
weights_store.close() | ||
|
||
|
||
def load_weights_only(model, filepath, skip_mismatch=False): | ||
def load_weights_only(model, filepath, sharded=False, skip_mismatch=False): | ||
"""Load the weights of a model from a filepath (.keras or .weights.h5). | ||
|
||
Note: only supports h5 for now. | ||
|
@@ -227,12 +261,23 @@ def load_weights_only(model, filepath, skip_mismatch=False): | |
filepath = str(filepath) | ||
if filepath.endswith(".weights.h5"): | ||
# TODO: download file if h5 filepath is remote | ||
weights_store = H5IOStore(filepath, mode="r") | ||
if sharded: | ||
weights_store = ShardedH5IOStore(filepath, mode="r") | ||
else: | ||
weights_store = H5IOStore(filepath, mode="r") | ||
elif filepath.endswith(".keras"): | ||
archive = zipfile.ZipFile(filepath, "r") | ||
weights_store = H5IOStore( | ||
_VARS_FNAME + ".h5", archive=archive, mode="r" | ||
) | ||
all_filenames = archive.namelist() | ||
if _VARS_FNAME + ".json" in all_filenames: | ||
weights_store = ShardedH5IOStore( | ||
_VARS_FNAME + ".h5", | ||
archive=archive, | ||
mode="r", | ||
) | ||
else: | ||
weights_store = H5IOStore( | ||
_VARS_FNAME + ".h5", archive=archive, mode="r" | ||
) | ||
|
||
failed_trackables = set() | ||
error_msgs = {} | ||
|
@@ -617,6 +662,138 @@ def close(self): | |
self.io_file.close() | ||
|
||
|
||
class ShardedH5IOStore: | ||
def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't max_size be an int? e.g. in MB? |
||
self.shard_list = [] | ||
self.root_path = root_path | ||
self.mode = mode | ||
self.archive = archive | ||
self.io_file = None | ||
self.max_size = convert_str_bytes_to_int(max_size) | ||
self.current_shard_size = 0 | ||
|
||
self.var_shard_map_filename = str(root_path).replace( | ||
".weights.h5", ".weights.json" | ||
) | ||
if not os.path.exists(self.var_shard_map_filename): | ||
if self.mode == "w": | ||
self.var_shard_map = {} | ||
if self.mode == "r": | ||
raise FileNotFoundError( | ||
f"Loading a sharded `.weights.h5` file requires " | ||
"its corresponding sharding map JSON file " | ||
f"{self.var_shard_map_filename} in the same directory. " | ||
"Please ensure all weights files and the sharding map " | ||
"JSON file are in the same directory when loading a " | ||
"sharded weights file." | ||
) | ||
else: | ||
with open(self.var_shard_map_filename, "r") as map_file: | ||
self.var_shard_map = json.load(map_file) | ||
|
||
self.h5_file = self._create_new_file(root_path) | ||
|
||
def _create_new_file(self, path): | ||
if path in self.shard_list: | ||
path = resolve_duplicate_filename(str(path), self.shard_list) | ||
self.root_path = path | ||
if self.archive: | ||
if self.mode == "w": | ||
self.io_file = io.BytesIO() | ||
else: | ||
self.io_file = self.archive.open(path, "r") | ||
return h5py.File(self.io_file, mode=self.mode) | ||
else: | ||
return h5py.File(path, mode=self.mode) | ||
|
||
def _change_access_file(self, filename): # Read-only | ||
self.close() | ||
if self.archive: | ||
self.io_file = self.archive.open(filename, "r") | ||
return h5py.File(self.io_file, mode=self.mode) | ||
else: | ||
return h5py.File(filename, mode=self.mode) | ||
|
||
def make(self, path): | ||
def _get_size(key): | ||
if isinstance(self.h5_file[key], h5py.Dataset): | ||
self.current_shard_size += self.h5_file[key].nbytes | ||
|
||
self.current_shard_size = 0 | ||
self.h5_file.visit(_get_size) | ||
if self.current_shard_size > self.max_size: | ||
self.shard_list.append(self.h5_file.filename) | ||
self.close() | ||
self.h5_file = self._create_new_file(self.root_path) | ||
if not path: | ||
group = self.h5_file.create_group("vars") | ||
else: | ||
group = self.h5_file.create_group(path).create_group("vars") | ||
self.var_shard_map[group.name] = self.root_path | ||
return group | ||
|
||
def get(self, path): | ||
if not path: | ||
return self.h5_file["vars"] | ||
if path in self.h5_file and "vars" in self.h5_file[path]: | ||
return self.h5_file[path]["vars"] | ||
|
||
# If not found, check shard map and switch files | ||
filename = self.var_shard_map.get(path) or self.var_shard_map.get( | ||
"/" + path + "/vars" | ||
) | ||
if filename is not None and self.h5_file.name != filename: | ||
new_file = self._change_access_file(filename) | ||
if "vars" in new_file[path]: | ||
self.h5_file = new_file | ||
return self.h5_file[path]["vars"] | ||
return {} | ||
|
||
def close(self): | ||
self.h5_file.close() | ||
if self.mode == "w": | ||
with open(self.var_shard_map_filename, "w") as map_file: | ||
map_file.write(json.dumps(self.var_shard_map)) | ||
if self.archive: | ||
self.archive.writestr(self.root_path, self.io_file.getvalue()) | ||
if self.io_file: | ||
self.io_file.close() | ||
|
||
|
||
def convert_str_bytes_to_int(size): | ||
if size.upper().endswith("GB"): | ||
return int(size[:-2]) * (10**9) | ||
if size.upper().endswith("MB"): | ||
return int(size[:-2]) * (10**6) | ||
if size.upper().endswith("KB"): | ||
return int(size[:-2]) * (10**3) | ||
raise ValueError( | ||
"Invalid format for `size`. Use an integer followed by the unit " | ||
"(GB, MB, or KB). For example, '5GB' or '15MB'." | ||
) | ||
|
||
|
||
def resolve_duplicate_filename(path, path_list): | ||
pattern = re.compile(r"_\d\.weights\.h5") | ||
pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate | ||
if not pre_duplicate.endswith(".weights.h5"): | ||
match_list = list( | ||
filter(lambda x: x.startswith(pre_duplicate), path_list) | ||
) | ||
if len(match_list) > 1: | ||
return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5" | ||
return path.replace(".weights.h5", "_1.weights.h5") | ||
|
||
|
||
def dtype_to_bytes(dtype): | ||
if "bool" in str(dtype): | ||
return 1 / 8 | ||
bits = re.search(r"[^\d](\d+)$", str(dtype)) | ||
if bits is None: | ||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") | ||
return int(bits.groups()[0]) // 8 # Bit size in bytes | ||
|
||
|
||
class H5Entry: | ||
"""Leaf entry in a H5IOStore.""" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -754,6 +754,70 @@ def call(self, inputs): | |
return self.first_layer(self.second_layer(inputs)) | ||
|
||
|
||
def _get_large_model(): | ||
model = keras.Sequential( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why pick a convnet for a large model? |
||
[ | ||
keras.layers.Input(shape=[28, 28, 1], dtype="float32"), | ||
keras.layers.Conv2D( | ||
filters=12, | ||
kernel_size=3, | ||
padding="same", | ||
name="conv1", | ||
use_bias=False, | ||
), # no bias necessary before batch norm | ||
keras.layers.BatchNormalization( | ||
scale=False, center=True | ||
), # no batch norm scaling necessary before "relu" | ||
keras.layers.Activation("relu"), # activation after batch norm | ||
keras.layers.Conv2D( | ||
filters=24, | ||
kernel_size=6, | ||
padding="same", | ||
name="conv2", | ||
use_bias=False, | ||
strides=2, | ||
), | ||
keras.layers.BatchNormalization(scale=False, center=True), | ||
keras.layers.Activation("relu"), | ||
keras.layers.Conv2D( | ||
filters=32, | ||
kernel_size=6, | ||
padding="same", | ||
name="conv3", | ||
use_bias=False, | ||
strides=2, | ||
), | ||
keras.layers.BatchNormalization(scale=False, center=True), | ||
keras.layers.Activation("relu"), | ||
keras.layers.Flatten(), | ||
keras.layers.Dense(200, name="dense1", use_bias=False), | ||
keras.layers.BatchNormalization(scale=False, center=True), | ||
keras.layers.Activation("relu"), | ||
keras.layers.Dropout(0.4), # Dropout on dense layer only | ||
keras.layers.Dense(10, name="dense2", activation="softmax"), | ||
] | ||
) | ||
model.compile( | ||
optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] | ||
) | ||
return model | ||
|
||
|
||
class LargeModelTest(testing.TestCase): | ||
def test_model_sharding(self): | ||
model = _get_large_model() | ||
temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") | ||
ref_input = np.random.random((1, 28, 28, 1)) | ||
ref_output = model.predict(ref_input) | ||
saving_lib.save_weights_only( | ||
model, temp_filepath, sharded=True, shard_size="1MB" | ||
) | ||
|
||
model = _get_large_model() | ||
model.load_weights(temp_filepath, sharded=True) | ||
self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) | ||
|
||
|
||
class SavingBattleTest(testing.TestCase): | ||
def test_custom_object_without_from_config(self): | ||
temp_filepath = os.path.join( | ||
|
@@ -809,7 +873,9 @@ def dense(self): | |
def call(self, x): | ||
return self.dense(x) | ||
|
||
temp_filepath = "normal_model.weights.h5" | ||
temp_filepath = os.path.join( | ||
self.get_temp_dir(), "normal_model.weights.h5" | ||
) | ||
model_a = NormalModel() | ||
model_a(np.random.random((2, 2))) | ||
model_a.save_weights(temp_filepath) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should
sharded
be configurable here -- wouldn't it just depend on the file and the model?