diff --git a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java index 992dad68cf..b29923df75 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java +++ b/common/src/main/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParams.java @@ -249,7 +249,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(BETA1_FIELD, beta1); } if (beta2 != null) { - builder.field(BETA1_FIELD, beta2); + builder.field(BETA2_FIELD, beta2); } if (decayRate != null) { builder.field(DECAY_RATE_FIELD, decayRate); diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java index a3a14441bf..927b7dd8a1 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensor.java @@ -27,7 +27,7 @@ public class ModelTensor implements Writeable, ToXContentObject { @Builder public ModelTensor(String name, Number[] data, long[] shape, MLResultDataType dataType, ByteBuffer byteBuffer) { - if (this.data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) { + if (data != null && (dataType == null || dataType == MLResultDataType.UNKNOWN)) { throw new IllegalArgumentException("data type is null"); } this.name = name; diff --git a/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java new file mode 100644 index 0000000000..820ed5c93a --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/input/parameter/regression/LogisticRegressionParamsTest.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.input.parameter.regression; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.TestHelper.contentObjectToString; +import static org.opensearch.ml.common.TestHelper.testParseFromString; +import static org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams.PARSE_FIELD_NAME; + +public class LogisticRegressionParamsTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private Function function = parser -> { + try { + return (LogisticRegressionParams) LogisticRegressionParams.parse(parser); + } catch (IOException e) { + throw new RuntimeException("failed to parse LogisticRegressionParams", e); + } + }; + + private LogisticRegressionParams logisticRegressionParams; + + @Before + public void setUp() { + logisticRegressionParams = LogisticRegressionParams + .builder() + .objectiveType(LogisticRegressionParams.ObjectiveType.LOGMULTICLASS) + .optimizerType(LogisticRegressionParams.OptimizerType.ADA_GRAD) + .learningRate(0.1) + .momentumType(LogisticRegressionParams.MomentumType.STANDARD) + .momentumFactor(0.2) + .epsilon(0.3) + .beta1(0.4) + .beta2(0.5) + .decayRate(0.6) + .epochs(1) + .batchSize(2) + .seed(3L) + .target("test_target") + .build(); + } + + @Test + public void readInputStream_Success() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + logisticRegressionParams.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + LogisticRegressionParams params = new LogisticRegressionParams(streamInput); + assertEquals(params, logisticRegressionParams); + } + + @Test + public void parse_PassIntValueToDoubleField() throws IOException { + String paramsStr = contentObjectToString(logisticRegressionParams); + testParseFromString(logisticRegressionParams, paramsStr, function); + } + + @Test + public void parse_InvalidParam_InvalidDoubleValue() throws IOException { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Double value passed as String"); + String paramsStr = contentObjectToString(logisticRegressionParams); + testParseFromString(logisticRegressionParams, paramsStr.replace("\"epsilon\":0.3,", "\"epsilon\":\"0.3\","), function); + } + + @Test + public void test_GetWriteableName() { + assertEquals(logisticRegressionParams.getWriteableName(), PARSE_FIELD_NAME); + } + + @Test + public void test_GetVersion() { + assertEquals(logisticRegressionParams.getVersion(), 1); + } + + @Test + public void readInputStream_Success_Empty() throws IOException { + LogisticRegressionParams params = LogisticRegressionParams.builder().build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + params.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + LogisticRegressionParams parsedParams = new LogisticRegressionParams(streamInput); + assertEquals(params, parsedParams); + } + + @Test + public void parse_LogisticRegressionParams() throws IOException { + TestHelper.testParse(logisticRegressionParams, function); + } + + @Test + public void parse_EmptyLogisticRegressionParams() throws IOException { + TestHelper.testParse(LogisticRegressionParams.builder().build(), function); + } + + @Test + public void parse_LogisticRegressionParams_WrongExtraField() throws IOException { + TestHelper.testParseFromString(logisticRegressionParams, "{\"objective\":\"LOGMULTICLASS\",\"learning_rate\":0.1,\"wrong_field\":1.0}", function); + } + +} + diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java new file mode 100644 index 0000000000..86ab1a1739 --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorTest.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.output.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; + +public class ModelTensorTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + private ModelTensor modelTensor; + + @Before + public void setUp() { + modelTensor = ModelTensor.builder() + .name("model_tensor") + .data(new Number[]{1, 2, 3}) + .shape(new long[]{1, 2, 3,}) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) + .build(); + } + + @Test + public void test_StreamInAndOut() throws IOException { + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + modelTensor.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ModelTensor parsedTensor = new ModelTensor(streamInput); + assertEquals(parsedTensor, modelTensor); + } + + @Test + public void test_ModelTensorSuccess() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelTensor.toXContent(builder, EMPTY_PARAMS); + String modelTensorContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}", modelTensorContent); + } + + @Test + public void toXContent_NullValue() throws IOException { + ModelTensor tensor = ModelTensor.builder().build(); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + tensor.toXContent(builder, EMPTY_PARAMS); + String modelTensorContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{}", modelTensorContent); + } + + @Test + public void test_StreamInAndOut_NullValue() throws IOException { + ModelTensor tensor = ModelTensor.builder().build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + tensor.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ModelTensor parsedTensor = new ModelTensor(streamInput); + assertEquals(parsedTensor, tensor); + } + + @Test + public void test_UnknownDataType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("data type is null"); + ModelTensor tensor = new ModelTensor("null_data", new Number[]{1, 2, 3}, null, MLResultDataType.UNKNOWN, ByteBuffer.wrap(new byte[]{0,1,0,1})); + } + + @Test + public void test_NullDataType() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("data type is null"); + ModelTensor tensor = new ModelTensor("null_data", new Number[]{1, 2, 3}, null, null, ByteBuffer.wrap(new byte[]{0,1,0,1})); + } +} + diff --git a/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java new file mode 100644 index 0000000000..8864689f7e --- /dev/null +++ b/common/src/test/java/org/opensearch/ml/common/output/model/ModelTensorsTest.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.output.model; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.ml.common.TestHelper; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.opensearch.common.xcontent.ToXContent.EMPTY_PARAMS; + +public class ModelTensorsTest { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + private ModelTensors modelTensors; + private ModelResultFilter modelResultFilter; + + @Before + public void setUp() { + String sentence = "test sentence"; + String column = "model_tensor"; + Integer position = 1; + modelResultFilter = ModelResultFilter.builder() + .targetResponse(Arrays.asList(column)) + .targetResponsePositions(Arrays.asList(position)) + .build(); + + ModelTensor modelTensor = ModelTensor.builder() + .name("model_tensor") + .data(new Number[]{1, 2, 3}) + .shape(new long[]{1, 2, 3,}) + .dataType(MLResultDataType.INT32) + .byteBuffer(ByteBuffer.wrap(new byte[]{0,1,0,1})) + .build(); + + modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + } + + @Test + public void test_ModelTensortoXContent() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + modelTensors.toXContent(builder, EMPTY_PARAMS); + String modelTensorContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{\"output\":[{\"name\":\"model_tensor\",\"data_type\":\"INT32\",\"shape\":[1,2,3],\"data\":[1,2,3],\"byte_buffer\":{\"array\":\"AAEAAQ==\",\"order\":\"BIG_ENDIAN\"}}]}", modelTensorContent); + } + + @Test + public void test_ModelTensortoXContent_NullValue() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + ModelTensors tensors = ModelTensors.builder().build(); + tensors.toXContent(builder, EMPTY_PARAMS); + String modelTensorContent = TestHelper.xContentBuilderToString(builder); + assertEquals("{}", modelTensorContent); + } + + @Test + public void test_StreamInAndOut_NullValue() throws IOException { + ModelTensors tensors = ModelTensors.builder().build(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + tensors.writeTo(bytesStreamOutput); + + StreamInput streamInput = bytesStreamOutput.bytes().streamInput(); + ModelTensors parsedTensors = new ModelTensors(streamInput); + assertEquals(parsedTensors.getMlModelTensors(), tensors.getMlModelTensors()); + } + + @Test + public void test_Filter() { + ModelTensor modelTensorFiltered = ModelTensor.builder() + .name("model_tensor") + .shape(new long[]{1, 2, 3,}) + .dataType(MLResultDataType.INT32) + .build(); + modelTensors.filter(modelResultFilter); + assertEquals(modelTensors.getMlModelTensors().size(), 1); + assertEquals(modelTensors.getMlModelTensors().get(0), modelTensorFiltered); + } + + @Test + public void test_Filter_NullTargetResponse() { + ModelResultFilter resultFilter = ModelResultFilter.builder().build(); + modelTensors.filter(resultFilter); + assertEquals(modelTensors.getMlModelTensors().size(), 1); + } + + @Test + public void test_Filter_NullMLModelTensors() { + ModelTensors tensors = ModelTensors.builder().build(); + tensors.filter(modelResultFilter); + assertEquals(modelTensors.getMlModelTensors().size(), 1); + } + + @Test + public void test_ToAndFromBytes() throws IOException { + byte[] bytes = modelTensors.toBytes(); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + modelTensors.writeTo(bytesStreamOutput); + assertEquals(bytes.length, bytesStreamOutput.bytes().toBytesRef().bytes.length); + + ModelTensors tensors = ModelTensors.fromBytes(bytes); + assertEquals(modelTensors.getMlModelTensors(), tensors.getMlModelTensors()); + } +} +