From 4ebf849ac981f94fcbfe689f06455a7cb00f9cd6 Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Thu, 19 Dec 2024 22:17:26 +0100 Subject: [PATCH] DPL: improve handling of RNTuple (#13818) - Support more integer types, including tests. - Add ability to support objects which are not grouped in a TDirectory --- .../AnalysisSupport/src/RNTuplePlugin.cxx | 27 +++++++++++++++++++ .../include/Framework/RootArrowFilesystem.h | 5 ++++ Framework/Core/src/Plugin.cxx | 12 +++++++-- Framework/Core/src/RootArrowFilesystem.cxx | 4 ++- Framework/Core/test/test_Root2ArrowTable.cxx | 12 ++++++--- 5 files changed, 53 insertions(+), 7 deletions(-) diff --git a/Framework/AnalysisSupport/src/RNTuplePlugin.cxx b/Framework/AnalysisSupport/src/RNTuplePlugin.cxx index 9f67785f1a069..f66723419c24e 100644 --- a/Framework/AnalysisSupport/src/RNTuplePlugin.cxx +++ b/Framework/AnalysisSupport/src/RNTuplePlugin.cxx @@ -187,6 +187,31 @@ struct RootNTupleVisitor : public ROOT::Experimental::Detail::RFieldVisitor { this->datatype = arrow::int32(); } + void VisitInt8Field(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::int8(); + } + + void VisitInt16Field(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::int16(); + } + + void VisitUInt32Field(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::uint32(); + } + + void VisitUInt8Field(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::uint8(); + } + + void VisitUInt16Field(const ROOT::Experimental::RField& field) override + { + this->datatype = arrow::int16(); + } + void VisitBoolField(const ROOT::Experimental::RField& field) override { this->datatype = arrow::boolean(); @@ -240,6 +265,8 @@ std::unique_ptr rootFieldFromArrow(std::shared_p return std::make_unique>(name); case arrow::Type::DOUBLE: return std::make_unique>(name); + case arrow::Type::STRING: + return std::make_unique>(name); default: throw runtime_error("Unsupported arrow column type"); } diff --git a/Framework/Core/include/Framework/RootArrowFilesystem.h b/Framework/Core/include/Framework/RootArrowFilesystem.h index 8744656e7d55d..feab713b445fe 100644 --- a/Framework/Core/include/Framework/RootArrowFilesystem.h +++ b/Framework/Core/include/Framework/RootArrowFilesystem.h @@ -83,6 +83,11 @@ struct RootArrowFactoryPlugin { struct RootObjectReadingCapability { // The unique name of this capability std::string name = "unknown"; + // Convert a logical filename to an actual object to be read + // This can be used, e.g. to read an RNTuple stored in + // a flat directory structure in a TFile vs a TTree stored inside + // a TDirectory (e.g. /DF_1000/o2tracks). + std::function lfn2objectPath; // Given a TFile, return the object which this capability support // Use a void * in order not to expose the kind of object to the // generic reading code. This is also where we load the plugin diff --git a/Framework/Core/src/Plugin.cxx b/Framework/Core/src/Plugin.cxx index af71db4af3445..568908426c143 100644 --- a/Framework/Core/src/Plugin.cxx +++ b/Framework/Core/src/Plugin.cxx @@ -179,12 +179,12 @@ struct ImplementationContext { std::function getHandleByClass(char const* classname) { - return [classname](TDirectoryFile* file, std::string const& path) { return file->GetObjectChecked(path.c_str(), TClass::GetClass(classname)); }; + return [c = TClass::GetClass(classname)](TDirectoryFile* file, std::string const& path) { return file->GetObjectChecked(path.c_str(), c); }; } std::function getBufferHandleByClass(char const* classname) { - return [classname](TBufferFile* buffer, std::string const& path) { buffer->Reset(); return buffer->ReadObjectAny(TClass::GetClass(classname)); }; + return [c = TClass::GetClass(classname)](TBufferFile* buffer, std::string const& path) { buffer->Reset(); return buffer->ReadObjectAny(c); }; } void lazyLoadFactory(std::vector& implementations, char const* specs) @@ -210,6 +210,13 @@ struct RNTupleObjectReadingCapability : o2::framework::RootObjectReadingCapabili return new RootObjectReadingCapability{ .name = "rntuple", + .lfn2objectPath = [](std::string s) { + std::replace(s.begin()+1, s.end(), '/', '-'); + if (s.starts_with("/")) { + return s; + } else { + return "/" + s; + } }, .getHandle = getHandleByClass("ROOT::Experimental::RNTuple"), .getBufferHandle = getBufferHandleByClass("ROOT::Experimental::RNTuple"), .factory = [context]() -> RootArrowFactory& { @@ -226,6 +233,7 @@ struct TTreeObjectReadingCapability : o2::framework::RootObjectReadingCapability return new RootObjectReadingCapability{ .name = "ttree", + .lfn2objectPath = [](std::string s) { return s; }, .getHandle = getHandleByClass("TTree"), .getBufferHandle = getBufferHandleByClass("TTree"), .factory = [context]() -> RootArrowFactory& { diff --git a/Framework/Core/src/RootArrowFilesystem.cxx b/Framework/Core/src/RootArrowFilesystem.cxx index 545ba6f0afb71..4a1286515508c 100644 --- a/Framework/Core/src/RootArrowFilesystem.cxx +++ b/Framework/Core/src/RootArrowFilesystem.cxx @@ -47,7 +47,8 @@ std::shared_ptr TFileFileSystem::GetSubFilesystem(arr // file, so that we can support TTree and RNTuple at the same time // without having to depend on both. for (auto& capability : mObjectFactory.capabilities) { - void* handle = capability.getHandle(mFile, source.path()); + auto objectPath = capability.lfn2objectPath(source.path()); + void* handle = capability.getHandle(mFile, objectPath); if (!handle) { continue; } @@ -238,6 +239,7 @@ std::shared_ptr TBufferFileFS::GetSubFilesystem(arrow // file, so that we can support TTree and RNTuple at the same time // without having to depend on both. for (auto& capability : mObjectFactory.capabilities) { + void* handle = capability.getBufferHandle(mBuffer, source.path()); if (handle) { mFilesystem = capability.factory().getSubFilesystem(handle); diff --git a/Framework/Core/test/test_Root2ArrowTable.cxx b/Framework/Core/test/test_Root2ArrowTable.cxx index 8eb3a9825f0f7..04a8d91303f0e 100644 --- a/Framework/Core/test/test_Root2ArrowTable.cxx +++ b/Framework/Core/test/test_Root2ArrowTable.cxx @@ -369,7 +369,7 @@ bool validateContents(std::shared_ptr batch) bool validateSchema(std::shared_ptr schema) { - REQUIRE(schema->num_fields() == 10); + REQUIRE(schema->num_fields() == 11); REQUIRE(schema->field(0)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(1)->type()->id() == arrow::float32()->id()); REQUIRE(schema->field(2)->type()->id() == arrow::float32()->id()); @@ -380,6 +380,7 @@ bool validateSchema(std::shared_ptr schema) REQUIRE(schema->field(7)->type()->id() == arrow::boolean()->id()); REQUIRE(schema->field(8)->type()->id() == arrow::fixed_size_list(arrow::boolean(), 2)->id()); REQUIRE(schema->field(9)->type()->id() == arrow::list(arrow::int32())->id()); + REQUIRE(schema->field(10)->type()->id() == arrow::int8()->id()); return true; } @@ -435,6 +436,7 @@ TEST_CASE("RootTree2Dataset") bool manyBool[2]; int vla[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; int vlaSize = 0; + char byte; t->Branch("px", &px, "px/F"); t->Branch("py", &py, "py/F"); @@ -447,6 +449,7 @@ TEST_CASE("RootTree2Dataset") t->Branch("manyBools", &manyBool, "manyBools[2]/O"); t->Branch("vla_size", &vlaSize, "vla_size/I"); t->Branch("vla", vla, "vla[vla_size]/I"); + t->Branch("byte", &byte, "byte/B"); // fill the tree for (Int_t i = 0; i < 100; i++) { xyz[0] = 1; @@ -463,6 +466,7 @@ TEST_CASE("RootTree2Dataset") manyBool[0] = (i % 4 == 0); manyBool[1] = (i % 5 == 0); vlaSize = i % 10; + byte = i; t->Fill(); } } @@ -512,7 +516,7 @@ TEST_CASE("RootTree2Dataset") auto batches = (*scanner)(); auto result = batches.result(); REQUIRE(result.ok()); - REQUIRE((*result)->columns().size() == 10); + REQUIRE((*result)->columns().size() == 11); REQUIRE((*result)->num_rows() == 100); validateContents(*result); @@ -552,7 +556,7 @@ TEST_CASE("RootTree2Dataset") auto batchesWritten = (*scanner)(); auto resultWritten = batches.result(); REQUIRE(resultWritten.ok()); - REQUIRE((*resultWritten)->columns().size() == 10); + REQUIRE((*resultWritten)->columns().size() == 11); REQUIRE((*resultWritten)->num_rows() == 100); validateContents(*resultWritten); } @@ -586,7 +590,7 @@ TEST_CASE("RootTree2Dataset") auto rntupleBatchesWritten = (*rntupleScannerWritten)(); auto rntupleResultWritten = rntupleBatchesWritten.result(); REQUIRE(rntupleResultWritten.ok()); - REQUIRE((*rntupleResultWritten)->columns().size() == 10); + REQUIRE((*rntupleResultWritten)->columns().size() == 11); REQUIRE(validateSchema((*rntupleResultWritten)->schema())); REQUIRE((*rntupleResultWritten)->num_rows() == 100); REQUIRE(validateContents(*rntupleResultWritten));