Skip to content

Commit

Permalink
refactor: Avoid changing the backend in KerasModule's __call__ if it'…
Browse files Browse the repository at this point in the history
…s already tensorflow.
  • Loading branch information
AnnaTz committed Nov 30, 2023
1 parent 423ebed commit a385c0b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ivy/stateful/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,12 @@ def call(self, *args, training=None, **kwargs):
return ret

def __call__(self, *args, **kwargs):
ivy.set_backend("tensorflow")
args, kwargs = ivy.args_to_new_backend(*args, native=True, **kwargs)
ivy.previous_backend()

if ivy.backend != "tensorflow":
ivy.set_backend("tensorflow")
args, kwargs = ivy.args_to_new_backend(*args, native=True, **kwargs)
ivy.previous_backend()
else:
args, kwargs = ivy.args_to_new_backend(*args, native=True, **kwargs)
return super(KerasModel, self).__call__(*args, **kwargs)

def to_device(self, device):
Expand Down

0 comments on commit a385c0b

Please sign in to comment.