Skip to content

Commit

Permalink
Fix Ruff rule B008
Browse files Browse the repository at this point in the history
  • Loading branch information
smokestacklightnin committed Oct 27, 2024
1 parent f5224b2 commit f011508
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
5 changes: 4 additions & 1 deletion tfx/dsl/component/experimental/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for tfx.dsl.components.base.decorators."""


from __future__ import annotations
import pytest
import os
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -141,8 +142,10 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable=

def verify_beam_pipeline_arg_non_none_default_value(
a: int,
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
beam_pipeline: BeamComponentParameter[beam.Pipeline] = None,
) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types
if beam_pipeline is None:
beam_pipeline = beam.Pipeline()
del beam_pipeline
return {'b': float(a)}

Expand Down
5 changes: 4 additions & 1 deletion tfx/dsl/component/experimental/decorators_typeddict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Tests for tfx.dsl.components.base.decorators."""


from __future__ import annotations
import pytest
import os
from typing import Any, Dict, List, Optional, TypedDict
Expand Down Expand Up @@ -141,8 +142,10 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): #

def verify_beam_pipeline_arg_non_none_default_value(
a: int,
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
beam_pipeline: BeamComponentParameter[beam.Pipeline] | None = None,
) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types
if beam_pipeline is None:
beam_pipeline = beam.Pipeline()
del beam_pipeline
return {'b': float(a)}

Expand Down
7 changes: 5 additions & 2 deletions tfx/examples/bert/utils/bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Configurable fine-tuning BERT models for various tasks."""

from __future__ import annotations
from typing import Optional, List, Union

import tensorflow as tf
Expand Down Expand Up @@ -59,8 +60,7 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,

def compile_bert_classifier(
model: tf.keras.Model,
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
loss: tf.keras.losses.Loss | None = None,
learning_rate: float = 2e-5,
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
"""Compile the BERT classifier using suggested parameters.
Expand All @@ -79,6 +79,9 @@ def compile_bert_classifier(
Returns:
None.
"""
if loss is None:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

if metrics is None:
metrics = ["sparse_categorical_accuracy"]

Expand Down

0 comments on commit f011508

Please sign in to comment.