Skip to content

Commit

Permalink
Merge branch 'dtime-warp' of https://github.com/lewardo/flucoma-core
Browse files Browse the repository at this point in the history
…into dtime-warp
  • Loading branch information
tremblap committed Sep 10, 2023
2 parents d2c3011 + 8981f98 commit 902cdd7
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 97 deletions.
4 changes: 4 additions & 0 deletions include/algorithms/public/DTW.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ class DTW
return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}
}

return 0;
};

index lastCol(index row)
Expand All @@ -226,6 +228,8 @@ class DTW
return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}
}

return mCols - 1;
};
}; // struct Constraint
};
Expand Down
2 changes: 2 additions & 0 deletions include/clients/nrt/CommonResults.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ static const std::string LargeK{"k is too large"};
static const std::string SmallDim{"Number of dimensions is too small"};
static const std::string LargeDim{"Number of dimensions is too large"};
static const std::string EmptyDataSet{"DataSet is empty"};
static const std::string EmptyDataSeries{"DataSeries is empty"};
static const std::string EmptyLabelSet{"LabelSet is empty"};
static const std::string NoDataSet{"DataSet does not exist"};
static const std::string NoDataSeries{"DataSeries does not exist"};
static const std::string NoLabelSet{"LabelSet does not exist"};
static const std::string NoDataFitted{"No data fitted"};
static const std::string NotEnoughData{"Not enough data"};
Expand Down
31 changes: 9 additions & 22 deletions include/clients/nrt/DTWClassifierClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ constexpr auto DTWClassifierParams = defineParameters(
LongParam("numNeighbours", "Number of Nearest Neighbours", 3, Min(1)),
EnumParam("constraint", "Constraint Type", 0, "Unconstrained", "Ikatura",
"Sakoe-Chiba"),
FloatParam("radius", "Sakoe-Chiba Constraint Radius", 2, Min(0)),
FloatParam("gradient", "Ikatura Parallelogram max gradient", 1, Min(1)));
FloatParam("constraintParam", "Sakoe-Chiba radius or Ikatura max gradient",
3, Min(0)));

