Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling custom MoE routing distributions at runtime #2497

Closed
Mutinifni opened this issue Nov 25, 2024 · 1 comment
Closed

Enabling custom MoE routing distributions at runtime #2497

Mutinifni opened this issue Nov 25, 2024 · 1 comment
Assignees
Labels
feature request New feature or request triaged Issue has been triaged by maintainers

Comments

@Mutinifni
Copy link

Mutinifni commented Nov 25, 2024

For MoE models such as Mixtral, I want to override router expert selection to use custom routing distributions (as mentioned in #2331). To run my benchmarks, I am using gptManagerBenchmark, which requires a pre-built engine. However, I want to pass the routing distribution as a runtime config parameter which I can modify across runs without needing to rebuild a new engine. This would be similar to the routing string in the MoE layer microbenchmarks, but for the entire model instead of just 1 layer. Is there any way to do this dynamically? Thanks!

My current approach requires rebuilding the engine for different routing distributions. Specifically, I pass the routing distribution for each layer as a list of probabilities to the MixtureOfExperts module. The distribution is read from an external file.

class MixtureOfExperts(Module):

    def __init__(self,
                 [...]
                 experts_distribution: Optional[float] = None):

I then override the router output in forward() by sampling according to the experts_distribution:

def adjust_routing_to_distribution(self, routing, experts_distribution, k):
    num_tokens, num_experts = routing.shape
    assert experts_distribution.shape[0] == num_experts, "Distribution size must match number of experts"

    # Sample top-k experts based on the distribution
    adjusted_routing = torch.zeros_like(routing)
    for i in range(num_tokens):
        sampled_experts = torch.multinomial(experts_distribution, k, replacement=False)
        adjusted_routing[i, sampled_experts] = routing[i, sampled_experts]

    # Normalize the adjusted routing probabilities to ensure they sum to 1
    adjusted_routing = adjusted_routing / adjusted_routing.sum(dim=1, keepdim=True)

    return adjusted_routing
@Mutinifni Mutinifni changed the title Enabling runtime model configs Enabling custom MoE routing distributions at runtime Nov 25, 2024
@hello-11 hello-11 added the triaged Issue has been triaged by maintainers label Dec 2, 2024
@djns99 djns99 added the feature request New feature or request label Dec 4, 2024
@djns99
Copy link
Collaborator

djns99 commented Dec 4, 2024

Hi @Mutinifni, thanks for your request. We don't currently have a plan to do a generic implementation of a "fake routing" module like you describe. I think your manual solution is probably the best approach for now.
To avoid the engine rebuild the best way is likely to add a new input tensor for the TRT network that replaces the router and simply use that as your input to the MOE plugin. You can then initialise this to whatever distribution you like in the runtime. Though whatever approach will require some manual modifications for the time being.
I'll close this issue for now, feel free to comment further if you have anything to add

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants