Skip to content

Commit

Permalink
fix gpu test (keras-team#1939)
Browse files Browse the repository at this point in the history
* fix gpu test

* cast input

* update dtype

* change to resnet preset

* remove arg
  • Loading branch information
divyashreepathihalli authored and ushareng committed Oct 24, 2024
1 parent f10c45d commit ed035d3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
27 changes: 10 additions & 17 deletions keras_hub/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from keras import ops

from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.pali_gemma.pali_gemma_backbone import (
PaliGemmaBackbone,
)
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.tests.test_case import TestCase


Expand Down Expand Up @@ -86,24 +84,19 @@ def test_from_preset_errors(self):
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
converter = ImageConverter.from_preset(
"pali_gemma_3b_mix_224",
"resnet_50_imagenet",
interpolation="nearest",
)
converter.save_to_preset(save_dir)
# Save a tiny backbone so the preset is valid.
backbone = PaliGemmaBackbone(
vocabulary_size=100,
image_size=224,
num_layers=1,
num_query_heads=1,
num_key_value_heads=1,
hidden_dim=8,
intermediate_dim=16,
head_dim=8,
vit_patch_size=14,
vit_num_heads=1,
vit_hidden_dim=8,
vit_num_layers=1,
backbone = ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 64, 64],
stackwise_num_blocks=[2, 2, 2],
stackwise_num_strides=[1, 2, 2],
block_type="basic_block",
use_pre_activation=True,
)
backbone.save_to_preset(save_dir)

Expand Down
1 change: 0 additions & 1 deletion keras_hub/src/models/resnet/resnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class ResNetBackbone(FeaturePyramidBackbone):
stackwise_num_strides=[1, 2, 2],
block_type="basic_block",
use_pre_activation=True,
pooling="avg",
)
model(input_data)
```
Expand Down

0 comments on commit ed035d3

Please sign in to comment.