diff --git a/deepxde/data/quadruple.py b/deepxde/data/quadruple.py index 2e9fc0ee4..efbe91767 100644 --- a/deepxde/data/quadruple.py +++ b/deepxde/data/quadruple.py @@ -94,7 +94,7 @@ def train_next_batch(self, batch_size=None): self.train_x[0][indices_branch], self.train_x[1][indices_branch], self.train_x[2][indices_trunk], - ), self.train_y[indices_branch, indices_trunk] + ), self.train_y[indices_branch][:, indices_trunk] def test(self): return self.test_x, self.test_y