Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove GNU PBDS dependence in stats classes and improve ArgMinMax implementation #75

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 120 additions & 17 deletions cpp/csp/cppnodes/statsimpl.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#include <csp/engine/CppNode.h>
#include <csp/engine/WindowBuffer.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>

#include <functional>
#include <numeric>
#include <set>
#include <type_traits>

#ifndef __clang__
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#endif

namespace csp::cppnodes
{

Expand Down Expand Up @@ -1079,6 +1084,7 @@ class WeightedKurtosis
bool m_excess;
};

#ifndef __clang__
template<typename Comparator>
using ost = __gnu_pbds::tree<double, __gnu_pbds::null_type, Comparator, __gnu_pbds::rb_tree_tag,
__gnu_pbds::tree_order_statistics_node_update>;
Expand All @@ -1090,6 +1096,7 @@ void ost_erase( ost<Comparator> &t, double & v )
auto it = t.find_by_order( rank );
t.erase( it );
}
#endif

class Quantile
{
Expand Down Expand Up @@ -1132,7 +1139,11 @@ class Quantile

void remove( double x )
{
#ifndef __clang__
ost_erase( m_tree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
}

void reset()
Expand All @@ -1153,7 +1164,7 @@ class Quantile
int ct = ceil( target );

double qtl;

#ifndef __clang__
switch ( m_interpolation )
{
case LINEAR:
Expand Down Expand Up @@ -1199,13 +1210,65 @@ class Quantile
default:
break;
}

#else
auto it = m_tree.begin();
std::advance( it, ft );
switch ( m_interpolation )
{
case LINEAR:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( 1 - target + ft ) * lower + ( 1 - ct + target ) * higher;
}
break;
case LOWER:
qtl = *it;
break;
case HIGHER:
qtl = ( ft == ct ? *it : *++it );
break;
case MIDPOINT:
if( ft == target )
{
qtl = *it;
}
else
{
double lower = *it;
double higher = *++it;
qtl = ( higher+lower ) / 2;
}
break;
case NEAREST:
if( target - ft <= ct - target )
{
qtl = *it;
}
else
{
qtl = *++it;
}
break;
default:
break;
}
#endif
return qtl;
}

private:


#ifndef __clang__
ost<std::less_equal<double>> m_tree;
#else
std::multiset<double> m_tree;
#endif
std::vector<Dictionary::Data> m_quants;
int64_t m_interpolation;
};
Expand Down Expand Up @@ -1293,36 +1356,49 @@ class Rank
else
{
m_lastval = x;
#ifndef __clang__
if( m_method == MAX )
m_maxtree.insert( x );
else
m_mintree.insert( x );
#else
m_tree.insert( x );
#endif
}
}

void remove( double x )
{
if( likely( !isnan( x ) ) )
{
#ifndef __clang__
if( m_method == MAX )
ost_erase( m_maxtree, x );
else
ost_erase( m_mintree, x );
#else
m_tree.erase( m_tree.find( x ) );
#endif
}
}

void reset()
{
#ifndef __clang__
if( m_method == MAX )
m_maxtree.clear();
else
m_mintree.clear();
#else
m_tree.clear();
#endif
}

double compute() const
{
// Verify tree is not empty and lastValue is valid
// Last value can only ever be NaN if the "keep" nan option is used
#ifndef __clang__
if( likely( !isnan( m_lastval ) && ( ( m_method == MAX && m_maxtree.size() > 0 ) || m_mintree.size() > 0 ) ) )
{
switch( m_method )
Expand Down Expand Up @@ -1357,13 +1433,42 @@ class Rank
break;
}
}
#else
if( likely( !isnan( m_lastval ) && m_tree.size() > 0 ) )
{
switch( m_method )
{
case MIN:
{
return std::distance( m_tree.begin(), m_tree.find( m_lastval ) );
}
case MAX:
{
auto end_range = m_tree.equal_range( m_lastval ).second;
return std::distance( m_tree.begin(), std::prev( end_range ) );
}
case AVG:
{
auto range = m_tree.equal_range( m_lastval );
return std::distance( m_tree.begin(), range.first ) + ( double )std::distance( range.first, std::prev( range.second ) ) / 2;
}
default:
break;
}
}
#endif

return std::numeric_limits<double>::quiet_NaN();
}

private:

#ifndef __clang__
ost<std::less_equal<double>> m_mintree;
ost<std::greater_equal<double>> m_maxtree;
#else
std::multiset<double> m_tree;
#endif
double m_lastval;

int64_t m_method;
Expand All @@ -1374,18 +1479,18 @@ class ArgMinMax
{
public:
ArgMinMax() = default;
ArgMinMax( bool max, bool recent ) : m_max{max}, m_recent{recent} {};
ArgMinMax( bool max, bool recent ) : m_recent( recent ), m_monoQueue( max ) {}

ArgMinMax( ArgMinMax && rhs ) = default;
ArgMinMax & operator=( ArgMinMax && rhs ) = default;

// no copy as always
// no copy
ArgMinMax( const ArgMinMax & rhs ) = delete;
ArgMinMax & operator=( const ArgMinMax & rhs ) = delete;

void add( double x, DateTime t )
{
m_tree.insert( x );
m_monoQueue.add( x );
auto & it = m_treemap[x];
it.m_count++;
if( m_recent )
Expand All @@ -1396,7 +1501,7 @@ class ArgMinMax

void remove( double x )
{
ost_erase( m_tree, x );
m_monoQueue.remove( x );
auto it = m_treemap.find( x );
it->second.m_count--;
if( !it->second.m_count ) // don't let map grow unbounded
Expand All @@ -1407,20 +1512,19 @@ class ArgMinMax

void reset()
{
m_tree.clear();
m_monoQueue.reset();
m_treemap.clear();
}

DateTime compute()
{
if( m_tree.size() > 0 )
if( m_treemap.size() > 0 )
{
int target = m_max ? m_tree.size()-1 : 0;
double max_val = *m_tree.find_by_order( target );
double arg_val = m_monoQueue.compute();
if( m_recent )
return m_treemap[max_val].m_lasttime;
return m_treemap[arg_val].m_lasttime;
else
return m_treemap[max_val].m_alltimes[-1];
return m_treemap[arg_val].m_alltimes[-1];
}

return DateTime::fromNanoseconds( 0 );
Expand All @@ -1437,9 +1541,8 @@ class ArgMinMax
VariableSizeWindowBuffer<DateTime> m_alltimes;
};

bool m_max;
bool m_recent;
ost<std::less_equal<double>> m_tree;
AscendingMinima m_monoQueue;
std::map<double, TreeData> m_treemap;
};

Expand Down