From 01850af2c9021ff13b8b75aed43059aa4e7be303 Mon Sep 17 00:00:00 2001 From: Vanya Belyaev Date: Fri, 24 Nov 2023 11:31:53 +0100 Subject: [PATCH] 1. add test for `ostap.stats.ustat` module 1. change the interface for fuctions from `ostap.stats.ustat` module 1. change the interface for `Ostap::UStat` class --- ReleaseNotes/release_notes.md | 4 + ostap/stats/tests/test_stats_ustat.py | 126 ++++++++++++++++++++++++++ ostap/stats/ustat.py | 79 ++++++++-------- source/include/Ostap/UStat.h | 10 +- source/src/UStat.cpp | 10 +- 5 files changed, 183 insertions(+), 46 deletions(-) create mode 100644 ostap/stats/tests/test_stats_ustat.py diff --git a/ReleaseNotes/release_notes.md b/ReleaseNotes/release_notes.md index 8d1d85be..9549afc8 100644 --- a/ReleaseNotes/release_notes.md +++ b/ReleaseNotes/release_notes.md @@ -5,9 +5,13 @@ 1. Further optimisation in `Ostap::Math::ChebyshedSum` 1. add new test `ostap/math/tests/test_math.poly.py` 1. Reduce usage of `Ostap::Utils::Iterator` + 1. add test for `ostap.stats.ustat` module ## Backward incompatible: + 1. change the interface for fuctions from `ostap.stats.ustat` module + 1. change the interface for `Ostap::UStat` class + ## Bug fixes: diff --git a/ostap/stats/tests/test_stats_ustat.py b/ostap/stats/tests/test_stats_ustat.py new file mode 100644 index 00000000..b643950a --- /dev/null +++ b/ostap/stats/tests/test_stats_ustat.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# ============================================================================= +# @file ostap/stats/tests/test_stats_ustat.py +# Test uStatistics for Goodness-Of-Fit tests +# Copyright (c) Ostap developpers. +# ============================================================================= +""" Test uStatistics for goodness-of-fit tests +""" +# ============================================================================= +import ostap.stats.ustat as uStat +import ostap.fitting.models as Models +from ostap.utils.timing import timing +from ostap.core.pyrouts import SE +from ostap.plotting.canvas import use_canvas +import ROOT, random, math +# ============================================================================= +from ostap.logger.logger import getLogger +if '__main__' == __name__ : logger = getLogger ( 'tests_stats_ustat' ) +else : logger = getLogger ( __name__ ) +# ============================================================================== + + +histos = set() + +def test_stats_ustat_G2D () : + + logger = getLogger ( "test_stats_ustat_G2D" ) + + x = ROOT.RooRealVar ( 'x' , 'x-variable' , 0 , 10 ) + y = ROOT.RooRealVar ( 'y' , 'y-variable' , 0 , 10 ) + + pdf = Models.Gauss2D_pdf ( 'G2D' , x , y , + muX = ( 5 , 4 , 6 ) , + muY = ( 5 , 4 , 6 ) , + sigmaX = ( 1 , 0.5 , 1.5 ) , + sigmaY = ( 2 , 1.5 , 2.5 ) , + theta = ( math.pi/4 , math.pi/8 , math.pi/2 ) ) + + + + for n in ( 100 , 200 , 500 , 1000 ) : ## , 5000 , 10000 ) : + + + title = "2D: N=%4d test" % n + with timing ( title , logger = logger ) , use_canvas ( title , wait = 5 ) : + + pdf.muX = 5 + pdf.muY = 5 + pdf.sigmaX = 1 + pdf.sigmaY = 2 + pdf.theta = math.pi/4 + + + data = pdf.generate ( n ) + + pdf.fitTo ( data , silent = True ) + + t , h , r = uStat.uPlot ( pdf , data ) + + histos.add ( h ) + h.draw() + + data.clear() + del data + +def test_stats_ustat_G3D () : + + logger = getLogger ( "test_stats_ustat_G3D" ) + + x = ROOT.RooRealVar ( 'x1' , 'x-variable' , 0 , 10 ) + y = ROOT.RooRealVar ( 'y2' , 'y-variable' , 0 , 10 ) + z = ROOT.RooRealVar ( 'z1' , 'z-variable' , 0 , 10 ) + + pdf = Models.Gauss3D_pdf ( 'G3D' , x , y , z , + muX = ( 5 , 4 , 6 ) , + muY = ( 5 , 4 , 6 ) , + muZ = ( 5 , 4 , 6 ) , + sigmaX = ( 1 , 0.5 , 1.5 ) , + sigmaY = ( 2 , 1.5 , 2.5 ) , + sigmaZ = ( 3 , 2.5 , 3.5 ) , + phi = ( math.pi/4 , math.pi/8 , math.pi/2 ) , + theta = ( math.pi/4 , math.pi/8 , math.pi/2 ) , + psi = ( math.pi/4 , math.pi/8 , math.pi/2 ) ) + + for n in ( 100 , 200 , 500 , 1000 ) : ## , 5000 , 10000 ) : + + + title = "3D: N=%4d test" % n + with timing ( title , logger = logger ) , use_canvas ( title , wait = 5 ) : + + pdf.muX = 5 + pdf.muY = 5 + pdf.muZ = 5 + pdf.sigmaX = 1 + pdf.sigmaY = 2 + pdf.sigmaZ = 3 + pdf.phi = math.pi/4 + pdf.theta = math.pi/4 + pdf.psi = math.pi/4 + + data = pdf.generate ( n ) + + pdf.fitTo ( data , silent = True ) + + t , h , r = uStat.uPlot ( pdf , data ) + + histos.add ( h ) + h.draw() + + data.clear() + del data + + + + +# ============================================================================= +if '__main__' == __name__ : + + test_stats_ustat_G2D () + test_stats_ustat_G3D () + +# ============================================================================= +## The END +# ============================================================================= + diff --git a/ostap/stats/ustat.py b/ostap/stats/ustat.py index f5ad9b15..8cf8d120 100755 --- a/ostap/stats/ustat.py +++ b/ostap/stats/ustat.py @@ -13,12 +13,12 @@ # # >>> pdf = ... ## pdf # >>> data = ... ## dataset -# >>> pdf.fitTo( data , ... ) ## fit it! +# >>> pdf.fitTo ( data , ... ) ## fit it! # # >>> import ostap.stats.ustat as uStat # # >>> r,histo = uStat.uPlot ( pdf , data ) -# >>> print r ## print fit results +# >>> print ( r ) ## print fit results # >>> histo.Draw() ## plot the results # # @endcode @@ -30,7 +30,7 @@ # @date 2011-09-21 # # ============================================================================ -""" ``U-statistics'' useful for ``Goodness-Of-Fit'' tests +""" `U-statistics' useful for `Goodness-Of-Fit' tests This is a simple translation of the original C++ lines written by Greig Cowan into Python @@ -51,17 +51,16 @@ >>> histo.Draw() ## plot the results """ # ============================================================================ -from __future__ import print_function __author__ = "Vanya BELYAEV Ivan.Belyaev@cern.ch" __date__ = "2010-09-21" -__version__ = "$Revision$" +__version__ = "$Revision:$" # ============================================================================ __all__ = ( "uPlot" , ## make plot of U-statistics "uCalc" , ## calculate U-statistics ) # ============================================================================ -from ostap.core.core import cpp, Ostap, hID +from ostap.core.core import Ostap, hID import ostap.histos.histos import ROOT, math, ctypes # ============================================================================= @@ -83,22 +82,35 @@ # @see Analysis::UStat::calculate # @date 2011-09-21 def uCalc ( pdf , - args , data , - histo , + args = None , + histo = None , silent = False ) : """Calculate U-statistics """ - import sys + if not isinstance ( pdf , ROOT.RooAbsPdf ) or not pdf : + from ostap.fitting.pdfbasic import APDF1 + assert pdf and isinstance ( pdf , APDF1 ) , "Invalid type of `pdf'!" + pdf = pdf.pdf + + if not args : args = pdf.getObservables ( data ) + if not histo : histo = ROOT.nullptr + + ## tStat = ctypes.c_double (-1) sc = Ostap.UStat.calculate ( pdf , data , - histo , tStat , + histo , args ) + if sc.isFailure() : + logger.error ( "Error from Ostap::UStat::Calculate %s" % sc ) + + if not histo : histo = None + tStat = float ( tStat.value ) - return histo, tStat + return tStat, histo # ============================================================================= ## make the plot of U-statistics @@ -112,8 +124,8 @@ def uCalc ( pdf , # # >>> import ostap.stats.ustat as uStat # -# >>> r,histo = uStat.uPlot ( pdf , data ) -# >>> print r ## print fit results +# >>> t , res , histo = uStat.uPlot ( pdf , data ) +# >>> print ( res ) ## print fit results # >>> histo.Draw() ## plot the results # # @endcode @@ -136,24 +148,21 @@ def uPlot ( pdf , >>> import ostap.stats.ustat as uStat - >>> r,histo = uStat.uPlot ( pdf , data ) - >>> print r ## print fit results + >>> t, res , histo = uStat.uPlot ( pdf , data ) + >>> print ( res ) ## print fit results >>> histo.Draw() ## plot the results """ - + if not bins or bins <= 0 : nEntries = float(data.numEntries()) - bins = 10 - for nbins in ( 100 , - 50 , - 40 , - 25 , - 20 , - 16 , - 10 , - 8 , - 5 ) : - if nEntries/nbins < 100 : continue + bins = 10 + for nbins in ( 1000 , 500 , + 200 , 100 , + 50 , 40 , + 25 , 20 , + 16 , 10 , + 8 , 5 ) : + if nEntries/float(nbins) < 100 : continue bins = nbins break @@ -161,14 +170,12 @@ def uPlot ( pdf , histo.Sumw2 ( ) histo.SetMinimum ( 0 ) - if not args : args = pdf.getObservables ( data ) - - h,tStat = uCalc ( pdf , - args , - data , - histo , - silent ) - + tStat , hh = uCalc ( pdf , + data , + args , + histo , + silent ) + res = histo.Fit ( 'pol0' , 'SLQ0+' ) func = histo.GetFunction ( 'pol0' ) if func : @@ -176,7 +183,7 @@ def uPlot ( pdf , func.SetLineColor ( 2 ) func.ResetBit ( 1 << 9 ) - return res , histo, float(tStat) + return float ( tStat ) , histo , res # =========================================================================== diff --git a/source/include/Ostap/UStat.h b/source/include/Ostap/UStat.h index 1766f3a2..04196697 100644 --- a/source/include/Ostap/UStat.h +++ b/source/include/Ostap/UStat.h @@ -45,11 +45,11 @@ namespace Ostap * @param args (input) the arguments */ static Ostap::StatusCode calculate - ( const RooAbsPdf& pdf , - const RooDataSet& data , - TH1& hist , - double& tStat , - RooArgSet * args = 0 ) ; + ( const RooAbsPdf& pdf , + const RooDataSet& data , + double& tStat , + TH1* hist = nullptr , + RooArgSet* args = nullptr ) ; // ======================================================================== }; // ========================================================================== diff --git a/source/src/UStat.cpp b/source/src/UStat.cpp index 8f32d44e..e421d829 100644 --- a/source/src/UStat.cpp +++ b/source/src/UStat.cpp @@ -118,8 +118,8 @@ namespace Ostap::StatusCode Ostap::UStat::calculate ( const RooAbsPdf& pdf , const RooDataSet& data , - TH1& hist , double& tStat , + TH1* hist , RooArgSet* args ) { // @@ -187,7 +187,7 @@ Ostap::StatusCode Ostap::UStat::calculate { return Ostap::StatusCode ( InvalidItem2 ) ; } // RETURN // const double distance = getDistance ( event_i.get() , event_j.get() ) ; - if ( 0 > distance ) { return Ostap::StatusCode( InvalidDist ) ; } // RETURN + if ( 0 > distance ) { return Ostap::StatusCode ( InvalidDist ) ; } // RETURN // if ( 0 == j || distance < min_distance ) { min_distance = distance ; } @@ -199,7 +199,7 @@ Ostap::StatusCode Ostap::UStat::calculate // const double value = std::exp ( -val1 * num * pdfValue ) ; // - hist.Fill ( value ) ; + if ( hist ) { hist -> Fill ( value ) ; } // tstat.push_back ( value ) ; // @@ -208,7 +208,7 @@ Ostap::StatusCode Ostap::UStat::calculate // // calculate T-statistics // - std::sort ( tstat.begin() , tstat.end() ) ; + std::stable_sort ( tstat.begin() , tstat.end() ) ; double tS = 0 ; double nD = tstat.size() ; for ( TStat::const_iterator t = tstat.begin() ; tstat.end() != t ; ++t ) @@ -224,5 +224,5 @@ Ostap::StatusCode Ostap::UStat::calculate return Ostap::StatusCode::SUCCESS ; } // ============================================================================ -// The END +// The END // ============================================================================