Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize Equalization #201

Merged
merged 8 commits into from
Apr 5, 2022

Conversation

quantumalaviya
Copy link
Contributor

Port Equalization to BaseImageAugmentationLayer
closes #192

@quantumalaviya
Copy link
Contributor Author

Currently non-vectorized, tf.vectorized_map throws InvalidArgumentError. I would love any suggestions as to how to resolve this.

This is what happens when I remove self.auto_vectorize = False (basically switch from tf.map_fn to tf.vectorized_map). Here's the stack trace:

InvalidArgumentError                      Traceback (most recent call last)
[<ipython-input-9-882645fcdc75>](https://localhost:8080/#) in <module>()
----> 1 test = E(img)

1 frames
[/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

[/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py](https://localhost:8080/#) in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Exception encountered when calling layer "equalization" (type Equalization).

Graph execution error:

Detected at node 'loop_body/Where/pfor/TensorListConcatV2' defined at (most recent call last):
    File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py", line 846, in launch_instance
      app.start()
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py", line 499, in start
      self.io_loop.start()
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 132, in start
      self.asyncio_loop.run_forever()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 541, in run_forever
      self._run_once()
    File "/usr/lib/python3.7/asyncio/base_events.py", line 1786, in _run_once
      handle._run()
    File "/usr/lib/python3.7/asyncio/events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py", line 122, in _handle_events
      handler_func(fileobj, events)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 452, in _handle_events
      self._handle_recv()
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 481, in _handle_recv
      self._run_callback(callback, msg)
    File "/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py", line 431, in _run_callback
      callback(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py", line 300, in null_wrapper
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
      return self.dispatch_shell(stream, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
      handler(stream, idents, msg)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
      user_expressions, allow_stdin)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
      interactivity=interactivity, compiler=compiler, result=result)
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
      if self.run_code(code, result):
    File "/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "<ipython-input-9-882645fcdc75>", line 1, in <module>
      test = E(img)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py", line 1014, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler
      return fn(*args, **kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/layers/preprocessing/image_preprocessing.py", line 382, in call
      return self._format_output(self._batch_augment(inputs), is_dict)
    File "/usr/local/lib/python3.7/dist-packages/keras/layers/preprocessing/image_preprocessing.py", line 407, in _batch_augment
      return self._map_fn(self._augment, inputs)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py", line 185, in f
      iters,
Node: 'loop_body/Where/pfor/TensorListConcatV2'
PartialTensorShape: Incompatible shapes during merge: [202,1] vs. [256,1]
	 [[{{node loop_body/Where/pfor/TensorListConcatV2}}]] [Op:__inference_f_1684]

Call arguments received by layer "equalization" (type Equalization):
  • inputs=tf.Tensor(shape=(2, 672, 504, 3), dtype=int32)
  • training=True

@kartik4949
Copy link
Contributor

I can second this, the same error for tf.vectorized_map on my latest mask_to_boxes pr

@qlzh727
Copy link
Member

qlzh727 commented Mar 23, 2022

Thanks for the PR. I think the error might be similar as #146 (comment). Basically there are some intermidiate tensor has the dynamic shape that depends on the value of the input (eg the bucketing part). We should make it have the fixed shape, and then try the vectorized map, which should be way faster than the map_fn.

@bhack
Copy link
Contributor

bhack commented Mar 23, 2022

Thanks for the PR. I think the error might be similar as #146 (comment). Basically there are some intermidiate tensor has the dynamic shape that depends on the value of the input (eg the bucketing part). We should make it have the fixed shape, and then try the vectorized map, which should be way faster than the map_fn.

/cc @wangpengmit

@quantumalaviya
Copy link
Contributor Author

I believe the problem-causing intermediate tensors here are nonzero and nonzero_histogram, right?

@quantumalaviya
Copy link
Contributor Author

Thanks for the help @qlzh727! The function is now vectorized.

@quantumalaviya
Copy link
Contributor Author

quantumalaviya commented Mar 24, 2022

It does give out these warnings while computing but I assume it's not a big problem. Let me know if these need to be fixed, however.

WARNING:tensorflow:Using a while_loop for converting HistogramFixedWidth
WARNING:tensorflow:Using a while_loop for converting HistogramFixedWidth
WARNING:tensorflow:Using a while_loop for converting HistogramFixedWidth

@bhack
Copy link
Contributor

bhack commented Mar 24, 2022

I suppose that tf.histogram_fixed_width Is not supported so we could have some performance drain.

@quantumalaviya
Copy link
Contributor Author

Makes sense, I wonder if designing a supported util for computing histograms will make a huge difference

@bhack
Copy link
Contributor

bhack commented Mar 24, 2022

I think you can open a related ticket on tensorflow to track this with a very minimal example.

@bhack
Copy link
Contributor

bhack commented Mar 24, 2022

We are also trying to give some hints on the root cause of the fallback with tensorflow/tensorflow#55192

keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
@LukeWood
Copy link
Contributor

Thanks for the PR!

keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
@quantumalaviya quantumalaviya changed the title Port Equalization to BaseImageAugmentationLayer Vectorize Equalization Mar 29, 2022
keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
keras_cv/layers/preprocessing/equalization.py Outdated Show resolved Hide resolved
@LukeWood
Copy link
Contributor

LukeWood commented Apr 4, 2022

Apologies for the delay @quantumalaviya I think this looks pretty good. I'm a little concerned about that hard coded value in there but if we can add a test case to ensure it doesn't show up I'll feel pretty confident in it. Can we also include a demo with a side by side image - one augmented one not?

@quantumalaviya
Copy link
Contributor Author

quantumalaviya commented Apr 5, 2022

I agree, I am also not a fan of having the hardcoded value. However, it just serves the purpose of being a really big number to exclude zeroes.

@LukeWood
Copy link
Contributor

LukeWood commented Apr 5, 2022

I agree, I am also not a fan of having the hardcoded value. However, it just serves the purpose of being a really big number to exclude zeroes.

Thanks for the explanation.

@LukeWood
Copy link
Contributor

LukeWood commented Apr 5, 2022

Ok @quantumalaviya - I updated the PR with some copyedits and minor fixes around dtype handling. Also added some new test cases to cover the correctness.

Thanks for the great PR! Good idea to add a big value to mask out 0s. I appreciate the contribution!

@LukeWood
Copy link
Contributor

LukeWood commented Apr 5, 2022

Thanks you @quantumalaviya !!!

@LukeWood LukeWood merged commit 72a774d into keras-team:master Apr 5, 2022
LukeWood added a commit that referenced this pull request Apr 5, 2022
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
LukeWood added a commit that referenced this pull request Apr 5, 2022
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
kartik4949 pushed a commit to kartik4949/keras-cv that referenced this pull request Apr 21, 2022
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
@quantumalaviya quantumalaviya deleted the equalize_to_baseimgaug branch May 2, 2022 11:04
ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
* Port Equalization to BaseImageAugmentationLayer (non-vectorized)

* Vectorizing Equalization

* Vectorize Equalization

* Vectorize Equalization

* added changes

* Update equalization.py

* introduce equalization correctness tests

* Reformat equalization

Co-authored-by: Luke Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Migrate Equalization to use BaseImageAugmentationLayer
6 participants