class DTWClassifierClient : public FluidBaseClient,
OfflineIn,
OfflineOut,
ModelObject,
public DataClient<DTWClassifierData>
{
enum { kName, kNumNeighbors, kConstraint, kRadius, kGradient };
enum { kName, kNumNeighbors, kConstraint, kParam };

public:
using string = std::string;
Expand Down Expand Up @@ -113,13 +113,13 @@ class DTWClassifierClient : public FluidBaseClient,
InputLabelSetClientRef labelSetClient)
{
auto dataSeriesClientPtr = dataSeriesClient.get().lock();
if (!dataSeriesClientPtr) return Error(NoDataSet);
if (!dataSeriesClientPtr) return Error(NoDataSeries);

auto labelSetPtr = labelSetClient.get().lock();
if (!labelSetPtr) return Error(NoLabelSet);

auto dataSeries = dataSeriesClientPtr->getDataSeries();
if (dataSeries.size() == 0) return Error(EmptyDataSet);
if (dataSeries.size() == 0) return Error(EmptyDataSeries);

auto labelSet = labelSetPtr->getLabelSet();
if (labelSet.size() == 0) return Error(EmptyLabelSet);
Expand Down Expand Up @@ -160,13 +160,13 @@ class DTWClassifierClient : public FluidBaseClient,
{

auto sourcePtr = source.get().lock();
if (!sourcePtr) return Error(NoDataSet);
if (!sourcePtr) return Error(NoDataSeries);

auto destPtr = dest.get().lock();
if (!destPtr) return Error(NoLabelSet);

auto dataSeries = sourcePtr->getDataSeries();
if (dataSeries.size() == 0) return Error(EmptyDataSet);
if (dataSeries.size() == 0) return Error(EmptyDataSeries);

if (dataSeries.pointSize() != mAlgorithm.dims())
return Error(WrongPointSize);
Expand Down Expand Up @@ -202,19 +202,6 @@ class DTWClassifierClient : public FluidBaseClient,
}

private:
float constraintParam(algorithm::DTWConstraint constraint) const
{
using namespace algorithm;

switch (constraint)
{
case DTWConstraint::kIkatura: return get<kGradient>();
case DTWConstraint::kSakoeChiba: return get<kRadius>();
}

return 0.0;
}

MessageResult<string> kNearestModeLabel(InputRealMatrixView series) const
{
index k = get<kNumNeighbors>();
Expand All @@ -236,8 +223,8 @@ class DTWClassifierClient : public FluidBaseClient,
std::transform(
indices.begin(), indices.end(), distances.begin(),
[&series, &ds, &constraint, this](index i) {
double dist = mAlgorithm.dtw.process(series, ds[i], constraint,
constraintParam(constraint));
double dist =
mAlgorithm.dtw.process(series, ds[i], constraint, get<kParam>());
return std::max(std::numeric_limits<double>::epsilon(), dist);
});

Expand Down
30 changes: 7 additions & 23 deletions include/clients/nrt/DTWClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@ constexpr auto DTWParams = defineParameters(
StringParam<Fixed<true>>("name", "Name"),
EnumParam("constraint", "Constraint Type", 0, "Unconstrained", "Ikatura",
"Sakoe-Chiba"),
LongParam("radius", "Sakoe-Chiba Constraint Radius", 2, Min(0)),
FloatParam("gradient", "Ikatura Parallelogram max gradient", 1.0,
Min(1.0)));
FloatParam("constraintParam", "Sakoe-Chiba radius or Ikatura max gradient",
3, Min(0)));

class DTWClient : public FluidBaseClient,
OfflineIn,
OfflineOut,
ModelObject,
public DataClient<algorithm::DTW>
{
enum { kName, kConstraint, kRadius, kGradient };
enum { kName, kConstraint, kParam };

public:
using string = std::string;
Expand Down Expand Up @@ -75,10 +74,10 @@ class DTWClient : public FluidBaseClient,
string id1, string id2)
{
auto dataseriesClientPtr = dataseriesClient.get().lock();
if (!dataseriesClientPtr) return Error<double>(NoDataSet);
if (!dataseriesClientPtr) return Error<double>(NoDataSeries);

auto srcDataSeries = dataseriesClientPtr->getDataSeries();
if (srcDataSeries.size() == 0) return Error<double>(EmptyDataSet);
if (srcDataSeries.size() == 0) return Error<double>(EmptyDataSeries);

index i1 = srcDataSeries.getIndex(id1), i2 = srcDataSeries.getIndex(id2);

Expand All @@ -90,8 +89,7 @@ class DTWClient : public FluidBaseClient,
algorithm::DTWConstraint constraint =
(algorithm::DTWConstraint) get<kConstraint>();

return mAlgorithm.process(series1, series2, constraint,
constraintParam(constraint));
return mAlgorithm.process(series1, series2, constraint, get<kParam>());
}

MessageResult<double> bufCost(InputBufferPtr data1, InputBufferPtr data2)
Expand All @@ -116,28 +114,14 @@ class DTWClient : public FluidBaseClient,
(algorithm::DTWConstraint) get<kConstraint>();

return mAlgorithm.process(buf1frames, buf2frames, constraint,
constraintParam(constraint));
get<kParam>());
}

static auto getMessageDescriptors()
{
return defineMessages(makeMessage("cost", &DTWClient::cost),
makeMessage("bufCost", &DTWClient::bufCost));
}

private:
float constraintParam(algorithm::DTWConstraint constraint)
{
using namespace algorithm;

switch (constraint)
{
case DTWConstraint::kIkatura: return get<kGradient>();
case DTWConstraint::kSakoeChiba: return get<kRadius>();
}

return 0.0;
}
};

using DTWRef = SharedClientRef<const DTWClient>;
Expand Down
33 changes: 10 additions & 23 deletions include/clients/nrt/DTWRegressorClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ constexpr auto DTWRegressorParams = defineParameters(
LongParam("numNeighbours", "Number of Nearest Neighbours", 3, Min(1)),
EnumParam("constraint", "Constraint Type", 0, "Unconstrained", "Ikatura",
"Sakoe-Chiba"),
FloatParam("radius", "Sakoe-Chiba Constraint Radius", 2, Min(0)),
FloatParam("gradient", "Ikatura Parallelogram max gradient", 1, Min(1)));
FloatParam("constraintParam", "Sakoe-Chiba radius or Ikatura max gradient",
3, Min(0)));

