diff --git a/python/mpact/models/train.py b/python/mpact/models/train.py index 791edcc..2fdaf68 100644 --- a/python/mpact/models/train.py +++ b/python/mpact/models/train.py @@ -2,6 +2,16 @@ import torch.nn.functional as F +def num_all_parameters(model): + """Returns the number of all parameters in a model.""" + return sum(p.numel() for p in model.parameters()) + + +def num_parameters(model): + """Returns the number of trainable parameters in a model.""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + def training_loop(model, optimizer, loss_function, train, validation, epochs=10): """A rudimentary PyTorch training loop for classification with training and validation data.""" for epoch in range(epochs): @@ -27,7 +37,9 @@ def training_loop(model, optimizer, loss_function, train, validation, epochs=10) output = model(inp) loss = loss_function(output, target) vloss += loss.data.item() - correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], target).view(-1) + correct = torch.eq( + torch.max(F.softmax(output, dim=1), dim=1)[1], target + ).view(-1) num_correct += torch.sum(correct).item() num_total += correct.shape[0] diff --git a/test/python/train_simple.py b/test/python/train_simple.py index dd4b0f3..8658e4d 100644 --- a/test/python/train_simple.py +++ b/test/python/train_simple.py @@ -7,7 +7,7 @@ from mpact.mpactbackend import mpact_jit from mpact.models.kernels import SimpleNet -from mpact.models.train import training_loop +from mpact.models.train import training_loop, num_all_parameters, num_parameters A = torch.tensor( @@ -94,6 +94,12 @@ def __getitem__(self, index): validation = DataLoader(validation_data, batch_size=2) +# CHECK-LABEL: parameters +# CHECK-COUNT-2: 182 +print("parameters") +print(num_all_parameters(net)) +print(num_parameters(net)) + # Run it with PyTorch. # CHECK-LABEL: pytorch # CHECK: Epoch 9