You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I was trying to accomplish the same thing using the same code. Instead I created a different class as an implementation of the Vit class which overwrites the forward pass to circumvent the last two layers that are used for classification.
import pytorch_pretrained_vit as ptv
from pytorch_pretrained_vit.model import PositionalEmbedding1D
class EncoderVit(ptv.ViT):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.positional_embedding = PositionalEmbedding1D(576, 768)
def forward(self, x):
x = self.patch_embedding(x).flatten(2).transpose(1, 2)
# x = torch.cat((model.class_token.expand(1, -1, -1), x), dim=1)
x = self.positional_embedding(x)
x = self.transformer(x)
return x
I needed this to use the ViT as an encoder, and I'm guessing you do too. Hope this helps!
I want extract the transformer intermediate layer. I use follow code, but it does not work.
nn.Sequential(*list(model.children()), how should i do?
The text was updated successfully, but these errors were encountered: