diff --git a/PhysicsTools/TensorFlow/BuildFile.xml b/PhysicsTools/TensorFlow/BuildFile.xml
index 90d48104241aa..0fa15c43a0f7b 100644
--- a/PhysicsTools/TensorFlow/BuildFile.xml
+++ b/PhysicsTools/TensorFlow/BuildFile.xml
@@ -3,6 +3,7 @@
+
diff --git a/PhysicsTools/TensorFlow/interface/TensorFlow.h b/PhysicsTools/TensorFlow/interface/TensorFlow.h
index edb6c86f9bc4f..371c49d4aad04 100644
--- a/PhysicsTools/TensorFlow/interface/TensorFlow.h
+++ b/PhysicsTools/TensorFlow/interface/TensorFlow.h
@@ -25,6 +25,8 @@
namespace tensorflow {
+ enum class Backend { cpu, cuda, rocm, intel, best };
+
typedef std::pair NamedTensor;
typedef std::vector NamedTensorList;
@@ -39,6 +41,10 @@ namespace tensorflow {
// since the threading configuration is done per run() call as of 2.1
void setThreading(SessionOptions& sessionOptions, int nThreads, const std::string& singleThreadPool);
+ // Set the backend option cpu/cuda
+ // The gpu memory is set to "allow_growth" to avoid TF getting all the CUDA memory at once.
+ void setBackend(SessionOptions& sessionOptions, Backend backend = Backend::cpu);
+
// loads a meta graph definition saved at exportDir using the SavedModel interface for a tag and
// predefined sessionOptions
// transfers ownership
@@ -52,11 +58,13 @@ namespace tensorflow {
// transfers ownership
MetaGraphDef* loadMetaGraphDef(const std::string& exportDir,
const std::string& tag = kSavedModelTagServe,
+ Backend backend = Backend::cpu,
int nThreads = 1);
// deprecated in favor of loadMetaGraphDef
MetaGraphDef* loadMetaGraph(const std::string& exportDir,
const std::string& tag = kSavedModelTagServe,
+ Backend backend = Backend::cpu,
int nThreads = 1);
// loads a graph definition saved as a protobuf file at pbFile
@@ -67,9 +75,9 @@ namespace tensorflow {
// transfers ownership
Session* createSession(SessionOptions& sessionOptions);
- // return a new, empty session with nThreads
+ // return a new, empty session with nThreads and selected backend
// transfers ownership
- Session* createSession(int nThreads = 1);
+ Session* createSession(Backend backend = Backend::cpu, int nThreads = 1);
// return a new session that will contain an already loaded meta graph whose exportDir must be
// given in order to load and initialize the variables, sessionOptions are predefined
@@ -83,7 +91,10 @@ namespace tensorflow {
// in order to load and initialize the variables, threading options are inferred from nThreads
// an error is thrown when metaGraphDef is a nullptr or when the graph has no nodes
// transfers ownership
- Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads = 1);
+ Session* createSession(const MetaGraphDef* metaGraphDef,
+ const std::string& exportDir,
+ Backend backend = Backend::cpu,
+ int nThreads = 1);
// return a new session that will contain an already loaded graph def, sessionOptions are predefined
// an error is thrown when graphDef is a nullptr or when the graph has no nodes
@@ -94,7 +105,7 @@ namespace tensorflow {
// inferred from nThreads
// an error is thrown when graphDef is a nullptr or when the graph has no nodes
// transfers ownership
- Session* createSession(const GraphDef* graphDef, int nThreads = 1);
+ Session* createSession(const GraphDef* graphDef, Backend backend = Backend::cpu, int nThreads = 1);
// closes a session, calls its destructor, resets the pointer, and returns true on success
bool closeSession(Session*& session);
diff --git a/PhysicsTools/TensorFlow/plugins/TfGraphDefProducer.cc b/PhysicsTools/TensorFlow/plugins/TfGraphDefProducer.cc
index a749a9f70bfde..2468a74737a6f 100644
--- a/PhysicsTools/TensorFlow/plugins/TfGraphDefProducer.cc
+++ b/PhysicsTools/TensorFlow/plugins/TfGraphDefProducer.cc
@@ -48,7 +48,7 @@ TfGraphDefProducer::TfGraphDefProducer(const edm::ParameterSet& iConfig)
// ------------ method called to produce the data ------------
TfGraphDefProducer::ReturnType TfGraphDefProducer::produce(const TfGraphRecord& iRecord) {
auto* graph = tensorflow::loadGraphDef(filename_);
- return std::make_unique(tensorflow::createSession(graph, 1), graph);
+ return std::make_unique(tensorflow::createSession(graph), graph);
}
void TfGraphDefProducer::fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
diff --git a/PhysicsTools/TensorFlow/src/TensorFlow.cc b/PhysicsTools/TensorFlow/src/TensorFlow.cc
index 37d3c5183a243..3dffc4c90ab39 100644
--- a/PhysicsTools/TensorFlow/src/TensorFlow.cc
+++ b/PhysicsTools/TensorFlow/src/TensorFlow.cc
@@ -6,8 +6,9 @@
*/
#include "PhysicsTools/TensorFlow/interface/TensorFlow.h"
-
#include "FWCore/MessageLogger/interface/MessageLogger.h"
+#include "FWCore/ServiceRegistry/interface/Service.h"
+#include "FWCore/Utilities/interface/ResourceInformation.h"
namespace tensorflow {
@@ -25,6 +26,65 @@ namespace tensorflow {
setThreading(sessionOptions, nThreads);
}
+ void setBackend(SessionOptions& sessionOptions, Backend backend) {
+ /*
+ * The TensorFlow backend configures the available devices using options provided in the sessionOptions proto.
+ * // Options from https://github.com/tensorflow/tensorflow/blob/c53dab9fbc9de4ea8b1df59041a5ffd3987328c3/tensorflow/core/protobuf/config.proto
+ *
+ * If the device_count["GPU"] = 0 GPUs are not used.
+ * The visible_device_list configuration is used to map the `visible` devices (from CUDA_VISIBLE_DEVICES) to `virtual` devices.
+ * If Backend::cpu is request, the GPU device is disallowed by device_count configuration.
+ * If Backend::cuda is request:
+ * - if ResourceInformation shows an available Nvidia GPU device:
+ * the device is used with memory_growth configuration (not allocating all cuda memory at once).
+ * - if no device is present: an exception is raised.
+ */
+
+ edm::Service ri;
+ if (backend == Backend::cpu) {
+ // disable GPU usage
+ (*sessionOptions.config.mutable_device_count())["GPU"] = 0;
+ sessionOptions.config.mutable_gpu_options()->set_visible_device_list("");
+ }
+ // NVidia GPU
+ else if (backend == Backend::cuda) {
+ if (not ri->nvidiaDriverVersion().empty()) {
+ // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
+ (*sessionOptions.config.mutable_device_count())["GPU"] = 1;
+ sessionOptions.config.mutable_gpu_options()->set_visible_device_list("0");
+ // Do not allocate all the memory on the GPU at the beginning.
+ sessionOptions.config.mutable_gpu_options()->set_allow_growth(true);
+ } else {
+ edm::Exception ex(edm::errors::UnavailableAccelerator);
+ ex << "Cuda backend requested, but no NVIDIA GPU available in the job";
+ ex.addContext("Calling tensorflow::setBackend()");
+ throw ex;
+ }
+ }
+ // ROCm and Intel GPU are still not supported
+ else if ((backend == Backend::rocm) || (backend == Backend::intel)) {
+ edm::Exception ex(edm::errors::UnavailableAccelerator);
+ ex << "ROCm/Intel GPU backend requested, but TF is not compiled yet for this platform";
+ ex.addContext("Calling tensorflow::setBackend()");
+ throw ex;
+ }
+ // Get NVidia GPU if possible or fallback to CPU
+ else if (backend == Backend::best) {
+ // Check if a Nvidia GPU is availabl
+ if (not ri->nvidiaDriverVersion().empty()) {
+ // Take only the first GPU in the CUDA_VISIBLE_DEVICE list
+ (*sessionOptions.config.mutable_device_count())["GPU"] = 1;
+ sessionOptions.config.mutable_gpu_options()->set_visible_device_list("0");
+ // Do not allocate all the memory on the GPU at the beginning.
+ sessionOptions.config.mutable_gpu_options()->set_allow_growth(true);
+ } else {
+ // Just CPU support
+ (*sessionOptions.config.mutable_device_count())["GPU"] = 0;
+ sessionOptions.config.mutable_gpu_options()->set_visible_device_list("");
+ }
+ }
+ }
+
MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, SessionOptions& sessionOptions) {
// objects to load the graph
Status status;
@@ -49,19 +109,20 @@ namespace tensorflow {
return loadMetaGraphDef(exportDir, tag, sessionOptions);
}
- MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, int nThreads) {
+ MetaGraphDef* loadMetaGraphDef(const std::string& exportDir, const std::string& tag, Backend backend, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
+ setBackend(sessionOptions, backend);
return loadMetaGraphDef(exportDir, tag, sessionOptions);
}
- MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, int nThreads) {
+ MetaGraphDef* loadMetaGraph(const std::string& exportDir, const std::string& tag, Backend backend, int nThreads) {
edm::LogInfo("PhysicsTools/TensorFlow")
<< "tensorflow::loadMetaGraph() is deprecated, use tensorflow::loadMetaGraphDef() instead";
- return loadMetaGraphDef(exportDir, tag, nThreads);
+ return loadMetaGraphDef(exportDir, tag, backend, nThreads);
}
GraphDef* loadGraphDef(const std::string& pbFile) {
@@ -95,10 +156,11 @@ namespace tensorflow {
return session;
}
- Session* createSession(int nThreads) {
+ Session* createSession(Backend backend, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
+ setBackend(sessionOptions, backend);
return createSession(sessionOptions);
}
@@ -152,10 +214,11 @@ namespace tensorflow {
return session;
}
- Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, int nThreads) {
+ Session* createSession(const MetaGraphDef* metaGraphDef, const std::string& exportDir, Backend backend, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
+ setBackend(sessionOptions, backend);
return createSession(metaGraphDef, exportDir, sessionOptions);
}
@@ -186,10 +249,11 @@ namespace tensorflow {
return session;
}
- Session* createSession(const GraphDef* graphDef, int nThreads) {
+ Session* createSession(const GraphDef* graphDef, Backend backend, int nThreads) {
// create session options and set thread options
SessionOptions sessionOptions;
setThreading(sessionOptions, nThreads);
+ setBackend(sessionOptions, backend);
return createSession(graphDef, sessionOptions);
}
diff --git a/PhysicsTools/TensorFlow/test/BuildFile.xml b/PhysicsTools/TensorFlow/test/BuildFile.xml
index b5ed26c3fedf5..1270225afc9f0 100644
--- a/PhysicsTools/TensorFlow/test/BuildFile.xml
+++ b/PhysicsTools/TensorFlow/test/BuildFile.xml
@@ -2,38 +2,166 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+