From 04fe7a1cc12da7ab5b430730c485b889bd10c875 Mon Sep 17 00:00:00 2001 From: Robin Manhaeve Date: Wed, 3 Jul 2024 12:23:21 +0200 Subject: [PATCH] Small update --- .../tests/test_neural_predicate.py | 25 ++++++++++++++++--- src/deepproblog/utils/standard_networks.py | 21 +++++++++++++--- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/deepproblog/tests/test_neural_predicate.py b/src/deepproblog/tests/test_neural_predicate.py index be143c6..156a3c0 100644 --- a/src/deepproblog/tests/test_neural_predicate.py +++ b/src/deepproblog/tests/test_neural_predicate.py @@ -1,15 +1,18 @@ import pytest -from deepproblog.utils.standard_networks import DummyNet -from problog.logic import Term, Var +from deepproblog.utils.standard_networks import DummyNet, DummyTensorNet +from problog.logic import Term, Var, Constant from deepproblog.engines import ExactEngine, ApproximateEngine from deepproblog.model import Model from deepproblog.query import Query from deepproblog.network import Network +import torch + program = """ nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y). nn(dummy2,[X]) :: net2(X). nn(dummy3,[X],Y) :: net3(X,Y). +nn(dummy4,[X,Y],Z,[a,b]) :: net4(X,Y,Z). test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2). test2(X1,X2) :- net2(X1), net2(X2). @@ -25,6 +28,10 @@ dummy_values3 = {Term("i1"): [1.0, 2.0, 3.0, 4.0], Term("i2"): [-1.0, 0.0, 1.0]} dummy_net3 = Network(DummyNet(dummy_values3), "dummy3") +dummy_net4 = Network(DummyTensorNet(batching=True), "dummy4", batching=True) + +tensors = {(Constant(0),): torch.Tensor([0.2]), (Constant(1),): torch.Tensor([0.8])} + @pytest.fixture( params=[ @@ -43,9 +50,10 @@ def model(request) -> Model: """Simple fixture creating both the approximate and the exact engine""" if ApproximateEngine is None and request.param["name"] == "approximate": pytest.skip("ApproximateEngine is not available as PySWIP is not installed") - model = Model(program, [dummy_net1, dummy_net2, dummy_net3], load=False) + model = Model(program, [dummy_net1, dummy_net2, dummy_net3, dummy_net4], load=False) engine = request.param["engine_factory"](model) model.set_engine(engine, cache=request.param["cache"]) + model.add_tensor_source('dummy', tensors) return model @@ -99,3 +107,14 @@ def test_det_network_substitution(model: Model): r2 = model.get_tensor(r2) assert all(r1.detach().numpy() == [1.0, 2.0, 3.0, 4.0]) assert all(r2.detach().numpy() == [-1.0, 0.0, 1.0]) + +def test_double_input(model: Model): + terms = lambda x: Term("net4", + Term("tensor",Term("dummy", Constant(0))), + Term("tensor",Term("dummy", Constant(1))), + x) + results = model.solve([Query(terms(Var("X")))]) + r1 = float(results[0].result[terms(Term("a"))]) + r2 = float(results[0].result[terms(Term("b"))]) + assert pytest.approx(0.2) == r1 + assert pytest.approx(0.8) == r2 \ No newline at end of file diff --git a/src/deepproblog/utils/standard_networks.py b/src/deepproblog/utils/standard_networks.py index 0c53b99..d00bfc5 100644 --- a/src/deepproblog/utils/standard_networks.py +++ b/src/deepproblog/utils/standard_networks.py @@ -32,15 +32,30 @@ def forward(self, x): class DummyNet(nn.Module): - def __init__(self, values: Dict[Term, Union[list, torch.Tensor]]): + def __init__(self, values: Dict[Union[Term, tuple[Term, ...]], Union[list, torch.Tensor]]): super().__init__() self.values = values - def forward(self, x): - output = self.values[x] + def forward(self, *x: Term): + if len(x) == 1: + output = self.values[x[0]] + else: + output = self.values[x] return torch.tensor(output, requires_grad=True) +class DummyTensorNet(nn.Module): + def __init__(self, batching=False): + super().__init__() + self.batching = batching + + def forward(self, *x: torch.Tensor): + if self.batching: + return torch.stack([torch.tensor(y, requires_grad=True) for y in x], dim=0) + else: + return torch.tensor(x, requires_grad=True) + + class SmallNet(nn.Module): def __init__(self, num_classes=1000, size=None): super(SmallNet, self).__init__()