Skip to content

Commit

Permalink
fix?
Browse files Browse the repository at this point in the history
  • Loading branch information
VanyaBelyaev committed Dec 5, 2023
1 parent 5900819 commit eda9f15
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 26 deletions.
6 changes: 3 additions & 3 deletions ostap/fitting/ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
dtypes = []
for v in vnames :
if v in doubles : dtypes.append ( ( v , np.float64 ) )
elif v in vategories : dtypes.append ( ( v , np.int64 ) )
elif v in categories : dtypes.append ( ( v , np.int64 ) )
if weight : dtypes.append ( ( weight , np.float64 ) )

## get data in batches
Expand Down Expand Up @@ -152,7 +152,7 @@ def ds2numpy ( dataset , var_lst , silent = True ) :
cpart = source.getCategoryBatches ( first , num )
for c in cpart :
cname = c.first.name
if cname in categroies :
if cname in categories :
part [ cname ] = c.second
del cpart

Expand Down Expand Up @@ -223,7 +223,7 @@ def ds2numpy ( dataset , var_lst , silent = False ) :
dtypes = []
for v in vnames :
if v in doubles : dtypes.append ( ( v , np.float64 ) )
elif v in vategories : dtypes.append ( ( v , np.int64 ) )
elif v in categories : dtypes.append ( ( v , np.int64 ) )
if weight : dtypes.append ( ( weight , np.float64 ) )


Expand Down
44 changes: 26 additions & 18 deletions ostap/fitting/tests/test_fitting_ds2numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,28 @@ def test_small_ds():

NN = 1000
# Создаем переменные RooRealVar
x = ROOT.RooRealVar("x", "x", 0, 10)
y = ROOT.RooRealVar("y", "y", 0, 10)

x = ROOT.RooRealVar ("x", "x" , 0, 10)
y = ROOT.RooRealVar ("y", "y" , 0, 10)
i = ROOT.RooCategory ("i", 'category' , 'odd' , 'even' )
varset = ROOT.RooArgSet ( x , y , i )

# Создаем RooDataSet
data = ROOT.RooDataSet("data", "data", ROOT.RooArgSet(x, y))
data = ROOT.RooDataSet("data", "data", varset )

# Заполняем датасет случайными данными
random_generator = ROOT.TRandom3(42) # устанавливаем seed
for _ in range( NN ):
for e in range( NN ):
x_val = random_generator.Uniform(0, 10)
y_val = random_generator.Uniform(0, 10)
x.setVal(x_val)
y.setVal(y_val)
data.add(ROOT.RooArgSet(x, y))
x.setVal ( x_val )
y.setVal ( y_val )
f = random.randint ( 0 , 100 )
i.setVal ( f % 2 )
data.add ( varset )


ws = ds2numpy ( data, ['x', 'y'] )

ws = ds2numpy ( data, ['x', 'y' , 'i' ] )

# =============================================================================
def test_small_ds_with_weights():
Expand All @@ -50,25 +54,29 @@ def test_small_ds_with_weights():

NN = 1000
# Создаем переменные RooRealVar
x = ROOT.RooRealVar("x", "x", 0, 10)
y = ROOT.RooRealVar("y", "y", 0, 10)
varset = ROOT.RooArgSet(x, y )
x = ROOT.RooRealVar("x", "x", 0, 10)
y = ROOT.RooRealVar("y", "y", 0, 10)
i = ROOT.RooCategory ("i", 'category' , 'a' , 'b' , 'c' , 'd' , 'e' )
varset = ROOT.RooArgSet(x, y , i )

# Создаем RooDataSet
data = ROOT.RooDataSet("data", "data", varset )

# Заполняем датасет случайными данными
random_generator = ROOT.TRandom3(42) # устанавливаем seed
for _ in range(NN):
for e in range(NN):
x_val = random_generator.Uniform(0, 10)
y_val = random_generator.Uniform(0, 10)
x.setVal(x_val)
y.setVal(y_val)
data.add( varset )
x.setVal ( x_val )
y.setVal ( y_val )
f = random.randint ( 0 , 100 )
i.setVal ( f % 5 )
data.add ( varset )

ds = data.makeWeighted('x+y')

ws = ds2numpy ( ds, ['x', 'y' ] )
ws = ds2numpy ( ds, ['x', 'y' , 'i' ] )
print ( ws )

# =============================================================================
def test_ds_with_weights():
Expand Down
23 changes: 18 additions & 5 deletions ostap/fitting/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def _racat_items_ ( cat ) :
ROOT.RooAbsCategory.__iter__ = _racat_items_

# =========================================================================
## Is given label (or index) defined fron this categroy?
## Is given label (or index) defined fron this category?
# @see RooAbsCategory.isValidIndex
# @see RooAbsCategory.isValidLabel
def _racat_contains_ ( cat , item ) :
Expand Down Expand Up @@ -839,7 +839,7 @@ def _racat_getitem_ ( cat , item ) :
return result.GetName()
elif isinstance ( item , ROOT.RooCatType ) :
result = cat.lookupType ( item , False )
if not result : raise IndexError("No '%s' categroy is defined!" % item )
if not result : raise IndexError("No '%s' category is defined!" % item )
return result

raise TypeError("No '%s' label/index is defined!" % item )
Expand Down Expand Up @@ -870,7 +870,7 @@ def _racat_items_ ( cat ) :
ROOT.RooAbsCategory.__iter__ = _racat_items_

# =========================================================================
## Is given label (or index) defined fron this categroy?
## Is given label (or index) defined fron this category?
# @see RooAbsCategory.hasIndex
# @see RooAbsCategory.hasLabel
def _racat_contains_ ( cat , item ) :
Expand Down Expand Up @@ -950,6 +950,16 @@ def _rcat_str_ ( cat ) :
"""Print RooCategory instance"""
return "'%s' : '%s'/%d" % ( cat.name , cat.getLabel() , cat.getIndex() )

# ==============================================================================
## set current label/index
def _rcat_setval_( cat , value ) :
"""Set current label/index
"""
if isinstance ( value , integer_types ) : return cat.setIndex ( value )
elif isinstance ( value , string_types ) : return cat.setLabel ( value )
raise TypeError ( 'Invalid value type!' )


ROOT.RooAbsCategory.items = _racat_items_
ROOT.RooAbsCategory.iteritems = _racat_items_
ROOT.RooAbsCategory.__contains__ = _racat_contains_
Expand All @@ -959,8 +969,11 @@ def _rcat_str_ ( cat ) :
ROOT.RooAbsCategory.names = _racat_labels_
ROOT.RooAbsCategory.keys = _racat_labels_
ROOT.RooCategory .__str__ = _rcat_str_
ROOT.RooCategory .__repr__ = _rcat_str_
ROOT.RooCategory .__int__ = lambda s : s.getCurrentIndex()
ROOT.RooCategory .__repr__ = _rcat_str_

ROOT.RooCategory .__int__ = lambda s : s.getCurrentIndex ()
ROOT.RooCategory .getVal = lambda s : s.getCurrentIndex ()
ROOT.RooCategory .setVal = _rcat_setval_

_new_methods_ += [
ROOT.RooAbsCategory.__iter__ ,
Expand Down

0 comments on commit eda9f15

Please sign in to comment.