Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p1ch8: "ResBlock" objects are most likely identical #89

Open
aallahyar opened this issue May 10, 2022 · 2 comments
Open

p1ch8: "ResBlock" objects are most likely identical #89

aallahyar opened this issue May 10, 2022 · 2 comments

Comments

@aallahyar
Copy link

aallahyar commented May 10, 2022

In the scripts of p1ch8 (section 8.5.3, page 227 to be specific), we are making sub-blocks of convolution using the following code:

self.resblocks = nn.Sequential(
            *(n_blocks * [ResBlock(n_chans=n_chans1)]))

However, considering that objects are copied by reference, then I can imagine that the weight matrices across created ResBlocks are identical.

I think the code should be changed to:

self.resblocks = nn.Sequential(
            *[ResBlock(n_chans=n_chans1) for _ in range(n_blocks)])
@t-vi
Copy link

t-vi commented May 10, 2022

Absolutely, thank you for spotting this and reporting.

@ftianRF
Copy link

ftianRF commented May 19, 2022

@aallahyar Yes.

Here I add an example for more readers:

import torch
import torch.nn as nn

class NetResDeep(nn.Module):
    def __init__(self, n_chans1=32, n_blocks=10):
        super().__init__()
        self.n_chans1 = n_chans1
        self.conv1 = nn.Conv2d(3, n_chans1, kernel_size=3, padding=1)
        self.resblocks = nn.Sequential(*([ResBlock(n_chans=n_chans1)] * n_blocks))    # shown in the book
        #self.resblocks = nn.Sequential(*[ResBlock(n_chans=n_chans1) for _ in range(n_blocks)])    # the right version
        self.fc1 = nn.Linear(n_chans1 * 8 * 8, 32)
        self.fc2 = nn.Linear(32, 2)
    
    def forward(self, x):
        out = F.max_pool2d(torch.relu(self.conv1(x)), 2)
        out = F.max_pool2d(self.resblocks(out))
        out = out.view(-1, self.n_chans1 * 8 * 8)
        out = torch.relu(self.fc1(out))
        out = self.fc2(out)
        return out

netresdeep = NetResDeep()
id(netresdeep.resblocks[0].conv.weight) == id(netresdeep.resblocks[1].conv.weight)

Output:

True

This result shows that these two sets of weights share the same memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants