-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e77468d
commit 784f334
Showing
3 changed files
with
73 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
# @see RooAbsData | ||
# @see RooDataSet | ||
# @author Artem Egorychev [email protected] | ||
# @date 2023-12-04 | ||
# @date 2023-12-12 | ||
# ============================================================================= | ||
""" Helper module to convert RooDtaSet to numpy array | ||
Module with decoration for RooAbsData and related RooFit classes | ||
|
@@ -21,9 +21,11 @@ | |
) | ||
# ============================================================================= | ||
from ostap.core.meta_info import root_info | ||
from ostap.core.ostap_types import string_types | ||
from ostap.core.ostap_types import string_types, dictlike_types | ||
from ostap.core.core import loop_items | ||
from ostap.utils.utils import split_range | ||
from ostap.fitting.dataset import useStorage | ||
from ostap.fitting.funbasic import AFUN1 | ||
from ostap.utils.progress_bar import progress_bar | ||
import ostap.fitting.roocollections | ||
import ROOT | ||
|
@@ -59,7 +61,7 @@ | |
# @see ROOT.RooAbsDataStore.getCategoryBatches | ||
# @see ROOT.RooAbsDataStore.getWeightBatche | ||
# @attention conversion to ROOT.RooVectorDataStore is used! | ||
def ds2numpy ( dataset , var_lst , silent = True ) : | ||
def ds2numpy ( dataset , var_lst , silent = True , more_vars = {} ) : | ||
""" Convert dataset into numpy array using `ROOT.RooAbsData` iterface | ||
- see ROOT.RooAbsData.getBatches | ||
- see ROOT.RooAbsData.getCategoryBatches | ||
|
@@ -70,7 +72,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) : | |
- attention: Conversion to `ROOT.RooVectorDataStore` is used! | ||
""" | ||
|
||
## 1) check that all variables are present in dataset | ||
## 1) get names of all requested variables | ||
if all ( isinstance ( v , string_types ) for v in var_lst ) : | ||
vnames = [ v for v in var_lst ] | ||
elif all ( isinstance ( v , ROOT.RooAbsArg ) for v in var_lst ) : | ||
|
@@ -81,6 +83,19 @@ def ds2numpy ( dataset , var_lst , silent = True ) : | |
## 2) check that all variables are present in the dataset | ||
assert all ( ( v in dataset ) for v in var_lst ) , 'Not all variables are in dataset!' | ||
|
||
funcs = [] | ||
if more_vars and isinstance ( more_vars , dictlike_types ) : | ||
for name , fun in loop_items ( more_vars ) : | ||
if isinstance ( fun , AFUN1 ) : | ||
absreal = fun.fun | ||
elif isinstance( fun , ROOT.RooAbsPdf ) : absreal = fun | ||
elif isinstance( fun , ROOT.RooAbsReal ) : absreal = fun | ||
else : | ||
raise TypeError ( "Invald type ofun/pdf" ) | ||
obsvars = absreal.getObservables ( dataset ) | ||
item = name , absreal , obsvars | ||
funcs.append ( item ) | ||
|
||
## 3) reduce dataset if only a small subset of variables is requested | ||
nvars = len ( dataset.get() ) | ||
if 2 * len ( vnames ) <= nvars : | ||
|
@@ -119,6 +134,8 @@ def ds2numpy ( dataset , var_lst , silent = True ) : | |
for v in vnames : | ||
if v in doubles : dtypes.append ( ( v , np.float64 ) ) | ||
elif v in categories : dtypes.append ( ( v , np.int64 ) ) | ||
|
||
for f in funcs : dtypes.append ( ( f[0] , np.float64 ) ) | ||
if weight : dtypes.append ( ( weight , np.float64 ) ) | ||
|
||
## get data in batches | ||
|
@@ -166,6 +183,13 @@ def ds2numpy ( dataset , var_lst , silent = True ) : | |
else : | ||
data = np.concatenate ( [ data , part ] ) | ||
del part | ||
|
||
## add PDF values | ||
if funcs : | ||
for i, entry in enumerate ( source ) : | ||
for vname , func , obsvars in funcs : | ||
obsvars.assign ( entry ) | ||
data [ vname ] [ i ] = func.getVal() | ||
|
||
if delsource : | ||
source.reset() | ||
|
@@ -188,7 +212,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) : | |
|
||
# ========================================================================= | ||
## Convert dataset into numpy array using (slow) explicit loops | ||
def ds2numpy ( dataset , var_lst , silent = False ) : | ||
def ds2numpy ( dataset , var_lst , silent = False , more_vars = {} ) : | ||
""" Convert dataset into numpy array using (slow) explicit loops | ||
""" | ||
|
||
|
@@ -202,7 +226,20 @@ def ds2numpy ( dataset , var_lst , silent = False ) : | |
|
||
## 2) check that all variables are present in the dataset | ||
assert all ( ( v in dataset ) for v in var_lst ) , 'Not all variables are in dataset!' | ||
|
||
|
||
funcs = [] | ||
if more_vars and isinstance ( more_vars , dictlike_types ) : | ||
for name , fun in loop_items ( more_vars ) : | ||
if isinstance ( fun , AFUN1 ) : | ||
absreal = fun.fun | ||
elif isinstance( fun , ROOT.RooAbsPdf ) : absreal = fun | ||
elif isinstance( fun , ROOT.RooAbsReal ) : absreal = fun | ||
else : | ||
raise TypeError ( "Invald type ofun/pdf" ) | ||
obsvars = absreal.getObservables ( dataset ) | ||
item = name , absreal , obsvars | ||
funcs.append ( item ) | ||
|
||
## 3) reduce dataset if only a small subset of variables is requested | ||
nvars = len ( dataset.get() ) | ||
if 2 * len ( vnames ) <= nvars : | ||
|
@@ -224,6 +261,7 @@ def ds2numpy ( dataset , var_lst , silent = False ) : | |
for v in vnames : | ||
if v in doubles : dtypes.append ( ( v , np.float64 ) ) | ||
elif v in categories : dtypes.append ( ( v , np.int64 ) ) | ||
for f in funcs : dtypes.append ( ( f[0] , np.float64 ) ) | ||
if weight : dtypes.append ( ( weight , np.float64 ) ) | ||
|
||
|
||
|
@@ -239,6 +277,11 @@ def ds2numpy ( dataset , var_lst , silent = False ) : | |
elif vname in categories : data [ vname ] [ i ] = int ( v ) | ||
|
||
if weight : data [ weight ] [ i ] = dataset.weight() | ||
|
||
## add PDF values | ||
for vname , func , obsvars in funcs : | ||
obsvars.assign ( evt ) | ||
data [ vname ] [ i ] = func.getVal() | ||
|
||
return data | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters