Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating a model without functions? #1935

Open
noahcoolboy opened this issue Nov 7, 2024 · 1 comment
Open

Generating a model without functions? #1935

noahcoolboy opened this issue Nov 7, 2024 · 1 comment

Comments

@noahcoolboy
Copy link

Hello! I've been trying to port a model from pytorch manually to onnx using onnxscript.
I've tried to come with a way of doing this elegantly by creating "custom blocks" with attributes.
However, because of how onnxscript currently works, there are some issues.

This is my current code

def GConv2D(key: str, kernel_size: int, padding: int):
    weight = weights[key + ".weight"].numpy()
    bias = weights[key + ".bias"].numpy()

    @script()
    def GConv2D(r: FLOAT[...]):
        r = op.Conv(
            r,
            weight,
            bias,
            kernel_shape=[kernel_size, kernel_size],
            pads=[padding, padding, padding, padding],
        )

        return r

    return GConv2D

def GroupResBlock(key: str, in_dim: int, out_dim: int):
    downsample = GConv2D(key + ".downsample", 1, 0) if in_dim != out_dim else Identity()
    conv1 = GConv2D(key + ".conv1", 3, 1)
    conv2 = GConv2D(key + ".conv2", 3, 1)

    @script()
    def GroupResBlock(x: FLOAT[...]):
        x = conv1(op.Relu(x))
        x = conv2(op.Relu(x))
        x = downsample(x)
        return x
    
    return GroupResBlock

def MaskDecoderBlock(key: str):
    up_16_8 = GroupResBlock(key + ".up_16_8.out_conv", 256, 128)
    up_8_4 = GroupResBlock(key + ".up_8_4.out_conv", 128, 128)

    @script()
    def MaskDecoderBlock(x: FLOAT[...]):
        x = up_16_8(x)
        x = up_8_4(x)
        return x

    return MaskDecoderBlock

model = MaskDecoderBlock("mask_decoder").to_model_proto()

"downsample" from GroupResBlock is set conditionally. I want it to downsample if in_dim and out_dim are not equal to each other. To avoid having to put this if statement in the model itself, the check is done before so it can be baked into the model as is.

The issue is, up_16_8 gets created, and the function GroupResBlock gets defined as having the downsample block. When up_8_4 gets created, the function GroupResBlock is already defined and it reuses it (with the downsample block, and the wrong weights!)

Is there a way to generate a model proto without functions? As to make it avoid reusing blocks, and make it generate a flat graph instead?

@justinchuby
Copy link
Collaborator

We have plans to build a non-script mode, but that's not done yet. In the meantime you may take a look at ONNX IR https://github.com/microsoft/onnxscript/tree/main/onnxscript/ir and how we use it in https://github.com/pytorch/pytorch/blob/5f4a21dc58c7b0a732ae0dec8fdbf2dfbda4e7d5/torch/onnx/_internal/exporter/_building.py#L498 to capture the op calls and constructs the graph with ONNX IR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants