Skip to content

Commit

Permalink
Fixed error with multi-input neural networks and batches.
Browse files Browse the repository at this point in the history
Also added a test for this case.
  • Loading branch information
smistad committed Dec 19, 2023
1 parent ca375a6 commit 6d2c81c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion source/FAST/Algorithms/NeuralNetwork/NeuralNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ std::unordered_map<std::string, Tensor::pointer> NeuralNetwork::processInputData

if(m_batchSize == -1) {
m_batchSize = dataList.getSize();
} else {
} else if(m_batchSize != dataList.getSize()) {
throw Exception("Inconsistent batch size accross input nodes");
}
} else {
Expand Down
25 changes: 25 additions & 0 deletions source/FAST/Algorithms/NeuralNetwork/Tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,31 @@ TEST_CASE("Multi input single output network", "[fast][neuralnetwork]") {
}
}

TEST_CASE("Multi input single output network with batch", "[fast][neuralnetwork]") {
for(auto& engine : InferenceEngineManager::getEngineList()) {
auto importer = ImageFileImporter::New();
importer->setFilename(Config::getTestDataPath() + "US/JugularVein/US-2D_0.mhd");
auto image = importer->runAndGetOutputData<Image>();
auto batch1 = Batch::create({image, image});
auto batch2 = Batch::create({image, image});

auto network = NeuralNetwork::New();
network->setInferenceEngine(engine);
network->load(join(Config::getTestDataPath(),
"NeuralNetworkModels/multi_input_single_output." +
getModelFileExtension(network->getInferenceEngine()->getPreferredModelFormat())));
network->connect(0, batch1);
network->connect(1, batch2);
auto batch = network->runAndGetOutputData<Batch>();
auto list = batch->get().getTensors();
REQUIRE(list.size() == 2);
auto data = list[0];
// We are expecting a tensor output with dimensions (6)
REQUIRE(data->getShape().getDimensions() == 1);
CHECK(data->getShape()[0] == 6);
}
}

TEST_CASE("Single input multi output network", "[fast][neuralnetwork]") {
for(auto& engine : InferenceEngineManager::getEngineList()) {
#ifdef WIN32
Expand Down

0 comments on commit 6d2c81c

Please sign in to comment.