class DTWRegressorClient : public FluidBaseClient,
OfflineIn,
OfflineOut,
ModelObject,
public DataClient<DTWRegressorData>
{
enum { kName, kNumNeighbors, kConstraint, kRadius, kGradient };
enum { kName, kNumNeighbors, kConstraint, kParam };

public:
using string = std::string;
Expand Down Expand Up @@ -117,16 +117,16 @@ class DTWRegressorClient : public FluidBaseClient,
InputDataSetClientRef dataSetClient)
{
auto dataSeriesClientPtr = dataSeriesClient.get().lock();
if (!dataSeriesClientPtr) return Error(NoDataSet);
if (!dataSeriesClientPtr) return Error(NoDataSeries);

auto dataSetPtr = dataSetClient.get().lock();
if (!dataSetPtr) return Error(NoDataSet);

auto dataSeries = dataSeriesClientPtr->getDataSeries();
if (dataSeries.size() == 0) return Error(EmptyDataSet);
if (dataSeries.size() == 0) return Error(EmptyDataSeries);

auto dataSet = dataSetPtr->getDataSet();
if (dataSet.size() == 0) return Error(EmptyLabelSet);
if (dataSet.size() == 0) return Error(EmptyDataSet);

if (dataSeries.size() != dataSet.size()) return Error(SizesDontMatch);

Expand Down Expand Up @@ -180,13 +180,13 @@ class DTWRegressorClient : public FluidBaseClient,
{

auto sourcePtr = source.get().lock();
if (!sourcePtr) return Error(NoDataSet);
if (!sourcePtr) return Error(NoDataSeries);

auto destPtr = dest.get().lock();
if (!destPtr) return Error(NoLabelSet);

auto dataSeries = sourcePtr->getDataSeries();
if (dataSeries.size() == 0) return Error(EmptyDataSet);
if (dataSeries.size() == 0) return Error(EmptyDataSeries);

if (dataSeries.pointSize() != mAlgorithm.series.dims())
return Error(WrongPointSize);
Expand Down Expand Up @@ -230,19 +230,6 @@ class DTWRegressorClient : public FluidBaseClient,
}

private:
float constraintParam(algorithm::DTWConstraint constraint) const
{
using namespace algorithm;

switch (constraint)
{
case DTWConstraint::kIkatura: return get<kGradient>();
case DTWConstraint::kSakoeChiba: return get<kRadius>();
}

return 0.0;
}

MessageResult<RealVector>
kNearestWeightedSum(InputRealMatrixView series,
Allocator& alloc = FluidDefaultAllocator()) const
Expand All @@ -269,8 +256,8 @@ class DTWRegressorClient : public FluidBaseClient,
std::transform(
indices.begin(), indices.end(), distances.begin(),
[&series, &ds, &constraint, this](index i) {
double dist = mAlgorithm.dtw.process(series, ds[i], constraint,
constraintParam(constraint));
double dist =
mAlgorithm.dtw.process(series, ds[i], constraint, get<kParam>());
return std::max(std::numeric_limits<double>::epsilon(), dist);
});

Expand Down
40 changes: 20 additions & 20 deletions include/clients/nrt/DataSeriesClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,10 @@ class DataSeriesClient
bool overwrite)
{
auto dataseriesClientPtr = dataseriesClient.get().lock();
if (!dataseriesClientPtr) return Error(NoDataSet);
if (!dataseriesClientPtr) return Error(NoDataSeries);

auto srcDataSeries = dataseriesClientPtr->getDataSeries();
if (srcDataSeries.size() == 0) return Error(EmptyDataSet);
if (srcDataSeries.size() == 0) return Error(EmptyDataSeries);
if (srcDataSeries.pointSize() != mAlgorithm.pointSize())
return Error(WrongPointSize);

Expand All @@ -265,7 +265,7 @@ class DataSeriesClient
return OK();
}

