Skip to content

Commit

Permalink
Add STDEV() aggregate function (#1553)
Browse files Browse the repository at this point in the history
Add a new aggregate function `STDEV(X)` which computes the (sample) standard deviation, such that a user will not have to repetitively type `math:sqrt(sum(math:pow((X - avg(X)), 2)) / (count(*) - 1))`. This is not part of the SPARQL standard, but also doesn't cause any conflicts.
  • Loading branch information
ullingerc authored Nov 13, 2024
1 parent 1bcfeeb commit 1a2fe17
Show file tree
Hide file tree
Showing 21 changed files with 2,749 additions and 2,420 deletions.
3 changes: 3 additions & 0 deletions src/engine/GroupBy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionGenerators.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "global/RuntimeParameters.h"
#include "index/Index.h"
#include "index/IndexImpl.h"
Expand Down Expand Up @@ -1026,6 +1027,8 @@ GroupBy::isSupportedAggregate(sparqlExpression::SparqlExpression* expr) {
if (auto val = dynamic_cast<GroupConcatExpression*>(expr)) {
return H{GROUP_CONCAT, val->getSeparator()};
}
// NOTE: The STDEV function is not suitable for lazy and hash map
// optimizations.
if (dynamic_cast<SampleExpression*>(expr)) return H{SAMPLE};

// `expr` is an unsupported aggregate
Expand Down
6 changes: 6 additions & 0 deletions src/engine/sparqlExpressions/AggregateExpression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "engine/sparqlExpressions/AggregateExpression.h"

#include "engine/sparqlExpressions/GroupConcatExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression::detail {

Expand Down Expand Up @@ -180,6 +181,11 @@ AggregateExpression<AggregateOperation, FinalOperation>::getVariableForCount()
// Explicit instantiation for the AVG expression.
template class AggregateExpression<AvgOperation, decltype(avgFinalOperation)>;

// Explicit instantiation for the STDEV expression.
template class AggregateExpression<AvgOperation, decltype(stdevFinalOperation)>;
template class DeviationAggExpression<AvgOperation,
decltype(stdevFinalOperation)>;

// Explicit instantiations for the other aggregate expressions.
#define INSTANTIATE_AGG_EXP(Function, ValueGetter) \
template class AggregateExpression< \
Expand Down
1 change: 1 addition & 0 deletions src/engine/sparqlExpressions/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(sparqlExpressions
SampleExpression.cpp
RelationalExpressions.cpp
AggregateExpression.cpp
StdevExpression.cpp
RegexExpression.cpp
NumericUnaryExpressions.cpp
NumericBinaryExpressions.cpp
Expand Down
74 changes: 74 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#include "engine/sparqlExpressions/StdevExpression.h"

namespace sparqlExpression {

namespace detail {

// _____________________________________________________________________________
ExpressionResult DeviationExpression::evaluate(
EvaluationContext* context) const {
auto impl = [context](SingleExpressionResult auto&& el) -> ExpressionResult {
// Prepare space for result
VectorWithMemoryLimit<IdOrLiteralOrIri> exprResult{context->_allocator};
std::fill_n(std::back_inserter(exprResult), context->size(),
IdOrLiteralOrIri{Id::makeUndefined()});
bool undef = false;

auto devImpl = [&undef, &exprResult, context](auto generator) {
double sum = 0.0;
// Intermediate storage of the results returned from the child
// expression
VectorWithMemoryLimit<double> childResults{context->_allocator};

// Collect values as doubles
for (auto& inp : generator) {
const auto& n = detail::NumericValueGetter{}(std::move(inp), context);
auto v = std::visit(
[]<typename T>(T&& value) -> std::optional<double> {
if constexpr (ad_utility::isSimilar<T, double> ||
ad_utility::isSimilar<T, int64_t>) {
return static_cast<double>(value);
} else {
return std::nullopt;
}
},
n);
if (v.has_value()) {
childResults.push_back(v.value());
sum += v.value();
} else {
// There is a non-numeric value in the input. Therefore the entire
// result will be undef.
undef = true;
return;
}
context->cancellationHandle_->throwIfCancelled();
}

// Calculate squared deviation and save for result
double avg = sum / static_cast<double>(context->size());
for (size_t i = 0; i < childResults.size(); i++) {
exprResult.at(i) = IdOrLiteralOrIri{
ValueId::makeFromDouble(std::pow(childResults.at(i) - avg, 2))};
}
};

auto generator =
detail::makeGenerator(AD_FWD(el), context->size(), context);
devImpl(std::move(generator));

if (undef) {
return IdOrLiteralOrIri{Id::makeUndefined()};
}
return exprResult;
};
auto childRes = child_->evaluate(context);
return std::visit(impl, std::move(childRes));
};

} // namespace detail
} // namespace sparqlExpression
100 changes: 100 additions & 0 deletions src/engine/sparqlExpressions/StdevExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2024, University of Freiburg,
// Chair of Algorithms and Data Structures.
// Author: Christoph Ullinger <[email protected]>

#pragma once

#include <cmath>
#include <functional>
#include <memory>
#include <variant>

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/LiteralExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/SparqlExpression.h"
#include "engine/sparqlExpressions/SparqlExpressionTypes.h"
#include "engine/sparqlExpressions/SparqlExpressionValueGetters.h"
#include "global/ValueId.h"

namespace sparqlExpression {

namespace detail {

/// The STDEV Expression

// Helper expression: The individual deviation squares. A DeviationExpression
// over X corresponds to the value (X - AVG(X))^2.
class DeviationExpression : public SparqlExpression {
private:
Ptr child_;

public:
DeviationExpression(Ptr&& child) : child_{std::move(child)} {}

// __________________________________________________________________________
ExpressionResult evaluate(EvaluationContext* context) const override;

// __________________________________________________________________________
AggregateStatus isAggregate() const override {
return SparqlExpression::AggregateStatus::NoAggregate;
}

// __________________________________________________________________________
[[nodiscard]] string getCacheKey(
const VariableToColumnMap& varColMap) const override {
return absl::StrCat("[ SQ.DEVIATION ]", child_->getCacheKey(varColMap));
}

private:
// _________________________________________________________________________
std::span<SparqlExpression::Ptr> childrenImpl() override {
return {&child_, 1};
}
};

// Separate subclass of AggregateOperation, that replaces its child with a
// DeviationExpression of this child. Everything else is left untouched.
template <typename AggregateOperation,
typename FinalOperation = decltype(identity)>
class DeviationAggExpression
: public AggregateExpression<AggregateOperation, FinalOperation> {
public:
// __________________________________________________________________________
DeviationAggExpression(bool distinct, SparqlExpression::Ptr&& child,
AggregateOperation aggregateOp = AggregateOperation{})
: AggregateExpression<AggregateOperation, FinalOperation>(
distinct, std::make_unique<DeviationExpression>(std::move(child)),
aggregateOp){};
};

// The final operation for dividing by degrees of freedom and calculation square
// root after summing up the squared deviation
inline auto stdevFinalOperation = [](const NumericValue& aggregation,
size_t numElements) {
auto divAndRoot = [](double value, double degreesOfFreedom) {
if (degreesOfFreedom <= 0) {
return 0.0;
} else {
return std::sqrt(value / degreesOfFreedom);
}
};
return makeNumericExpressionForAggregate<decltype(divAndRoot)>()(
aggregation, NumericValue{static_cast<double>(numElements) - 1});
};

// The actual Standard Deviation Expression
// Mind the explicit instantiation of StdevExpressionBase in
// AggregateExpression.cpp
using StdevExpressionBase =
DeviationAggExpression<AvgOperation, decltype(stdevFinalOperation)>;
class StdevExpression : public StdevExpressionBase {
using StdevExpressionBase::StdevExpressionBase;
ValueId resultForEmptyGroup() const override { return Id::makeFromDouble(0); }
};

} // namespace detail

using detail::StdevExpression;

} // namespace sparqlExpression
3 changes: 3 additions & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "engine/sparqlExpressions/RegexExpression.h"
#include "engine/sparqlExpressions/RelationalExpressions.h"
#include "engine/sparqlExpressions/SampleExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "engine/sparqlExpressions/UuidExpressions.h"
#include "parser/GraphPatternOperation.h"
#include "parser/RdfParser.h"
Expand Down Expand Up @@ -2372,6 +2373,8 @@ ExpressionPtr Visitor::visit(Parser::AggregateContext* ctx) {
}

return makePtr.operator()<GroupConcatExpression>(std::move(separator));
} else if (functionName == "stdev") {
return makePtr.operator()<StdevExpression>();
} else {
AD_CORRECTNESS_CHECK(functionName == "sample");
return makePtr.operator()<SampleExpression>();
Expand Down
1 change: 1 addition & 0 deletions src/parser/sparqlParser/SparqlQleverVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "engine/sparqlExpressions/AggregateExpression.h"
#include "engine/sparqlExpressions/NaryExpression.h"
#include "engine/sparqlExpressions/StdevExpression.h"
#include "parser/data/GraphRef.h"
#undef EOF
#include "parser/sparqlParser/generated/SparqlAutomaticVisitor.h"
Expand Down
2 changes: 2 additions & 0 deletions src/parser/sparqlParser/generated/SparqlAutomatic.g4
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ aggregate : COUNT '(' DISTINCT? ( '*' | expression ) ')'
| MIN '(' DISTINCT? expression ')'
| MAX '(' DISTINCT? expression ')'
| AVG '(' DISTINCT? expression ')'
| STDEV '(' DISTINCT? expression ')'
| SAMPLE '(' DISTINCT? expression ')'
| GROUP_CONCAT '(' DISTINCT? expression ( ';' SEPARATOR '=' string )? ')' ;

Expand Down Expand Up @@ -763,6 +764,7 @@ SUM : S U M;
MIN : M I N;
MAX : M A X;
AVG : A V G;
STDEV : S T D E V ;
SAMPLE : S A M P L E;
SEPARATOR : S E P A R A T O R;

Expand Down
4 changes: 3 additions & 1 deletion src/parser/sparqlParser/generated/SparqlAutomatic.interp

Large diffs are not rendered by default.

75 changes: 38 additions & 37 deletions src/parser/sparqlParser/generated/SparqlAutomatic.tokens
Original file line number Diff line number Diff line change
Expand Up @@ -136,43 +136,44 @@ SUM=135
MIN=136
MAX=137
AVG=138
SAMPLE=139
SEPARATOR=140
IRI_REF=141
PNAME_NS=142
PNAME_LN=143
BLANK_NODE_LABEL=144
VAR1=145
VAR2=146
LANGTAG=147
PREFIX_LANGTAG=148
INTEGER=149
DECIMAL=150
DOUBLE=151
INTEGER_POSITIVE=152
DECIMAL_POSITIVE=153
DOUBLE_POSITIVE=154
INTEGER_NEGATIVE=155
DECIMAL_NEGATIVE=156
DOUBLE_NEGATIVE=157
EXPONENT=158
STRING_LITERAL1=159
STRING_LITERAL2=160
STRING_LITERAL_LONG1=161
STRING_LITERAL_LONG2=162
ECHAR=163
NIL=164
ANON=165
PN_CHARS_U=166
VARNAME=167
PN_PREFIX=168
PN_LOCAL=169
PLX=170
PERCENT=171
HEX=172
PN_LOCAL_ESC=173
WS=174
COMMENTS=175
STDEV=139
SAMPLE=140
SEPARATOR=141
IRI_REF=142
PNAME_NS=143
PNAME_LN=144
BLANK_NODE_LABEL=145
VAR1=146
VAR2=147
LANGTAG=148
PREFIX_LANGTAG=149
INTEGER=150
DECIMAL=151
DOUBLE=152
INTEGER_POSITIVE=153
DECIMAL_POSITIVE=154
DOUBLE_POSITIVE=155
INTEGER_NEGATIVE=156
DECIMAL_NEGATIVE=157
DOUBLE_NEGATIVE=158
EXPONENT=159
STRING_LITERAL1=160
STRING_LITERAL2=161
STRING_LITERAL_LONG1=162
STRING_LITERAL_LONG2=163
ECHAR=164
NIL=165
ANON=166
PN_CHARS_U=167
VARNAME=168
PN_PREFIX=169
PN_LOCAL=170
PLX=171
PERCENT=172
HEX=173
PN_LOCAL_ESC=174
WS=175
COMMENTS=176
'*'=1
'('=2
')'=3
Expand Down
Loading

0 comments on commit 1a2fe17

Please sign in to comment.