Skip to content

Commit

Permalink
extend 'more_vars' argument
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Dec 20, 2023
1 parent 966875e commit f7bec16
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions ostap/fitting/ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,29 @@ def ds2numpy ( dataset , var_lst , silent = True , more_vars = {} ) :
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )

elif more_vars and all ( ( isinstance ( v , ( ROOT.RooAbsReal , AFUN1 ) ) for v in more_vars ) ) :
for v in more_vars :
if isinstance ( v , AFUN1 ) : absreal = v.fun
else : absreal = v
obsvars = absreal.getObservables ( dataset )
item = v.name , absreal , obsvars
funcs.append ( item )
elif more_vars :
for name, var in more_vars :
if isinstance ( var , AFUN1 ) : absreal = var.fun
elif isinstance ( var , ROOT.RooAbsReal ) : absreal = var
else :
raise TypeError ( "Invalid content of 'more_vars'!" )
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )

## 3) reduce dataset if only a small subset of variables is requested
nvars = len ( dataset.get() )
if 2 * len ( vnames ) <= nvars :
with useStorage ( ROOT.RooAbsData.Vector ) :
dstmp = dataset.subset ( vnames )
result = ds2numpy ( dstmp , vnames )
result = ds2numpy ( dstmp , vnames , more_vars = more_vars )
dstmp.erase()
del dstmp
return result
Expand Down Expand Up @@ -239,12 +255,28 @@ def ds2numpy ( dataset , var_lst , silent = False , more_vars = {} ) :
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )
elif more_vars and all ( ( isinstance ( v , ( ROOT.RooAbsReal , AFUN1 ) ) for v in more_vars ) ) :
for v in more_vars :
if isinstance ( v , AFUN1 ) : absreal = v.fun
else : absreal = v
obsvars = absreal.getObservables ( dataset )
item = v.name , absreal , obsvars
funcs.append ( item )
elif more_vars :
for name, var in more_vars :
if isinstance ( var , AFUN1 ) : absreal = var.fun
elif isinstance ( var , ROOT.RooAbsReal ) : absreal = var
else :
raise TypeError ( "Invalid content of 'more_vars'!" )
obsvars = absreal.getObservables ( dataset )
item = name , absreal , obsvars
funcs.append ( item )

## 3) reduce dataset if only a small subset of variables is requested
nvars = len ( dataset.get() )
if 2 * len ( vnames ) <= nvars :
dstmp = dataset.subset ( vnames )
result = ds2numpy ( dstmp , vnames )
result = ds2numpy ( dstmp , vnames , more_vars = more_vars )
dstmp.erase()
del dstmp
return result
Expand Down

0 comments on commit f7bec16

Please sign in to comment.