MessageResult<void> getDataSet(DataSetClientRef dest, index time) const
MessageResult<void> getDataSet(index time, DataSetClientRef dest) const
{
auto destPtr = dest.get().lock();
if (!destPtr) return Error(NoDataSet);
Expand All @@ -279,18 +279,20 @@ class DataSeriesClient
MessageResult<void> getIds(LabelSetClientRef dest)
{
auto destPtr = dest.get().lock();
if (!destPtr) return Error(NoDataSet);
if (!destPtr) return Error(NoLabelSet);
destPtr->setLabelSet(getIdsLabelSet());

return OK();
}

MessageResult<FluidTensor<rt::string, 1>>
kNearest(InputBufferPtr data, index nNeighbours, index p = 2) const
MessageResult<FluidTensor<rt::string, 1>> kNearest(InputBufferPtr data,
index nNeighbours) const
{
// check for nNeighbours > 0 and < size of DS
if (mAlgorithm.size() == 0)
return Error<FluidTensor<rt::string, 1>>(EmptyDataSeries);
if (nNeighbours > mAlgorithm.size())
return Error<FluidTensor<rt::string, 1>>(SmallDataSet);
return Error<FluidTensor<rt::string, 1>>(LargeK);
if (nNeighbours <= 0) return Error<FluidTensor<rt::string, 1>>(SmallK);

BufferAdaptor::ReadAccess buf(data.get());
Expand All @@ -307,10 +309,9 @@ class DataSeriesClient

auto ds = mAlgorithm.getData();

std::transform(indices.begin(), indices.end(), distances.begin(),
[&series, &ds, &p, this](index i) {
return distance(series, ds[i], p);
});
std::transform(
indices.begin(), indices.end(), distances.begin(),
[&series, &ds, this](index i) { return distance(series, ds[i], 2); });

std::sort(indices.begin(), indices.end(), [&distances](index a, index b) {
return distances[asUnsigned(a)] < distances[asUnsigned(b)];
Expand All @@ -328,12 +329,14 @@ class DataSeriesClient
return labels;
}

MessageResult<FluidTensor<double, 1>>
kNearestDist(InputBufferPtr data, index nNeighbours, index p = 2) const
MessageResult<FluidTensor<double, 1>> kNearestDist(InputBufferPtr data,
index nNeighbours) const
{
// check for nNeighbours > 0 and < size of DS
if (mAlgorithm.size() == 0)
return Error<FluidTensor<double, 1>>(EmptyDataSeries);
if (nNeighbours > mAlgorithm.size())
return Error<FluidTensor<double, 1>>(SmallDataSet);
return Error<FluidTensor<double, 1>>(LargeK);
if (nNeighbours <= 0) return Error<FluidTensor<double, 1>>(SmallK);

BufferAdaptor::ReadAccess buf(data.get());
Expand All @@ -350,10 +353,9 @@ class DataSeriesClient

auto ds = mAlgorithm.getData();

std::transform(indices.begin(), indices.end(), distances.begin(),
[&series, &ds, &p, this](index i) {
return distance(series, ds[i], p);
});
std::transform(
indices.begin(), indices.end(), distances.begin(),
[&series, &ds, this](index i) { return distance(series, ds[i], 2); });

std::sort(indices.begin(), indices.end(), [&distances](index a, index b) {
return distances[asUnsigned(a)] < distances[asUnsigned(b)];
Expand Down Expand Up @@ -407,8 +409,6 @@ class DataSeriesClient
makeMessage("read", &DataSeriesClient::read),
makeMessage("kNearest", &DataSeriesClient::kNearest),
makeMessage("kNearestDist", &DataSeriesClient::kNearestDist),
makeMessage("toBuffer", &DataSeriesClient::getSeries),
makeMessage("fromBuffer", &DataSeriesClient::setSeries),
makeMessage("getIds", &DataSeriesClient::getIds),
makeMessage("getDataSet", &DataSeriesClient::getDataSet));
}
Expand Down
Loading

0 comments on commit 902cdd7

Please sign in to comment.