Skip to content

Commit

Permalink
#12461: Torch reference implementation of PointNet model
Browse files Browse the repository at this point in the history
  • Loading branch information
HariniMohan0102 committed Sep 10, 2024
1 parent 723bbd4 commit ad7ea6c
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.experimental.functional_pointnet.reference.PointNetfeat import PointNetfeat


class PointNetDenseCls(nn.Module):
def __init__(self, k=2, feature_transform=False):
super(PointNetDenseCls, self).__init__()
self.k = k
self.feature_transform = feature_transform
self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
self.conv2 = torch.nn.Conv1d(512, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 128, 1)
self.conv4 = torch.nn.Conv1d(128, self.k, 1)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.bn3 = nn.BatchNorm1d(128)

def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
x, trans, trans_feat = self.feat(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = x.transpose(2, 1).contiguous()
x = F.log_softmax(x.view(-1, self.k), dim=-1)
x = x.view(batchsize, n_pts, self.k)
return x, trans, trans_feat
48 changes: 48 additions & 0 deletions models/experimental/functional_pointnet/reference/PointNetfeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.experimental.functional_pointnet.reference.STN3d import STN3d
from models.experimental.functional_pointnet.reference.STNkd import STNkd


class PointNetfeat(nn.Module):
def __init__(self, global_feat=True, feature_transform=False):
super(PointNetfeat, self).__init__()
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform:
self.fstn = STNkd(k=64)

def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))

if self.feature_transform:
trans_feat = self.fstn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2, 1)
else:
trans_feat = None

pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat
46 changes: 46 additions & 0 deletions models/experimental/functional_pointnet/reference/STN3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F


class STN3d(nn.Module):
def __init__(self):
super(STN3d, self).__init__()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
self.relu = nn.ReLU()

self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)

def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)

x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)

iden = (
Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32)))
.view(1, 9)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, 3, 3)
return x
48 changes: 48 additions & 0 deletions models/experimental/functional_pointnet/reference/STNkd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F


class STNkd(nn.Module):
def __init__(self, k=64):
super(STNkd, self).__init__()
self.conv1 = torch.nn.Conv1d(k, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k * k)
self.relu = nn.ReLU()

self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)

self.k = k

def forward(self, x):
batchsize = x.size()[0]
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)

x = F.relu(self.bn4(self.fc1(x)))
x = F.relu(self.bn5(self.fc2(x)))
x = self.fc3(x)

iden = (
Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32)))
.view(1, self.k * self.k)
.repeat(batchsize, 1)
)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, self.k, self.k)
return x
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
import pytest
from torch.autograd import Variable
from models.experimental.functional_pointnet.reference.PointNetDenseCls import PointNetDenseCls
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
def test_pointnet_model(device, reset_seeds):
input = torch.randn(32, 3, 2500, requires_grad=True)
reference_model = PointNetDenseCls(k=3)

new_state_dict = {}
keys = [name for name, parameter in reference_model.state_dict().items()]
ds_state_dict = {k: v for k, v in reference_model.state_dict().items()}
values = [parameter for name, parameter in ds_state_dict.items()]
for i in range(len(keys)):
new_state_dict[keys[i]] = values[i]
reference_model.load_state_dict(new_state_dict)
reference_model.eval()

output, _, _ = reference_model(input)

0 comments on commit ad7ea6c

Please sign in to comment.