Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwarding port changes in 2.4 to main branch (Add more unit test coverage to output.model and input.parameter in commons pakage) #601

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(BETA1_FIELD, beta1);
}
if (beta2 != null) {
builder.field(BETA1_FIELD, beta2);
builder.field(BETA2_FIELD, beta2);
}
if (decayRate != null) {
builder.field(DECAY_RATE_FIELD, decayRate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ModelTensor implements Writeable, ToXContentObject {

@Builder
public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType dataType, ByteBuffer byteBuffer) {
if (this.data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) {
if (data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) {
throw new IllegalArgumentException("data type is null");
}
this.name = name;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.input.parameter.regression;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.util.function.Function;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.TestHelper.contentObjectToString;
import static org.opensearch.ml.common.TestHelper.testParseFromString;
import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME;

public class LogisticRegressionParamsTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

private Function<XContentParser, LogisticRegressionParams> function = parser -> {
try {
return (LogisticRegressionParams) LogisticRegressionParams.parse(parser);
} catch (IOException e) {
throw new RuntimeException("failed to parse LogisticRegressionParams", e);
}
};

private LogisticRegressionParams logisticRegressionParams;

@Before
public void setUp() {
logisticRegressionParams = LogisticRegressionParams
.builder()
.objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS)
.optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD)
.learningRate(0.1)
.momentumType(LogisticRegressionParams.MomentumType.STANDARD)
.momentumFactor(0.2)
.epsilon(0.3)
.beta1(0.4)
.beta2(0.5)
.decayRate(0.6)
.epochs(1)
.batchSize(2)
.seed(3L)
.target("test_target")
.build();
}

@Test
public void readInputStream_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
logisticRegressionParams.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
LogisticRegressionParams params = new LogisticRegressionParams(streamInput);
assertEquals(params, logisticRegressionParams);
}

@Test
public void parse_PassIntValueToDoubleField() throws IOException {
String paramsStr = contentObjectToString(logisticRegressionParams);
testParseFromString(logisticRegressionParams, paramsStr, function);
}

@Test
public void parse_InvalidParam_InvalidDoubleValue() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Double value passed as String");
String paramsStr = contentObjectToString(logisticRegressionParams);
testParseFromString(logisticRegressionParams, paramsStr.replace("\"epsilon\":0.3,", "\"epsilon\":\"0.3\","), function);
}

@Test
public void test_GetWriteableName() {
assertEquals(logisticRegressionParams.getWriteableName(), PARSE_FIELD_NAME);
}

@Test
public void test_GetVersion() {
assertEquals(logisticRegressionParams.getVersion(), 1);
}

@Test
public void readInputStream_Success_Empty() throws IOException {
LogisticRegressionParams params = LogisticRegressionParams.builder().build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
params.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
LogisticRegressionParams parsedParams = new LogisticRegressionParams(streamInput);
assertEquals(params, parsedParams);
}

@Test
public void parse_LogisticRegressionParams() throws IOException {
TestHelper.testParse(logisticRegressionParams, function);
}

@Test
public void parse_EmptyLogisticRegressionParams() throws IOException {
TestHelper.testParse(LogisticRegressionParams.builder().build(), function);
}

@Test
public void parse_LogisticRegressionParams_WrongExtraField() throws IOException {
TestHelper.testParseFromString(logisticRegressionParams, "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", function);
}

}

Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.output.model;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.nio.ByteBuffer;

import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

public class ModelTensorTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

private ModelTensor modelTensor;

@Before
public void setUp() {
modelTensor = ModelTensor.builder()
.name("model_tensor")
.data(new Number[]{1, 2, 3})
.shape(new long[]{1, 2, 3,})
.dataType(MLResultDataType.INT32)
.byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1}))
.build();
}

@Test
public void test_StreamInAndOut() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
modelTensor.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
assertEquals(parsedTensor, modelTensor);
}

@Test
public void test_ModelTensorSuccess() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelTensor.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}", modelTensorContent);
}

@Test
public void toXContent_NullValue() throws IOException {
ModelTensor tensor = ModelTensor.builder().build();
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
tensor.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{}", modelTensorContent);
}

@Test
public void test_StreamInAndOut_NullValue() throws IOException {
ModelTensor tensor = ModelTensor.builder().build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
tensor.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensor parsedTensor = new ModelTensor(streamInput);
assertEquals(parsedTensor, tensor);
}

@Test
public void test_UnknownDataType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("data type is null");
ModelTensor tensor = new ModelTensor("null_data", new Number[]{1, 2, 3}, null, MLResultDataType.UNKNOWN, ByteBuffer.wrap(new byte[]{0,1,0,1}));
}

@Test
public void test_NullDataType() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("data type is null");
ModelTensor tensor = new ModelTensor("null_data", new Number[]{1, 2, 3}, null, null, ByteBuffer.wrap(new byte[]{0,1,0,1}));
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.output.model;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.ml.common.TestHelper;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;

import static org.junit.Assert.assertEquals;
import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS;

public class ModelTensorsTest {

@Rule
public ExpectedException exceptionRule = ExpectedException.none();
private ModelTensors modelTensors;
private ModelResultFilter modelResultFilter;

@Before
public void setUp() {
String sentence = "test sentence";
String column = "model_tensor";
Integer position = 1;
modelResultFilter = ModelResultFilter.builder()
.targetResponse(Arrays.asList(column))
.targetResponsePositions(Arrays.asList(position))
.build();

ModelTensor modelTensor = ModelTensor.builder()
.name("model_tensor")
.data(new Number[]{1, 2, 3})
.shape(new long[]{1, 2, 3,})
.dataType(MLResultDataType.INT32)
.byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1}))
.build();

modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
}

@Test
public void test_ModelTensortoXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelTensors.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", modelTensorContent);
}

@Test
public void test_ModelTensortoXContent_NullValue() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
ModelTensors tensors = ModelTensors.builder().build();
tensors.toXContent(builder, EMPTY_PARAMS);
String modelTensorContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{}", modelTensorContent);
}

@Test
public void test_StreamInAndOut_NullValue() throws IOException {
ModelTensors tensors = ModelTensors.builder().build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
tensors.writeTo(bytesStreamOutput);

StreamInput streamInput = bytesStreamOutput.bytes().streamInput();
ModelTensors parsedTensors = new ModelTensors(streamInput);
assertEquals(parsedTensors.getMlModelTensors(), tensors.getMlModelTensors());
}

@Test
public void test_Filter() {
ModelTensor modelTensorFiltered = ModelTensor.builder()
.name("model_tensor")
.shape(new long[]{1, 2, 3,})
.dataType(MLResultDataType.INT32)
.build();
modelTensors.filter(modelResultFilter);
assertEquals(modelTensors.getMlModelTensors().size(), 1);
assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered);
}

@Test
public void test_Filter_NullTargetResponse() {
ModelResultFilter resultFilter = ModelResultFilter.builder().build();
modelTensors.filter(resultFilter);
assertEquals(modelTensors.getMlModelTensors().size(), 1);
}

@Test
public void test_Filter_NullMLModelTensors() {
ModelTensors tensors = ModelTensors.builder().build();
tensors.filter(modelResultFilter);
assertEquals(modelTensors.getMlModelTensors().size(), 1);
}

@Test
public void test_ToAndFromBytes() throws IOException {
byte[] bytes = modelTensors.toBytes();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
modelTensors.writeTo(bytesStreamOutput);
assertEquals(bytes.length, bytesStreamOutput.bytes().toBytesRef().bytes.length);

ModelTensors tensors = ModelTensors.fromBytes(bytes);
assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors());
}
}