Skip to content

Commit

Permalink
1. add RooDataSet -> TTree transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Oct 29, 2024
1 parent 4ec6609 commit 4ab2745
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 41 deletions.
41 changes: 28 additions & 13 deletions ostap/fitting/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,22 +3362,34 @@ def _rad_rows_ ( dataset ,
# data = dataset.ds2tree ( name = 'my_tree' , filename= 'aa.root'
# @endcode
# - result of the type <code>ostap.tree.data_utils.Data</code>
# @param name tree name (if not specified dataset name wil lbe used)
# @param name tree name (if not specified, dataset name will be used)
# @param filename tile name (if not specified, temporary file will be used)
def _ds_2tree_ ( dataset , name = '' , filename = '' ) :
def _ds_2tree_ ( dataset , name = '' , filename = '' , cuts = '' , vars = () , cut_range = '' ) :
""" Convert `ROOT.RooDataSet` to `ROOT.TTree`
>>> dataset = ...
>>> data = dataset.ds2tree ( name = 'my_tree' , filename= 'aa.root'
- name : tree name (if not specified dataset name wil lbe used)
- name : tree name (if not specified, dataset name will be used)
- filename : file name (if not specified, temporary file will be used)
- result of the type `ostap.tree.data_utils.Data`
"""

if not name : name = dataset.GetName()
if not filename :
import ostap.utils.cleanup as CU
filename = CU.CleanUp.tempfile ( suffix = '.root' )

if not name : name = dataset.GetName()


cuts = str ( cuts ).strip() if cuts else ''
cut_range = str ( cut_range ).strip() if cut_range else ''
if cuts or cut_range or vars :
dsaux = dataset.subset ( variables = vars ,
cuts = cuts ,
cut_range = cut_range )
result = _ds_2tree_ ( dsaux , name = name , filename = filename )
if dsaux and isinstance ( dsaux , ROOT.RooDataSet ) :
dsaux = Ostap.MoreRooFit.delete_data ( dsaux)
del dsaux
return result

import ostap.io.root_file
from ostap.trees.data_utils import Data

Expand All @@ -3387,11 +3399,14 @@ def _ds_2tree_ ( dataset , name = '' , filename = '' ) :
with ROOT.TFile ( filename , 'c' ) as rfile :
rfile.cd()
if not isinstance ( store , ROOT.RooTreeDataStore ) :
dstmp = ROOT.RooDataSet ( dataset , dsID() )
dstmp.convertToTreeStore()
store = dstmp.store()
with useStorage ( ROOT.RooAbsData.Tree ) :
dstmp = ROOT.RooDataSet ( dataset , dsID() )
store = dstmp.store()
if not isinstance ( store , ROOT.RooTreeDataStore ) :
dstmp.convertToTreeStore()
store = dstmp.store()
assert isinstance ( store , ROOT.RooTreeDataStore ) , \
'Store is not RooTreeDataStore: %s' % ( type ( store ) .__name__ )
'Store type %s is not RooTreeDataStore!' % ( type ( store ) .__name__ )
rfile [ name ] = store.tree()

## with ROOT.TFile ( filename , 'r' ) as rfile : rfile.ls()
Expand All @@ -3400,10 +3415,10 @@ def _ds_2tree_ ( dataset , name = '' , filename = '' ) :
dstmp = Ostap.MoreRooFit.delete_data ( dstmp )
del dstmp

return Data ( chain = name ,
files = [ filename ] ,
return Data ( chain = name ,
files = [ filename ] ,
description = "TTree from dataset %s/%s " % ( dataset.name , dataset.title ) ,
silent = True )
silent = True )

ROOT.RooDataSet.ds2tree = _ds_2tree_

Expand Down
52 changes: 24 additions & 28 deletions ostap/tools/tmva.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def files ( self ) :
# =============================================================================
## some manipulations with TMVA options
def opts_replace ( opts , expr , direct = True ) :
"""some manipulations with TMVA options"""
""" Some manipulations with TMVA options"""
if direct :
if 0 <= opts.find ( '!' + expr ) : opts = opts.replace ( '!' + expr , expr )
elif 0 <= opts.find ( expr ) : pass
Expand All @@ -183,7 +183,7 @@ def opts_replace ( opts , expr , direct = True ) :
# t = make_tarfile ( 'outptu.tgz' , files , varbose = True , tmp = True )
# @endcode
def make_tarfile ( output , files , verbose = False , tmp = False ) :
"""Create the tar file from components, optionally create it as tmp,
""" Create the tar file from components, optionally create it as tmp,
and later copy to the final destination.
(Sometime for unreliable file systems (like EOS via fsmount)
normal creation of tar-gile raises OSError
Expand Down Expand Up @@ -267,7 +267,7 @@ def make_tarfile ( output , files , verbose = False , tmp = False ) :
# @author Vanya BELYAEV [email protected]
# Thanks to Albert PUIG
class Trainer(object):
"""Helper class to train TMVA:
""" Helper class to train TMVA:
>>> from ostap.tools.tmva import Trainer
>>> t = Trainer(
Expand Down Expand Up @@ -351,7 +351,7 @@ def __init__( self ,
multithread = False ,
logger = None ) :

"""Constructor with list of methods
""" Constructor with list of methods
>>> from ostap.tools.tmva import Trainer
>>> t = Trainer(
Expand All @@ -378,8 +378,6 @@ def __init__( self ,
if logging :
if isinstance ( logging , string_types ) : self.__logging = logging
else : self.__logging = "%s.log" % self.name



for i , e in enumerate ( methods ) :
t = e [ 0 ]
Expand Down Expand Up @@ -496,7 +494,6 @@ def __init__( self ,
## self.__verbose = True if verbose else False
self.__verbose = True if ( verbose and self.category <= 0 ) else False


self.__bookingoptions = bookingoptions

import os
Expand Down Expand Up @@ -551,7 +548,6 @@ def __init__( self ,
self.__methods = tuple ( _methods )

self.__bookingoptions = str ( opts )


pattern_xml = pattern_XML % ( self.dirname , self.dirname )
pattern_C = pattern_CLASS % ( self.dirname , self.dirname )
Expand Down Expand Up @@ -872,13 +868,13 @@ def train ( self ) :
return self._train()

# =========================================================================
## train TMVA
## Train TMVA
# @code
# trainer.train ()
# @endcode
# @return the name of output XML file with the weights
def _train ( self ) :
""" train TMVA
""" Train TMVA
>>> trainer.train ()
return the name of output XML files with the weights
"""
Expand Down Expand Up @@ -1029,7 +1025,7 @@ def _train ( self ) :
# @endcode
# @return the names of output XML files with the weights
def __train ( self ) :
"""Train the TMVA:
""" Train the TMVA:
- returns the names of output XML file with the weights
>>> trainer.train ()
"""
Expand Down Expand Up @@ -1536,7 +1532,7 @@ def __train ( self ) :
# =========================================================================
## make selected standard TMVA plots
def makePlots ( self , name = None , output = None , ) :
"""Make selected standard TMVA plots"""
""" Make selected standard TMVA plots"""

self.logger.warning ( "makePlots: method is (temporarily?) disabled!" )
return
Expand Down Expand Up @@ -1615,7 +1611,7 @@ def makePlots ( self , name = None , output = None , ) :
# =============================================================================
## make selected standard TMVA plots
def make_Plots ( name , output , show_plots = True ) :
"""Make selected standard TMVA plots"""
""" Make selected standard TMVA plots"""

## if (6,29) <= root_info :
## logger.warning ("function is disabled")
Expand Down Expand Up @@ -1843,7 +1839,7 @@ def __init__ ( self ,
options = '' ,
logger = None ,
verbose = True ) :
"""Construct the TMVA reader
""" Construct the TMVA reader
>>> from ostap.tools.tmva import Reader
>>> r = Reader ( 'MyTMVA' ,
... variables = [ ## name accessor
Expand Down Expand Up @@ -1982,7 +1978,7 @@ def variables ( self ) :
# val = var ( entry )
# @endcode
class Var (object) :
"""Helper class to get TMVA decision for the certain method
""" Helper class to get TMVA decision for the certain method
>>> reader = ...
>>> var = reader[ method ]
>>> val = var ( entry )
Expand Down Expand Up @@ -2027,7 +2023,7 @@ def eval ( self , entry , cut_efficiency = 0.9 ) :
# =========================================================================
## helper class to get TMVA decision for certain method
class Method (Var) :
"""Helper class to get TMVA decision for certain method
""" Helper class to get TMVA decision for certain method
>>> reader = ...
>>> var = reader[ method ]
>>> val = var ( entry )
Expand Down Expand Up @@ -2056,7 +2052,7 @@ def __call__ ( self , arg , *args ) :
# print('Response is %s' % method.evaluate ( pt , eta , phi ) )
# @endcode
def evaluate ( self , *args ) :
"""Evaluate the method from parameters
""" Evaluate the method from parameters
>>> method = ...
>>> pt, eta, phi = 5 , 3.0 , 0 ## variables
>>> print('Response is %s' % method.evaluate ( pt , eta , phi ) )
Expand All @@ -2075,7 +2071,7 @@ def evaluate ( self , *args ) :
# ... print('MLP/BDTG for this event are %s/%s' % (mlp , bdtg))
# @endcode
def __getitem__ ( self , method ) :
"""Helper utility to get the corresponding function from the reader:
""" Helper utility to get the corresponding function from the reader:
- Use the reader
>>> tree = .... ## TTree/TChain/RooDataSet with data
>>> mlp_fun = reader['MLP'] ## <-- here!
Expand All @@ -2101,7 +2097,7 @@ def __getitem__ ( self , method ) :
# ... print('MLP/BDTG for this event are %s/%s' % (mlp , bdtg))
# @endcode
def __getattr__ ( self , method ) :
"""Helper utility to get the correspondig function from the reader:
""" Helper utility to get the correspondig function from the reader:
- Use the reader
>>> tree = .... ## TTree/TChain/RooDataSet with data
>>> mlp_fun = reader.MLP ## <-- here!
Expand Down Expand Up @@ -2129,7 +2125,7 @@ def __getattr__ ( self , method ) :
# Ugly trick with arrays is needed due to some technical problems
# (actually TMVA reader needs the address of 'float' (in C++ sense) variable
def __call__ ( self , method , entry , cut_efficiency = 0.90 ) :
"""Evaluate TMVA
""" Evaluate TMVA
- Use the reader
>>> tree = .... ## TTree/TChain/RooDataSet with data
>>> for entry in tree :
Expand Down Expand Up @@ -2157,7 +2153,7 @@ def __call__ ( self , method , entry , cut_efficiency = 0.90 ) :
# print ('MLP response is: ', reader.MLP.evaluate ( pt , y ))
# @endcode
def evaluate ( self , method , *args ) :
"""Evaluate TMVA
""" Evaluate TMVA
>>> reader = ...
>>> pt, y = ... ##
>>> print('MLP response is: ', reader.MLP.evaluate ( pt , y ))
Expand Down Expand Up @@ -2187,7 +2183,7 @@ def evaluate ( self , method , *args ) :
# =============================================================================
## start TMVA gui
def tmvaGUI ( filename , new_canvas = True ) :
"""Start TMVA-GUI
""" Start TMVA-GUI
"""
## ROOT.gROOT.LoadMacro('TMVAGui.C')
if new_canvas :
Expand All @@ -2201,7 +2197,7 @@ def tmvaGUI ( filename , new_canvas = True ) :
# =============================================================================
## convert input structure to Ostap.TMVA.MAPS
def _inputs2map_ ( inputs ) :
"""Convert input structure to Ostap.TMVA.MAPS
""" Convert input structure to Ostap.TMVA.MAPS
"""
from ostap.core.core import cpp, std, Ostap
MAP = std.map ( 'std::string', 'std::string' )
Expand Down Expand Up @@ -2235,6 +2231,8 @@ def _inputs2map_ ( inputs ) :
# =============================================================================
## convert weights structure to Ostap.TMVA.MAP
def _weights2map_ ( weights ) :
""" Convert weights structure to Ostap.TMVA.MAP
"""

from ostap.core.core import cpp, std, Ostap
MAP = std.map ( 'std::string', 'std::string' )
Expand All @@ -2254,7 +2252,7 @@ def _weights2map_ ( weights ) :

# =============================================================================
def _add_response_tree ( tree , verbose , *args ) :
"""Specific action to ROOT.TChain
""" Specific action to ROOT.TChain
"""

import ostap.trees.trees
Expand Down Expand Up @@ -2311,7 +2309,7 @@ def _add_response_tree ( tree , verbose , *args ) :

# =============================================================================
def _add_response_chain ( chain , verbose , *args ) :
"""Specific action to ROOT.TChain
""" Specific action to ROOT.TChain
"""

import ostap.trees.trees
Expand Down Expand Up @@ -2379,8 +2377,7 @@ def addTMVAResponse ( dataset , ## input dataset to be updated
options = '' , ## TMVA-reader options
verbose = True , ## verbosity flag
aux = 0.9 ) : ## for Cuts method : efficiency cut-off
"""
Helper function to add TMVA response into dataset
""" Helper function to add TMVA response into dataset
>>> tar_file = trainer.tar_file
>>> dataset = ...
>>> inputs = [ 'var1' , 'var2' , 'var2' ]
Expand Down Expand Up @@ -2443,7 +2440,6 @@ def addTMVAResponse ( dataset , ## input dataset to be updated

return newdata


# =============================================================================
if '__main__' == __name__ :

Expand Down

0 comments on commit 4ab2745

Please sign in to comment.