Skip to content

Commit

Permalink
Update file editor
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 23, 2024
1 parent cbb5ec6 commit f1da798
Show file tree
Hide file tree
Showing 2 changed files with 354 additions and 13 deletions.
291 changes: 278 additions & 13 deletions keras/src/saving/file_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import h5py
import rich.console

from keras.src import backend
from keras.src.saving import saving_lib
from keras.src.saving.saving_lib import H5IOStore
from keras.src.utils import naming
from keras.src.utils import summary_utils

try:
Expand All @@ -30,6 +32,26 @@ def is_ipython_notebook():


class KerasFileEditor:
"""Utility to inspect, edit, and resave Keras weights files.
Args:
filepath: The path to a local file to inspect and edit.
Examples:
```python
editor = KerasFileEditor("my_model.weights.h5")
# Displays current contents
editor.summary()
# Remove the weights of an existing layer
editor.delete_object("layers/dense_2")
# Add the weights of a new layer
editor.add_object("layers/einsum_dense", weights={"0": ..., "1": ...})
# Save the weights of the edited model
editor.resave_weights("edited_model.weights.h5")
```
"""

def __init__(
self,
filepath,
Expand Down Expand Up @@ -73,15 +95,138 @@ def __init__(
if self.metadata is not None:
self.console.print(self._generate_metadata_info(rich_style=True))

def weights_summary(self):
def summary(self):
"""Prints the weight structure of the opened file."""
if is_ipython_notebook():
self._weights_summary_iteractive()
else:
self._weights_summary_cli()

def compare_to_reference(self, model):
# TODO
raise NotImplementedError()
def compare_to(self, reference_model):
"""Compares the opened file to a reference model.
This method will list all mismatches between the
currently opened file and the provided reference model.
Args:
reference_model: Model instance to compare to.
Returns:
Dict with the following keys:
`'status'`, `'error_count'`, `'match_count'`.
Status can be `'success'` or `'error'`.
`'error_count'` is the number of mismatches found.
`'match_count'` is the number of matching weights found.
"""
self.console.print("Running comparison")
ref_spec = {}
get_weight_spec_of_saveable(reference_model, ref_spec)

def _compare(
target,
ref_spec,
inner_path,
target_name,
ref_name,
error_count,
match_count,
):
base_inner_path = inner_path
for ref_key, ref_val in ref_spec.items():
inner_path = base_inner_path + "/" + ref_key
if ref_key not in target:
if isinstance(ref_val, dict):
error_count += 1
self.console.print(
f"[color(160)]...Object [bold]{inner_path}[/] "
f"present in {ref_name}, "
f"missing from {target_name}[/]"
)
self.console.print(
f" In {ref_name}, {inner_path} contains "
f"the following keys: {list(ref_val.keys())}"
)
else:
self.console.print(
f"[color(160)]...Weight [bold]{inner_path}[/] "
f"present in {ref_name}, "
f"missing from {target_name}[/]"
)
elif isinstance(ref_val, dict):
_error_count, _match_count = _compare(
target[ref_key],
ref_spec[ref_key],
inner_path,
target_name,
ref_name,
error_count=error_count,
match_count=match_count,
)
error_count += _error_count
match_count += _match_count
else:
if target[ref_key].shape != ref_val.shape:
error_count += 1
self.console.print(
f"[color(160)]...Weight shape mismatch "
f"for [bold]{inner_path}[/][/]\n"
f" In {ref_name}: "
f"shape={tuple(ref_val[0])}\n"
f" In {target_name}: "
f"shape={target[ref_key].shape}"
)
else:
match_count += 1
return error_count, match_count

error_count, match_count = _compare(
self.weights_dict,
ref_spec,
inner_path="",
target_name="saved file",
ref_name="reference model",
error_count=0,
match_count=0,
)
_error_count, _ = _compare(
ref_spec,
self.weights_dict,
inner_path="",
target_name="reference model",
ref_name="saved file",
error_count=0,
match_count=0,
)
error_count += _error_count
self.console.print("─────────────────────")
if error_count == 0:
status = "success"
self.console.print(
"[color(28)][bold]Comparison successful:[/] "
"saved file is compatible with the reference model[/]"
)
if match_count == 1:
plural = ""
else:
plural = "s"
self.console.print(
f" Found {match_count} matching weight{plural}"
)
else:
status = "error"
if error_count == 1:
plural = ""
else:
plural = "s"
self.console.print(
f"[color(160)][bold]Found {error_count} error{plural}:[/] "
"saved file is not compatible with the reference model[/]"
)
return {
"status": status,
"error_count": error_count,
"match_count": match_count,
}

def _edit_object(self, edit_fn, source_name, target_name=None):
if target_name is not None and "/" in target_name:
Expand Down Expand Up @@ -137,20 +282,46 @@ def _edit(d):

_edit(self.weights_dict)

def rename_object(self, source_name, target_name):
def rename_object(self, object_name, new_name):
"""Rename an object in the file (e.g. a layer).
Args:
object_name: String, name or path of the
object to rename (e.g. `"dense_2"` or
`"layers/dense_2"`).
new_name: String, new name of the object.
"""

def rename_fn(weights_dict, source_name, target_name):
weights_dict[target_name] = weights_dict[source_name]
weights_dict.pop(source_name)

self._edit_object(rename_fn, source_name, target_name)
self._edit_object(rename_fn, object_name, new_name)

def delete_object(self, object_name):
"""Removes an object from the file (e.g. a layer).
Args:
object_name: String, name or path of the
object to delete (e.g. `"dense_2"` or
`"layers/dense_2"`).
"""

def delete_object(self, name):
def delete_fn(weights_dict, source_name, target_name=None):
weights_dict.pop(source_name)

self._edit_object(delete_fn, name)
self._edit_object(delete_fn, object_name)

def add_object(self, name, weights):
def add_object(self, object_path, weights):
"""Add a new object to the file (e.g. a layer).
Args:
object_path: String, full path of the
object to add (e.g. `"layers/dense_2"`).
weights: Dict mapping weight names to weight
values (arrays),
e.g. `{"0": kernel_value, "1": bias_value}`.
"""
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
Expand All @@ -159,9 +330,9 @@ def add_object(self, name, weights):
f"Received: type(weights)={type(weights)}"
)

if "/" in name:
if "/" in object_path:
# It's a path
elements = name.split("/")
elements = object_path.split("/")
partial_path = "/".join(elements[:-1])
weights_dict = self.weights_dict
for e in elements[:-1]:
Expand All @@ -172,9 +343,19 @@ def add_object(self, name, weights):
weights_dict = weights_dict[e]
weights_dict[elements[-1]] = weights
else:
self.weights_dict[name] = weights
self.weights_dict[object_path] = weights

def delete_weight(self, object_name, weight_name):
"""Removes a weight from an existing object.
Args:
object_name: String, name or path of the
object from which to remove the weight
(e.g. `"dense_2"` or `"layers/dense_2"`).
weight_name: String, name of the weight to
delete (e.g. `"0"`).
"""

def delete_weight_fn(weights_dict, source_name, target_name=None):
if weight_name not in weights_dict[source_name]:
raise ValueError(
Expand All @@ -188,6 +369,16 @@ def delete_weight_fn(weights_dict, source_name, target_name=None):
self._edit_object(delete_weight_fn, object_name)

def add_weights(self, object_name, weights):
"""Add one or more new weights to an existing object.
Args:
object_name: String, name or path of the
object to add the weights to
(e.g. `"dense_2"` or `"layers/dense_2"`).
weights: Dict mapping weight names to weight
values (arrays),
e.g. `{"0": kernel_value, "1": bias_value}`.
"""
if not isinstance(weights, dict):
raise ValueError(
"Argument `weights` should be a dict "
Expand All @@ -202,6 +393,12 @@ def add_weight_fn(weights_dict, source_name, target_name=None):
self._edit_object(add_weight_fn, object_name)

def resave_weights(self, filepath):
"""Save the edited weights file.
Args:
filepath: Path to save the file to.
Must be a `.weights.h5` file.
"""
filepath = str(filepath)
if not filepath.endswith(".weights.h5"):
raise ValueError(
Expand Down Expand Up @@ -319,7 +516,7 @@ def _print_weights_structure(
inner_path=inner_path,
)
else:
if isinstance(value, h5py.Dataset):
if hasattr(value, "shape"):
bold_key = summary_utils.bold_text(key)
self.console.print(
f"{prefix}{connector}{bold_key}:"
Expand Down Expand Up @@ -370,3 +567,71 @@ def _generate_html_weights(dictionary, margin_left=0, font_size=20):

if is_ipython_notebook():
ipython.display.display(ipython.display.HTML(output))


def get_weight_spec_of_saveable(saveable, spec, visited_saveables=None):

from keras.src.saving.keras_saveable import KerasSaveable

visited_saveables = visited_saveables or set()

# If the saveable has already been saved, skip it.
if id(saveable) in visited_saveables:
return

if hasattr(saveable, "save_own_variables"):
store = {}
saveable.save_own_variables(store)
if store:
keys = sorted(store.keys())
for k in keys:
val = store[k]
spec[k] = backend.KerasTensor(shape=val.shape, dtype=val.dtype)

visited_saveables.add(id(saveable))

for child_attr, child_obj in saving_lib._walk_saveable(saveable):
if isinstance(child_obj, KerasSaveable):
sub_spec = {}
get_weight_spec_of_saveable(
child_obj,
sub_spec,
visited_saveables=visited_saveables,
)
if sub_spec:
spec[child_attr] = sub_spec
elif isinstance(child_obj, (list, dict, tuple, set)):
sub_spec = {}
get_weight_spec_of_container(
child_obj,
sub_spec,
visited_saveables=visited_saveables,
)
if sub_spec:
spec[child_attr] = sub_spec


def get_weight_spec_of_container(container, spec, visited_saveables):

from keras.src.saving.keras_saveable import KerasSaveable

used_names = {}
if isinstance(container, dict):
container = list(container.values())

for saveable in container:
if isinstance(saveable, KerasSaveable):
name = naming.to_snake_case(saveable.__class__.__name__)
if name in used_names:
used_names[name] += 1
name = f"{name}_{used_names[name]}"
else:
used_names[name] = 0
sub_spec = {}
get_weight_spec_of_saveable(
saveable,
sub_spec,
visited_saveables=visited_saveables,
)
if sub_spec:
spec[name] = sub_spec
Loading

0 comments on commit f1da798

Please sign in to comment.