Skip to content

Commit

Permalink
Small update
Browse files Browse the repository at this point in the history
  • Loading branch information
Robin Manhaeve committed Jul 3, 2024
1 parent ab0f6bd commit 04fe7a1
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
25 changes: 22 additions & 3 deletions src/deepproblog/tests/test_neural_predicate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import pytest
from deepproblog.utils.standard_networks import DummyNet
from problog.logic import Term, Var
from deepproblog.utils.standard_networks import DummyNet, DummyTensorNet
from problog.logic import Term, Var, Constant
from deepproblog.engines import ExactEngine, ApproximateEngine
from deepproblog.model import Model
from deepproblog.query import Query
from deepproblog.network import Network

import torch

program = """
nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y).
nn(dummy2,[X]) :: net2(X).
nn(dummy3,[X],Y) :: net3(X,Y).
nn(dummy4,[X,Y],Z,[a,b]) :: net4(X,Y,Z).
test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2).
test2(X1,X2) :- net2(X1), net2(X2).
Expand All @@ -25,6 +28,10 @@
dummy_values3 = {Term("i1"): [1.0, 2.0, 3.0, 4.0], Term("i2"): [-1.0, 0.0, 1.0]}
dummy_net3 = Network(DummyNet(dummy_values3), "dummy3")

dummy_net4 = Network(DummyTensorNet(batching=True), "dummy4", batching=True)

tensors = {(Constant(0),): torch.Tensor([0.2]), (Constant(1),): torch.Tensor([0.8])}


@pytest.fixture(
params=[
Expand All @@ -43,9 +50,10 @@ def model(request) -> Model:
"""Simple fixture creating both the approximate and the exact engine"""
if ApproximateEngine is None and request.param["name"] == "approximate":
pytest.skip("ApproximateEngine is not available as PySWIP is not installed")
model = Model(program, [dummy_net1, dummy_net2, dummy_net3], load=False)
model = Model(program, [dummy_net1, dummy_net2, dummy_net3, dummy_net4], load=False)
engine = request.param["engine_factory"](model)
model.set_engine(engine, cache=request.param["cache"])
model.add_tensor_source('dummy', tensors)
return model


Expand Down Expand Up @@ -99,3 +107,14 @@ def test_det_network_substitution(model: Model):
r2 = model.get_tensor(r2)
assert all(r1.detach().numpy() == [1.0, 2.0, 3.0, 4.0])
assert all(r2.detach().numpy() == [-1.0, 0.0, 1.0])

def test_double_input(model: Model):
terms = lambda x: Term("net4",
Term("tensor",Term("dummy", Constant(0))),
Term("tensor",Term("dummy", Constant(1))),
x)
results = model.solve([Query(terms(Var("X")))])
r1 = float(results[0].result[terms(Term("a"))])
r2 = float(results[0].result[terms(Term("b"))])
assert pytest.approx(0.2) == r1
assert pytest.approx(0.8) == r2
21 changes: 18 additions & 3 deletions src/deepproblog/utils/standard_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,30 @@ def forward(self, x):


class DummyNet(nn.Module):
def __init__(self, values: Dict[Term, Union[list, torch.Tensor]]):
def __init__(self, values: Dict[Union[Term, tuple[Term, ...]], Union[list, torch.Tensor]]):
super().__init__()
self.values = values

def forward(self, x):
output = self.values[x]
def forward(self, *x: Term):
if len(x) == 1:
output = self.values[x[0]]
else:
output = self.values[x]
return torch.tensor(output, requires_grad=True)


class DummyTensorNet(nn.Module):
def __init__(self, batching=False):
super().__init__()
self.batching = batching

def forward(self, *x: torch.Tensor):
if self.batching:
return torch.stack([torch.tensor(y, requires_grad=True) for y in x], dim=0)
else:
return torch.tensor(x, requires_grad=True)


class SmallNet(nn.Module):
def __init__(self, num_classes=1000, size=None):
super(SmallNet, self).__init__()
Expand Down

0 comments on commit 04fe7a1

Please sign in to comment.