-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slower inference time when using grouped convolutions compared to regular convolutions #20471
Comments
I am facing similar issue |
cuDNN is highly optimized for specific operations and configurations, especially when parameters align well with cuDNN's pre-tuned algorithms. When values like the number of groups in a convolution layer are adjusted, cuDNN may not perform as efficiently, as the algorithm selection may not be optimal for those configurations. Have you tried the JAX backend? what are you observing then? |
How do I know which values of the number of groups align well with cuDNN? The difference on JAX backend is not as stark, but the grouped convolutions are still slower:
The code that I used to produce these numbers:
|
As the title says, for some parameter values the inference time for grouped convolutions the inference time is slower than for regular convolutions (i.e. number of groups = 1). Standalone code to reproduce the issue:
The output on my setup:
The inference time is 25 times slower if I use number of groups = 2, even though it should reduce the required number of FLOPs by 2.
If I use 3D convolutions, the difference is even larger, and it throws an additional XLA warning:
The output:
Some of the system parameters if they are needed:
The text was updated successfully, but these errors were encountered: