Skip to content

Commit

Permalink
Remove GNU PBDS dependence in stats classes Quantile/Rank/ArgMinMax; …
Browse files Browse the repository at this point in the history
…improved ArgMinMax implementation; using inefficient impl for Quantile/Rank
  • Loading branch information
AdamGlustein committed Feb 13, 2024
1 parent f4b8ac9 commit 2dad4fb
Showing 1 changed file with 120 additions and 17 deletions.
137 changes: 120 additions & 17 deletions cpp/csp/cppnodes/statsimpl.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#include <csp/engine/CppNode.h>
#include <csp/engine/WindowBuffer.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

#include <functional>
#include <numeric>
#include <set>
#include <type_traits>

#ifdef __GNUC__
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#endif

namespace csp::cppnodes
{

Expand Down Expand Up @@ -1079,6 +1084,7 @@ class WeightedKurtosis
bool m_excess;
};

#ifdef __GNUC__
template<typename Comparator>
using ost = __gnu_pbds::tree<double, __gnu_pbds::null_type, Comparator, __gnu_pbds::rb_tree_tag,
__gnu_pbds::tree_order_statistics_node_update>;
Expand All @@ -1090,6 +1096,7 @@ void ost_erase( ost<Comparator> &t, double & v )
auto it = t.find_by_order( rank );
t.erase( it );
}
#endif

class Quantile
{
Expand Down Expand Up @@ -1132,7 +1139,11 @@ class Quantile

void remove( double x )
{
#ifdef __GNUC__
ost_erase( m_tree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
}

void reset()
Expand All @@ -1153,7 +1164,7 @@ class Quantile
int ct = ceil( target );

double qtl;

#ifdef __GNUC__
switch ( m_interpolation )
{
case LINEAR:
Expand Down Expand Up @@ -1199,13 +1210,65 @@ class Quantile
default:
break;
}

#else
auto it = m_tree.begin();
std::advance( it, ft );
switch ( m_interpolation )
{
case LINEAR:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *it;
break;
case HIGHER:
qtl = ( ft == ct ? *it : *++it );
break;
case MIDPOINT:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( higher+lower ) / 2;
}
break;
case NEAREST:
if( target - ft <= ct - target )
{
qtl = *it;
}
else
{
qtl = *++it;
}
break;
default:
break;
}
#endif
return qtl;
}

private:


#ifdef __GNUC__
ost<std::less_equal<double>> m_tree;
#else
std::multiset<double> m_tree;
#endif
std::vector<Dictionary::Data> m_quants;
int64_t m_interpolation;
};
Expand Down Expand Up @@ -1293,36 +1356,49 @@ class Rank
else
{
m_lastval = x;
#ifdef __GNUC__
if( m_method == MAX )
m_maxtree.insert( x );
else
m_mintree.insert( x );
#else
m_tree.insert( x );
#endif
}
}

void remove( double x )
{
if( likely( !isnan( x ) ) )
{
#ifdef __GNUC__
if( m_method == MAX )
ost_erase( m_maxtree, x );
else
ost_erase( m_mintree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
}
}

void reset()
{
#ifdef __GNUC__
if( m_method == MAX )
m_maxtree.clear();
else
m_mintree.clear();
#else
m_tree.clear();
#endif
}

double compute() const
{
// Verify tree is not empty and lastValue is valid
// Last value can only ever be NaN if the "keep" nan option is used
#ifdef __GNUC__
if( likely( !isnan( m_lastval ) && ( ( m_method == MAX && m_maxtree.size() > 0 ) || m_mintree.size() > 0 ) ) )
{
switch( m_method )
Expand Down Expand Up @@ -1357,13 +1433,42 @@ class Rank
break;
}
}
#else
if( likely( !isnan( m_lastval ) && m_tree.size() > 0 ) )
{
switch( m_method )
{
case MIN:
{
return std::distance( m_tree.begin(), m_tree.find( m_lastval ) );
}
case MAX:
{
auto end_range = m_tree.equal_range( m_lastval ).second;
return std::distance( m_tree.begin(), std::prev( end_range ) );
}
case AVG:
{
auto range = m_tree.equal_range( m_lastval );
return std::distance( m_tree.begin(), range.first ) + ( double )std::distance( range.first, std::prev( range.second ) ) / 2;
}
default:
break;
}
}
#endif

return std::numeric_limits<double>::quiet_NaN();
}

private:

#ifdef __GNUC__
ost<std::less_equal<double>> m_mintree;
ost<std::greater_equal<double>> m_maxtree;
#else
std::multiset<double> m_tree;
#endif
double m_lastval;

int64_t m_method;
Expand All @@ -1374,18 +1479,18 @@ class ArgMinMax
{
public:
ArgMinMax() = default;
ArgMinMax( bool max, bool recent ) : m_max{max}, m_recent{recent} {};
ArgMinMax( bool max, bool recent ) : m_recent( recent ), m_monoQueue( max ) {}

ArgMinMax( ArgMinMax && rhs ) = default;
ArgMinMax & operator=( ArgMinMax && rhs ) = default;

// no copy as always
// no copy
ArgMinMax( const ArgMinMax & rhs ) = delete;
ArgMinMax & operator=( const ArgMinMax & rhs ) = delete;

void add( double x, DateTime t )
{
m_tree.insert( x );
m_monoQueue.add( x );
auto & it = m_treemap[x];
it.m_count++;
if( m_recent )
Expand All @@ -1396,7 +1501,7 @@ class ArgMinMax

void remove( double x )
{
ost_erase( m_tree, x );
m_monoQueue.remove( x );
auto it = m_treemap.find( x );
it->second.m_count--;
if( !it->second.m_count ) // don't let map grow unbounded
Expand All @@ -1407,20 +1512,19 @@ class ArgMinMax

void reset()
{
m_tree.clear();
m_monoQueue.reset();
m_treemap.clear();
}

DateTime compute()
{
if( m_tree.size() > 0 )
if( m_treemap.size() > 0 )
{
int target = m_max ? m_tree.size()-1 : 0;
double max_val = *m_tree.find_by_order( target );
double arg_val = m_monoQueue.compute();
if( m_recent )
return m_treemap[max_val].m_lasttime;
return m_treemap[arg_val].m_lasttime;
else
return m_treemap[max_val].m_alltimes[-1];
return m_treemap[arg_val].m_alltimes[-1];
}

return DateTime::fromNanoseconds( 0 );
Expand All @@ -1437,9 +1541,8 @@ class ArgMinMax
VariableSizeWindowBuffer<DateTime> m_alltimes;
};

bool m_max;
bool m_recent;
ost<std::less_equal<double>> m_tree;
AscendingMinima m_monoQueue;
std::map<double, TreeData> m_treemap;
};

Expand Down

0 comments on commit 2dad4fb

Please sign in to comment.