Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward port flaky test fix and add forecasting security tests #1329

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,15 @@ integTest {
filter {
includeTestsMatching "org.opensearch.ad.rest.*IT"
includeTestsMatching "org.opensearch.ad.e2e.*IT"
includeTestsMatching "org.opensearch.forecast.rest.*IT"
includeTestsMatching "org.opensearch.forecast.e2e.*IT"
}
}

if (System.getProperty("https") == null || System.getProperty("https") == "false") {
filter {
excludeTestsMatching "org.opensearch.ad.rest.SecureADRestIT"
excludeTestsMatching "org.opensearch.forecast.rest.SecureForecastRestIT"
}
}

Expand Down Expand Up @@ -468,6 +471,7 @@ task integTestRemote(type: RestIntegTestTask) {
if (System.getProperty("https") == null || System.getProperty("https") == "false") {
filter {
excludeTestsMatching "org.opensearch.ad.rest.SecureADRestIT"
excludeTestsMatching "org.opensearch.forecast.rest.SecureForecastRestIT"
}
}
}
Expand Down Expand Up @@ -696,10 +700,7 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao',
'org.opensearch.timeseries.transport.JobRequest',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.ml.Inferencer',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
Expand Down Expand Up @@ -727,7 +728,6 @@ List<String> jacocoExclusions = [
'org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker',
'org.opensearch.timeseries.util.TimeUtil',
'org.opensearch.ad.transport.ADHCImputeTransportAction',
'org.opensearch.timeseries.ml.RealTimeInferencer',
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.ad.caching.ADCacheProvider;
import org.opensearch.ad.caching.ADPriorityCache;
import org.opensearch.ad.indices.ADIndex;
Expand All @@ -32,7 +34,8 @@ public ADRealTimeInferencer(
ADColdStartWorker coldStartWorker,
ADSaveResultStrategy resultWriteWorker,
ADCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ADRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
AD_THREAD_POOL_NAME
AD_THREAD_POOL_NAME,
clock
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.util.ActionListenerExecutor;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -129,14 +128,12 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
return;
}
Config config = configOptional.get();
long windowDelayMillis = ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis();
int featureSize = config.getEnabledFeatureIds().size();
long dataEndMillis = nodeRequest.getRequest().getDataEndMillis();
long dataStartMillis = nodeRequest.getRequest().getDataStartMillis();
long executionEndTime = dataEndMillis + windowDelayMillis;
String taskId = nodeRequest.getRequest().getTaskId();
for (ModelState<ThresholdedRandomCutForest> modelState : cache.get().getAllModels(configId)) {
if (shouldProcessModelState(modelState, executionEndTime, clusterService, hashRing)) {
if (shouldProcessModelState(modelState, dataEndMillis, clusterService, hashRing)) {
double[] nanArray = new double[featureSize];
Arrays.fill(nanArray, Double.NaN);
adInferencer
Expand All @@ -163,8 +160,8 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* Determines whether the model state should be processed based on various conditions.
*
* Conditions checked:
* - The model's last seen execution end time is not the minimum Instant value.
* - The current execution end time is greater than or equal to the model's last seen execution end time,
* - The model's last seen data end time is not the minimum Instant value. This means the model hasn't been initialized yet.
* - The current data end time is greater than the model's last seen data end time,
* indicating that the model state was updated in previous intervals.
* - The entity associated with the model state is present.
* - The owning node for real-time processing of the entity, with the same local version, is present in the hash ring.
Expand All @@ -175,14 +172,14 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* concurrently (e.g., during tests when multiple threads may operate quickly).
*
* @param modelState The current state of the model.
* @param executionEndTime The end time of the current execution interval.
* @param dataEndTime The data end time of current interval.
* @param clusterService The service providing information about the current cluster node.
* @param hashRing The hash ring used to determine the owning node for real-time processing of entities.
* @return true if the model state should be processed; otherwise, false.
*/
private boolean shouldProcessModelState(
ModelState<ThresholdedRandomCutForest> modelState,
long executionEndTime,
long dataEndTime,
ClusterService clusterService,
HashRing hashRing
) {
Expand All @@ -194,8 +191,8 @@ private boolean shouldProcessModelState(
// Check if the model state conditions are met for processing
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
return modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& executionEndTime >= modelState.getLastSeenExecutionEndTime().toEpochMilli()
return modelState.getLastSeenDataEndTime() != Instant.MIN
&& dataEndTime > modelState.getLastSeenDataEndTime().toEpochMilli()
&& modelState.getEntity().isPresent()
&& owningNode.isPresent()
&& owningNode.get().getId().equals(clusterService.localNode().getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.forecast.caching.ForecastCacheProvider;
import org.opensearch.forecast.caching.ForecastPriorityCache;
import org.opensearch.forecast.indices.ForecastIndex;
Expand All @@ -32,7 +34,8 @@ public ForecastRealTimeInferencer(
ForecastColdStartWorker coldStartWorker,
ForecastSaveResultStrategy resultWriteWorker,
ForecastCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ForecastRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
FORECAST_THREAD_POOL_NAME
FORECAST_THREAD_POOL_NAME,
clock
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
all,
RestHandlerUtils.buildEntity(request, forecasterId)
);

return channel -> client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, new RestToXContentListener<>(channel));
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(Encode.forHtml(e.getMessage()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ private QueryBuilder generateBuildInSubFilter(SearchTopForecastResultRequest req
*/
private RangeQueryBuilder generateDateFilter(SearchTopForecastResultRequest request, Forecaster forecaster) {
// forecast from is data end time for forecast
// return QueryBuilders.termQuery(CommonName.DATA_END_TIME_FIELD, request.getForecastFrom().toEpochMilli());
long startInclusive = request.getForecastFrom().toEpochMilli();
long endExclusive = startInclusive + forecaster.getIntervalInMilliseconds();
return QueryBuilders.rangeQuery(CommonName.DATA_END_TIME_FIELD).gte(startInclusive).lt(endExclusive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
adColdstartQueue,
adSaveResultStrategy,
adCacheProvider,
threadPool
threadPool,
getClock()
);

ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker(
Expand Down Expand Up @@ -1230,7 +1231,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
forecastColdstartQueue,
forecastSaveResultStrategy,
forecastCacheProvider,
threadPool
threadPool,
getClock()
);

ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public ModelState<RCFModelType> get(String modelId, Config config) {
// reset every 60 intervals
return new DoorKeeper(
TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION,
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ),
clock,
TimeSeriesSettings.CACHE_DOOR_KEEPER_COUNT_THRESHOLD
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ private void coldStart(
// reset every 60 intervals
return new DoorKeeper(
TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION,
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ),
clock,
TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD
);
Expand All @@ -251,7 +251,7 @@ private void coldStart(
logger
.info(
"Won't retry real-time cold start within {} intervals for model {}",
TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ,
TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ,
modelId
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType score(
throw e;
} finally {
modelState.setLastUsedTime(clock.instant());
modelState.setLastSeenExecutionEndTime(clock.instant());
modelState.setLastSeenDataEndTime(sample.getDataEndTime());
}
return createEmptyResult();
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/org/opensearch/timeseries/ml/ModelState.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class ModelState<T> implements org.opensearch.timeseries.ExpiringState {
// time when the ML model was used last time
protected Instant lastUsedTime;
protected Instant lastCheckpointTime;
protected Instant lastSeenExecutionEndTime;
protected Instant lastSeenDataEndTime;
protected Clock clock;
protected float priority;
protected Deque<Sample> samples;
Expand Down Expand Up @@ -75,7 +75,7 @@ public ModelState(
this.priority = priority;
this.entity = entity;
this.samples = samples;
this.lastSeenExecutionEndTime = Instant.MIN;
this.lastSeenDataEndTime = Instant.MIN;
}

/**
Expand Down Expand Up @@ -252,11 +252,11 @@ public Map<String, Object> getModelStateAsMap() {
};
}

public Instant getLastSeenExecutionEndTime() {
return lastSeenExecutionEndTime;
public Instant getLastSeenDataEndTime() {
return lastSeenDataEndTime;
}

public void setLastSeenExecutionEndTime(Instant lastSeenExecutionEndTime) {
this.lastSeenExecutionEndTime = lastSeenExecutionEndTime;
public void setLastSeenDataEndTime(Instant lastSeenExecutionEndTime) {
this.lastSeenDataEndTime = lastSeenExecutionEndTime;
}
}
Loading
Loading