Skip to content

Commit

Permalink
1. add argument to function
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Dec 19, 2023
1 parent e77468d commit 784f334
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 17 deletions.
8 changes: 8 additions & 0 deletions ReleaseNotes/release_notes.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
## New features:

1. add `more_vars`argument to `ostap.fitting.ds2numpy.ds2numpy` function
`
## Backward incompatible:

## Bug fixes:

# v1.10.0.4

## New features:
Expand Down
55 changes: 49 additions & 6 deletions ostap/fitting/ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# @see RooAbsData
# @see RooDataSet
# @author Artem Egorychev [email protected]
# @date 2023-12-04
# @date 2023-12-12
# =============================================================================
""" Helper module to convert RooDtaSet to numpy array
Module with decoration for RooAbsData and related RooFit classes
Expand All @@ -21,9 +21,11 @@
)
# =============================================================================
from ostap.core.meta_info import root_info
from ostap.core.ostap_types import string_types
from ostap.core.ostap_types import string_types, dictlike_types
from ostap.core.core import loop_items
from ostap.utils.utils import split_range
from ostap.fitting.dataset import useStorage
from ostap.fitting.funbasic import AFUN1
from ostap.utils.progress_bar import progress_bar
import ostap.fitting.roocollections
import ROOT
Expand Down Expand Up @@ -59,7 +61,7 @@
# @see ROOT.RooAbsDataStore.getCategoryBatches
# @see ROOT.RooAbsDataStore.getWeightBatche
# @attention conversion to ROOT.RooVectorDataStore is used!
def ds2numpy ( dataset , var_lst , silent = True ) :
def ds2numpy ( dataset , var_lst , silent = True , more_vars = {} ) :
""" Convert dataset into numpy array using `ROOT.RooAbsData` iterface
- see ROOT.RooAbsData.getBatches
- see ROOT.RooAbsData.getCategoryBatches
Expand All @@ -70,7 +72,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
- attention: Conversion to `ROOT.RooVectorDataStore` is used!
"""

## 1) check that all variables are present in dataset
## 1) get names of all requested variables
if all ( isinstance ( v , string_types ) for v in var_lst ) :
vnames = [ v for v in var_lst ]
elif all ( isinstance ( v , ROOT.RooAbsArg ) for v in var_lst ) :
Expand All @@ -81,6 +83,19 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
## 2) check that all variables are present in the dataset
assert all ( ( v in dataset ) for v in var_lst ) , 'Not all variables are in dataset!'

funcs = []
if more_vars and isinstance ( more_vars , dictlike_types ) :
for name , fun in loop_items ( more_vars ) :
if isinstance ( fun , AFUN1 ) :
absreal = fun.fun
elif isinstance( fun , ROOT.RooAbsPdf ) : absreal = fun
elif isinstance( fun , ROOT.RooAbsReal ) : absreal = fun
else :
raise TypeError ( "Invald type ofun/pdf" )
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )

## 3) reduce dataset if only a small subset of variables is requested
nvars = len ( dataset.get() )
if 2 * len ( vnames ) <= nvars :
Expand Down Expand Up @@ -119,6 +134,8 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
for v in vnames :
if v in doubles : dtypes.append ( ( v , np.float64 ) )
elif v in categories : dtypes.append ( ( v , np.int64 ) )

for f in funcs : dtypes.append ( ( f[0] , np.float64 ) )
if weight : dtypes.append ( ( weight , np.float64 ) )

## get data in batches
Expand Down Expand Up @@ -166,6 +183,13 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
else :
data = np.concatenate ( [ data , part ] )
del part

## add PDF values
if funcs :
for i, entry in enumerate ( source ) :
for vname , func , obsvars in funcs :
obsvars.assign ( entry )
data [ vname ] [ i ] = func.getVal()

if delsource :
source.reset()
Expand All @@ -188,7 +212,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) :

# =========================================================================
## Convert dataset into numpy array using (slow) explicit loops
def ds2numpy ( dataset , var_lst , silent = False ) :
def ds2numpy ( dataset , var_lst , silent = False , more_vars = {} ) :
""" Convert dataset into numpy array using (slow) explicit loops
"""

Expand All @@ -202,7 +226,20 @@ def ds2numpy ( dataset , var_lst , silent = False ) :

## 2) check that all variables are present in the dataset
assert all ( ( v in dataset ) for v in var_lst ) , 'Not all variables are in dataset!'


funcs = []
if more_vars and isinstance ( more_vars , dictlike_types ) :
for name , fun in loop_items ( more_vars ) :
if isinstance ( fun , AFUN1 ) :
absreal = fun.fun
elif isinstance( fun , ROOT.RooAbsPdf ) : absreal = fun
elif isinstance( fun , ROOT.RooAbsReal ) : absreal = fun
else :
raise TypeError ( "Invald type ofun/pdf" )
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )

## 3) reduce dataset if only a small subset of variables is requested
nvars = len ( dataset.get() )
if 2 * len ( vnames ) <= nvars :
Expand All @@ -224,6 +261,7 @@ def ds2numpy ( dataset , var_lst , silent = False ) :
for v in vnames :
if v in doubles : dtypes.append ( ( v , np.float64 ) )
elif v in categories : dtypes.append ( ( v , np.int64 ) )
for f in funcs : dtypes.append ( ( f[0] , np.float64 ) )
if weight : dtypes.append ( ( weight , np.float64 ) )


Expand All @@ -239,6 +277,11 @@ def ds2numpy ( dataset , var_lst , silent = False ) :
elif vname in categories : data [ vname ] [ i ] = int ( v )

if weight : data [ weight ] [ i ] = dataset.weight()

## add PDF values
for vname , func , obsvars in funcs :
obsvars.assign ( evt )
data [ vname ] [ i ] = func.getVal()

return data

Expand Down
27 changes: 16 additions & 11 deletions ostap/fitting/tests/test_fitting_ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ostap.utils.timing import timing
from builtins import range
from ostap.fitting.ds2numpy import ds2numpy
import ostap.fitting.models as Models
import ostap.fitting.roofit
import ROOT, random
# =============================================================================
Expand Down Expand Up @@ -43,10 +44,14 @@ def test_small_ds():
i.setVal ( f % 2 )
data.add ( varset )


ws = ds2numpy ( data, ['x', 'y' , 'i' ] )
g1 = Models.Gauss_pdf ( 'G1' , xvar = x , mean = 5 , sigma = 1 )
g2 = Models.Gauss_pdf ( 'G2' , xvar = y , mean = 5 , sigma = 1 )

ws = ds2numpy ( data, ['x', 'y' , 'i' ] , more_vars = { 'gaus1' : g1 ,
'gaus2' : g2 } )

print ( ws )


# =============================================================================
def test_small_ds_with_weights():

Expand Down Expand Up @@ -185,17 +190,17 @@ def test_large_ds_without_weights():
with timing ('Test small ds' , logger ) :
test_small_ds()

with timing ('Test small dataset with weights', logger ) :
test_small_ds_with_weights()
## with timing ('Test small dataset with weights', logger ) :
## test_small_ds_with_weights()

with timing ('Test large dataset with weights', logger ) :
test_ds_with_weights()
## with timing ('Test large dataset with weights', logger ) :
## test_ds_with_weights()

with timing ('Test large dataset with weights', logger ) :
test_large_ds_with_weights()
## with timing ('Test large dataset with weights', logger ) :
## test_large_ds_with_weights()

with timing ('Test large dataset without weights', logger ) :
test_large_ds_without_weights()
## with timing ('Test large dataset without weights', logger ) :
## test_large_ds_without_weights()


# =============================================================================
Expand Down

0 comments on commit 784f334

Please sign in to comment.