Skip to content

Commit

Permalink
Adds error handling for invalid input names on model call.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557585594
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Aug 16, 2023
1 parent 6ce11ac commit 29b1384
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
15 changes: 15 additions & 0 deletions keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import contextlib
import functools
import itertools
import re
import textwrap
import threading
import warnings
Expand Down Expand Up @@ -71,6 +72,8 @@
# Prefix that is added to the TF op layer names.
_TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"

_VALID_INPUT_NAME_REGEX = r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$"

# TODO(mdan): Should we have a single generic type for types that can be passed
# to tf.cast?
_AUTOCAST_TYPES = (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)
Expand Down Expand Up @@ -1052,6 +1055,18 @@ def __call__(self, *args, **kwargs):
inputs, args, kwargs = self._call_spec.split_out_first_arg(args, kwargs)
input_list = tf.nest.flatten(inputs)

def _check_valid_input_names(x):
if not re.match(_VALID_INPUT_NAME_REGEX, x):
raise ValueError(
"Received an invalid input name: "
f"`{x}`. Please ensure that all input names do "
"not contain invalid characters such as spaces, "
"semicolons, etc."
)

if isinstance(inputs, dict):
tf.nest.map_structure(_check_valid_input_names, inputs.keys())

# Functional Model construction mode is invoked when `Layer`s are called
# on symbolic `KerasTensor`s, i.e.:
# >> inputs = tf.keras.Input(10)
Expand Down
18 changes: 16 additions & 2 deletions keras/engine/base_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def call(self, inputs):
@test_combinations.generate(test_combinations.combine(mode=["eager"]))
def test_composite_variable_assignment(self):
class Spec(tf.TypeSpec):

value_type = property(lambda self: CompositeVariable)

def _component_specs(self):
Expand Down Expand Up @@ -527,6 +526,22 @@ def test_exception_if_name_not_string_or_none(self):
):
base_layer.Layer(name=0)

def test_exception_if_call_invalid_input(self):
class MyModel(training_lib.Model):
def call(self, inputs):
return inputs["a feature"] + inputs["b_feature"]

inputs = {
"a feature": tf.constant([1.0]),
"b_feature": tf.constant([2.0]),
}

model = MyModel()
with self.assertRaisesRegex(
ValueError, "Received an invalid input name"
):
_ = model(inputs)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand Down Expand Up @@ -1649,7 +1664,6 @@ def wrapper():
)
class AutographControlFlowTest(test_combinations.TestCase):
def test_disabling_in_context_is_matched(self):

test_obj = self

class MyLayer(base_layer.Layer):
Expand Down

0 comments on commit 29b1384

Please sign in to comment.