diff --git a/CHANGELOG.md b/CHANGELOG.md index 57a2f7a4d..2b732b446 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.12...2.x) ### Features +- Implement document chunking processor with fixed token length and delimiter algorithm ([#607](https://github.com/opensearch-project/neural-search/pull/607/)) - Enabled support for applying default modelId in neural sparse query ([#614](https://github.com/opensearch-project/neural-search/pull/614) ### Enhancements - Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630)) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 0182ff4d3..d54c644c4 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -31,9 +31,11 @@ import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; @@ -114,14 +116,21 @@ public Map getProcessors(Processor.Parameters paramet SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env), TextImageEmbeddingProcessor.TYPE, - new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()) + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), + TextChunkingProcessor.TYPE, + new TextChunkingProcessorFactory( + parameters.env, + parameters.ingestService.getClusterService(), + parameters.indicesService, + parameters.analysisRegistry + ) ); } @Override public Optional getQueryPhaseSearcher() { // we're using "is_disabled" flag as there are no proper implementation of FeatureFlags.isDisabled(). Both - // cases when flag is not set or it is "false" are interpretted in the same way. In such case core is reading + // cases when flag is not set, or it is "false" are interpreted in the same way. In such case core is reading // the actual value from settings. if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey())) { log.info( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java new file mode 100644 index 000000000..50a9d4b7b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextChunkingProcessor.java @@ -0,0 +1,310 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Locale; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.env.Environment; +import org.opensearch.index.IndexService; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.indices.IndicesService; +import org.opensearch.index.IndexSettings; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.processor.chunker.Chunker; +import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory; +import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker; + +import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.DISABLED_MAX_CHUNK_LIMIT; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; + +/** + * This processor is used for text chunking. + * The text chunking results could be fed to downstream embedding processor. + * The processor needs two fields: algorithm and field_map, + * where algorithm defines chunking algorithm and parameters, + * and field_map specifies which fields needs chunking and the corresponding keys for the chunking results. + */ +public final class TextChunkingProcessor extends AbstractProcessor { + + public static final String TYPE = "text_chunking"; + public static final String FIELD_MAP_FIELD = "field_map"; + public static final String ALGORITHM_FIELD = "algorithm"; + private static final String DEFAULT_ALGORITHM = FixedTokenLengthChunker.ALGORITHM_NAME; + + private int maxChunkLimit; + private Chunker chunker; + private final Map fieldMap; + private final ClusterService clusterService; + private final IndicesService indicesService; + private final AnalysisRegistry analysisRegistry; + private final Environment environment; + + public TextChunkingProcessor( + final String tag, + final String description, + final Map fieldMap, + final Map algorithmMap, + final Environment environment, + final ClusterService clusterService, + final IndicesService indicesService, + final AnalysisRegistry analysisRegistry + ) { + super(tag, description); + this.fieldMap = fieldMap; + this.environment = environment; + this.clusterService = clusterService; + this.indicesService = indicesService; + this.analysisRegistry = analysisRegistry; + parseAlgorithmMap(algorithmMap); + } + + public String getType() { + return TYPE; + } + + @SuppressWarnings("unchecked") + private void parseAlgorithmMap(final Map algorithmMap) { + if (algorithmMap.size() > 1) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Unable to create %s processor as [%s] contains multiple algorithms", TYPE, ALGORITHM_FIELD) + ); + } + + String algorithmKey; + Object algorithmValue; + if (algorithmMap.isEmpty()) { + algorithmKey = DEFAULT_ALGORITHM; + algorithmValue = new HashMap<>(); + } else { + Entry algorithmEntry = algorithmMap.entrySet().iterator().next(); + algorithmKey = algorithmEntry.getKey(); + algorithmValue = algorithmEntry.getValue(); + if (!(algorithmValue instanceof Map)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Unable to create %s processor as parameters for [%s] algorithm must be an object", + TYPE, + algorithmKey + ) + ); + } + } + + if (!ChunkerFactory.CHUNKER_ALGORITHMS.contains(algorithmKey)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Chunking algorithm [%s] is not supported. Supported chunking algorithms are %s", + algorithmKey, + ChunkerFactory.CHUNKER_ALGORITHMS + ) + ); + } + Map chunkerParameters = (Map) algorithmValue; + // parse processor level max chunk limit + this.maxChunkLimit = parseIntegerParameter(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + if (maxChunkLimit < 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Parameter [%s] must be positive or %s to disable this parameter", + MAX_CHUNK_LIMIT_FIELD, + DISABLED_MAX_CHUNK_LIMIT + ) + ); + } + // fixed token length algorithm needs analysis registry for tokenization + chunkerParameters.put(FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD, analysisRegistry); + this.chunker = ChunkerFactory.create(algorithmKey, chunkerParameters); + } + + @SuppressWarnings("unchecked") + private boolean isListOfString(final Object value) { + // an empty list is also List + if (!(value instanceof List)) { + return false; + } + for (Object element : (List) value) { + if (!(element instanceof String)) { + return false; + } + } + return true; + } + + private int getMaxTokenCount(final Map sourceAndMetadataMap) { + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); + if (Objects.isNull(indexMetadata)) { + return IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings()); + } + // if the index is specified in the metadata, read maxTokenCount from the index setting + IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex()); + return indexService.getIndexSettings().getMaxTokenCount(); + } + + /** + * This method will be invoked by PipelineService to perform chunking and then write back chunking results to the document. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + */ + @Override + public IngestDocument execute(final IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + validateFieldsValue(sourceAndMetadataMap); + // fixed token length algorithm needs runtime parameter max_token_count for tokenization + Map runtimeParameters = new HashMap<>(); + int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap); + runtimeParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount); + runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + chunkMapType(sourceAndMetadataMap, fieldMap, runtimeParameters, 0); + return ingestDocument; + } + + private void validateFieldsValue(final Map sourceAndMetadataMap) { + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (Objects.nonNull(sourceValue)) { + String sourceKey = embeddingFieldsEntry.getKey(); + if (sourceValue instanceof List || sourceValue instanceof Map) { + validateNestedTypeValue(sourceKey, sourceValue, 1); + } else if (!(sourceValue instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", sourceKey) + ); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) { + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", sourceKey) + ); + } else if (sourceValue instanceof List) { + validateListTypeValue(sourceKey, sourceValue, maxDepth); + } else if (sourceValue instanceof Map) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1)); + } else if (!(sourceValue instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", sourceKey) + ); + } + } + + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(final String sourceKey, final Object sourceValue, final int maxDepth) { + for (Object value : (List) sourceValue) { + if (value instanceof Map) { + validateNestedTypeValue(sourceKey, value, maxDepth + 1); + } else if (value == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", sourceKey) + ); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", sourceKey) + ); + } + } + } + + @SuppressWarnings("unchecked") + private int chunkMapType( + Map sourceAndMetadataMap, + final Map fieldMap, + final Map runtimeParameters, + final int chunkCount + ) { + int updatedChunkCount = chunkCount; + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + // call this method recursively when target key is a map + Object sourceObject = sourceAndMetadataMap.get(originalKey); + if (sourceObject instanceof List) { + List sourceObjectList = (List) sourceObject; + for (Object source : sourceObjectList) { + if (source instanceof Map) { + updatedChunkCount = chunkMapType( + (Map) source, + (Map) targetKey, + runtimeParameters, + updatedChunkCount + ); + } + } + } else if (sourceObject instanceof Map) { + updatedChunkCount = chunkMapType( + (Map) sourceObject, + (Map) targetKey, + runtimeParameters, + updatedChunkCount + ); + } + } else { + // chunk the object when target key is of leaf type (null, string and list of string) + Object chunkObject = sourceAndMetadataMap.get(originalKey); + List chunkedResult = chunkLeafType(chunkObject, runtimeParameters); + sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult); + } + } + return updatedChunkCount; + } + + /** + * Chunk the content, update the runtime max_chunk_limit and return the result + */ + private List chunkString(final String content, final Map runTimeParameters) { + // update runtime max_chunk_limit if not disabled + List contentResult = chunker.chunk(content, runTimeParameters); + int runtimeMaxChunkLimit = parseIntegerParameter(runTimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + if (runtimeMaxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) { + runTimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit - contentResult.size()); + } + return contentResult; + } + + private List chunkList(final List contentList, final Map runTimeParameters) { + // flatten original output format from List> to List + List result = new ArrayList<>(); + for (String content : contentList) { + result.addAll(chunkString(content, runTimeParameters)); + } + return result; + } + + @SuppressWarnings("unchecked") + private List chunkLeafType(final Object value, final Map runTimeParameters) { + // leaf type means null, String or List + // the result should be an empty list when the input is null + List result = new ArrayList<>(); + if (value instanceof String) { + result = chunkString(value.toString(), runTimeParameters); + } else if (isListOfString(value)) { + result = chunkList((List) value, runTimeParameters); + } + return result; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java new file mode 100644 index 000000000..fb6712c76 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import java.util.Map; +import java.util.List; + +/** + * The interface for all chunking algorithms. + * All algorithms need to parse parameters and chunk the content. + */ +public interface Chunker { + + String MAX_CHUNK_LIMIT_FIELD = "max_chunk_limit"; + int DEFAULT_MAX_CHUNK_LIMIT = 100; + int DISABLED_MAX_CHUNK_LIMIT = -1; + + /** + * Parse the parameters for chunking algorithm. + * Throw IllegalArgumentException when parameters are invalid. + * + * @param parameters a map containing non-runtime parameters for chunking algorithms + */ + void parseParameters(Map parameters); + + /** + * Chunk the input string according to parameters and return chunked passages + * + * @param content input string + * @param runtimeParameters a map containing runtime parameters for chunking algorithms + * @return chunked passages + */ + List chunk(String content, Map runtimeParameters); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java new file mode 100644 index 000000000..aab9eaa3e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import com.google.common.collect.ImmutableMap; + +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + +/** + * A factory to create different chunking algorithm objects. + */ +public final class ChunkerFactory { + + private ChunkerFactory() {} // no instance of this factory class + + private static final Map, Chunker>> CHUNKERS_CONSTRUCTORS = ImmutableMap.of( + FixedTokenLengthChunker.ALGORITHM_NAME, + FixedTokenLengthChunker::new, + DelimiterChunker.ALGORITHM_NAME, + DelimiterChunker::new + ); + + public static Set CHUNKER_ALGORITHMS = CHUNKERS_CONSTRUCTORS.keySet(); + + public static Chunker create(final String type, final Map parameters) { + Function, Chunker> chunkerConstructionFunction = CHUNKERS_CONSTRUCTORS.get(type); + // chunkerConstructionFunction is not null because we have validated the type in text chunking processor + Objects.requireNonNull(chunkerConstructionFunction); + return chunkerConstructionFunction.apply(parameters); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java new file mode 100644 index 000000000..56a61a26f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerParameterParser.java @@ -0,0 +1,91 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.math.NumberUtils; + +import java.util.Locale; +import java.util.Map; + +/** + * Parse the parameter for text chunking processor and algorithms. + * Throw IllegalArgumentException when parameters are invalid. + */ +public final class ChunkerParameterParser { + + private ChunkerParameterParser() {} // no instance of this util class + + /** + * Parse String type parameter. + * Throw IllegalArgumentException if parameter is not a string or an empty string. + */ + public static String parseStringParameter(final Map parameters, final String fieldName, final String defaultValue) { + if (!parameters.containsKey(fieldName)) { + // all string parameters are optional + return defaultValue; + } + Object fieldValue = parameters.get(fieldName); + if (!(fieldValue instanceof String)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, String.class.getName()) + ); + } + if (StringUtils.isEmpty(fieldValue.toString())) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Parameter [%s] should not be empty.", fieldName)); + } + return fieldValue.toString(); + } + + /** + * Parse integer type parameter. + * Throw IllegalArgumentException if parameter is not an integer. + */ + public static int parseIntegerParameter(final Map parameters, final String fieldName, final int defaultValue) { + if (!parameters.containsKey(fieldName)) { + // all integer parameters are optional + return defaultValue; + } + String fieldValueString = parameters.get(fieldName).toString(); + try { + return NumberUtils.createInteger(fieldValueString); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Integer.class.getName()) + ); + } + } + + /** + * Parse integer type parameter with positive value. + * Throw IllegalArgumentException if parameter is not a positive integer. + */ + public static int parsePositiveIntegerParameter(final Map parameters, final String fieldName, final int defaultValue) { + int fieldValueInt = parseIntegerParameter(parameters, fieldName, defaultValue); + if (fieldValueInt <= 0) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName)); + } + return fieldValueInt; + } + + /** + * Parse double type parameter. + * Throw IllegalArgumentException if parameter is not a double. + */ + public static double parseDoubleParameter(final Map parameters, final String fieldName, final double defaultValue) { + if (!parameters.containsKey(fieldName)) { + // all double parameters are optional + return defaultValue; + } + String fieldValueString = parameters.get(fieldName).toString(); + try { + return NumberUtils.createDouble(fieldValueString); + } catch (NumberFormatException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", fieldName, Double.class.getName()) + ); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerUtil.java new file mode 100644 index 000000000..d4406f33e --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerUtil.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import java.util.Locale; + +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.DISABLED_MAX_CHUNK_LIMIT; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; + +/** + * A util class used by chunking algorithms. + */ +public class ChunkerUtil { + + private ChunkerUtil() {} // no instance of this util class + + /** + * Checks whether the chunking results would exceed the max chunk limit. + * If exceeds, then Throw IllegalStateException + * + * @param chunkResultSize the size of chunking result + * @param runtimeMaxChunkLimit runtime max_chunk_limit, used to check with chunkResultSize + * @param nonRuntimeMaxChunkLimit non-runtime max_chunk_limit, used to keep exception message consistent + */ + public static void checkRunTimeMaxChunkLimit(int chunkResultSize, int runtimeMaxChunkLimit, int nonRuntimeMaxChunkLimit) { + if (runtimeMaxChunkLimit != DISABLED_MAX_CHUNK_LIMIT && chunkResultSize >= runtimeMaxChunkLimit) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s]. This limit can be set by changing the [%s] parameter.", + TYPE, + nonRuntimeMaxChunkLimit, + MAX_CHUNK_LIMIT_FIELD + ) + ); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java new file mode 100644 index 000000000..c688af436 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java @@ -0,0 +1,76 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import java.util.Map; +import java.util.List; +import java.util.ArrayList; + +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter; + +/** + * The implementation {@link Chunker} for delimiter algorithm + */ +public final class DelimiterChunker implements Chunker { + + public static final String ALGORITHM_NAME = "delimiter"; + + public static final String DELIMITER_FIELD = "delimiter"; + + public static final String DEFAULT_DELIMITER = "\n\n"; + + private String delimiter; + private int maxChunkLimit; + + public DelimiterChunker(final Map parameters) { + parseParameters(parameters); + } + + /** + * Parse the parameters for delimiter algorithm. + * Throw IllegalArgumentException if delimiter is not a string or an empty string. + * + * @param parameters a map with non-runtime parameters as the following: + * 1. delimiter A string as the paragraph split indicator + * 2. max_chunk_limit processor level max chunk limit + */ + @Override + public void parseParameters(Map parameters) { + this.delimiter = parseStringParameter(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER); + this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + } + + /** + * Return the chunked passages for delimiter algorithm + * + * @param content input string + * @param runtimeParameters a map for runtime parameters, containing the following runtime parameters: + * 1. max_chunk_limit field level max chunk limit + */ + @Override + public List chunk(final String content, final Map runtimeParameters) { + int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + + List chunkResult = new ArrayList<>(); + int start = 0, end; + int nextDelimiterPosition = content.indexOf(delimiter); + + while (nextDelimiterPosition != -1) { + ChunkerUtil.checkRunTimeMaxChunkLimit(chunkResult.size(), runtimeMaxChunkLimit, maxChunkLimit); + end = nextDelimiterPosition + delimiter.length(); + chunkResult.add(content.substring(start, end)); + start = end; + nextDelimiterPosition = content.indexOf(delimiter, start); + } + + if (start < content.length()) { + ChunkerUtil.checkRunTimeMaxChunkLimit(chunkResult.size(), runtimeMaxChunkLimit, maxChunkLimit); + chunkResult.add(content.substring(start)); + } + + return chunkResult; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java new file mode 100644 index 000000000..cd630adf1 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import java.util.Locale; +import java.util.Map; +import java.util.List; +import java.util.Set; +import java.util.ArrayList; + +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.action.admin.indices.analyze.AnalyzeAction; +import org.opensearch.action.admin.indices.analyze.AnalyzeAction.AnalyzeToken; +import static org.opensearch.action.admin.indices.analyze.TransportAnalyzeAction.analyze; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerParameter; + +/** + * The implementation {@link Chunker} for fixed token length algorithm. + */ +public final class FixedTokenLengthChunker implements Chunker { + + public static final String ALGORITHM_NAME = "fixed_token_length"; + + // field name for each parameter + public static final String ANALYSIS_REGISTRY_FIELD = "analysis_registry"; + public static final String TOKEN_LIMIT_FIELD = "token_limit"; + public static final String OVERLAP_RATE_FIELD = "overlap_rate"; + public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count"; + public static final String TOKENIZER_FIELD = "tokenizer"; + + // default values for each parameter + private static final int DEFAULT_TOKEN_LIMIT = 384; + private static final double DEFAULT_OVERLAP_RATE = 0.0; + private static final int DEFAULT_MAX_TOKEN_COUNT = 10000; + private static final String DEFAULT_TOKENIZER = "standard"; + + // parameter restrictions + private static final double OVERLAP_RATE_LOWER_BOUND = 0.0; + private static final double OVERLAP_RATE_UPPER_BOUND = 0.5; + private static final Set WORD_TOKENIZERS = Set.of( + "standard", + "letter", + "lowercase", + "whitespace", + "uax_url_email", + "classic", + "thai" + ); + + // parameter value + private int tokenLimit; + private int maxChunkLimit; + private String tokenizer; + private double overlapRate; + private final AnalysisRegistry analysisRegistry; + + public FixedTokenLengthChunker(final Map parameters) { + parseParameters(parameters); + this.analysisRegistry = (AnalysisRegistry) parameters.get(ANALYSIS_REGISTRY_FIELD); + } + + /** + * Parse the parameters for fixed token length algorithm. + * Throw IllegalArgumentException when parameters are invalid. + * + * @param parameters a map with non-runtime parameters as the following: + * 1. tokenizer: the word tokenizer in opensearch + * 2. token_limit: the token limit for each chunked passage + * 3. overlap_rate: the overlapping degree for each chunked passage, indicating how many token comes from the previous passage + * 4. max_chunk_limit processor level max chunk level + * Here are requirements for non-runtime parameters: + * 1. token_limit must be a positive integer + * 2. overlap_rate must be within range [0, 0.5] + * 3. tokenizer must be a word tokenizer + * + */ + @Override + public void parseParameters(Map parameters) { + this.tokenLimit = parsePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT); + this.overlapRate = parseDoubleParameter(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE); + this.tokenizer = parseStringParameter(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER); + this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT); + if (overlapRate < OVERLAP_RATE_LOWER_BOUND || overlapRate > OVERLAP_RATE_UPPER_BOUND) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Parameter [%s] must be between %s and %s", + OVERLAP_RATE_FIELD, + OVERLAP_RATE_LOWER_BOUND, + OVERLAP_RATE_UPPER_BOUND + ) + ); + } + if (!WORD_TOKENIZERS.contains(tokenizer)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Tokenizer [%s] is not supported for [%s] algorithm. Supported tokenizers are %s", + tokenizer, + ALGORITHM_NAME, + WORD_TOKENIZERS + ) + ); + } + } + + /** + * Return the chunked passages for fixed token length algorithm. + * Throw IllegalArgumentException when runtime parameters are invalid. + * + * @param content input string + * @param runtimeParameters a map for runtime parameters, containing the following runtime parameters: + * 1. max_token_count the max token limit for the tokenizer + * 2. max_chunk_limit field level max chunk limit + */ + @Override + public List chunk(final String content, final Map runtimeParameters) { + int maxTokenCount = parsePositiveIntegerParameter(runtimeParameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT); + int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, this.maxChunkLimit); + + List tokens = tokenize(content, tokenizer, maxTokenCount); + List chunkResult = new ArrayList<>(); + + int startTokenIndex = 0; + int startContentPosition, endContentPosition; + int overlapTokenNumber = (int) Math.floor(tokenLimit * overlapRate); + + while (startTokenIndex < tokens.size()) { + ChunkerUtil.checkRunTimeMaxChunkLimit(chunkResult.size(), runtimeMaxChunkLimit, maxChunkLimit); + if (startTokenIndex == 0) { + // include all characters till the start if no previous passage + startContentPosition = 0; + } else { + startContentPosition = tokens.get(startTokenIndex).getStartOffset(); + } + if (startTokenIndex + tokenLimit >= tokens.size()) { + // include all characters till the end if no next passage + endContentPosition = content.length(); + chunkResult.add(content.substring(startContentPosition, endContentPosition)); + break; + } else { + // include gap characters between two passages + endContentPosition = tokens.get(startTokenIndex + tokenLimit).getStartOffset(); + chunkResult.add(content.substring(startContentPosition, endContentPosition)); + } + startTokenIndex += tokenLimit - overlapTokenNumber; + } + return chunkResult; + } + + private List tokenize(final String content, final String tokenizer, final int maxTokenCount) { + AnalyzeAction.Request analyzeRequest = new AnalyzeAction.Request(); + analyzeRequest.text(content); + analyzeRequest.tokenizer(tokenizer); + try { + AnalyzeAction.Response analyzeResponse = analyze(analyzeRequest, analysisRegistry, null, maxTokenCount); + return analyzeResponse.getTokens(); + } catch (Exception e) { + throw new IllegalStateException(String.format(Locale.ROOT, "analyzer %s throws exception: %s", tokenizer, e.getMessage()), e); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactory.java new file mode 100644 index 000000000..efffcc908 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactory.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.indices.IndicesService; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.TextChunkingProcessor; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.ALGORITHM_FIELD; +import static org.opensearch.ingest.ConfigurationUtils.readMap; + +/** + * Factory for chunking ingest processor for ingestion pipeline. + * Instantiates processor based on user provided input, which includes: + * 1. field_map: the input and output fields specified by the user + * 2. algorithm: chunking algorithm and its parameters + */ +public class TextChunkingProcessorFactory implements Processor.Factory { + + private final Environment environment; + + private final ClusterService clusterService; + + private final IndicesService indicesService; + + private final AnalysisRegistry analysisRegistry; + + public TextChunkingProcessorFactory( + Environment environment, + ClusterService clusterService, + IndicesService indicesService, + AnalysisRegistry analysisRegistry + ) { + this.environment = environment; + this.clusterService = clusterService; + this.indicesService = indicesService; + this.analysisRegistry = analysisRegistry; + } + + @Override + public TextChunkingProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + Map fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + Map algorithmMap = readMap(TYPE, processorTag, config, ALGORITHM_FIELD); + return new TextChunkingProcessor( + processorTag, + description, + fieldMap, + algorithmMap, + environment, + clusterService, + indicesService, + analysisRegistry + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 2a66f6992..cb3869868 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -5,11 +5,14 @@ package org.opensearch.neuralsearch.plugin; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.util.List; import java.util.Map; import java.util.Optional; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; import org.opensearch.indices.IndicesService; import org.opensearch.ingest.IngestService; import org.opensearch.ingest.Processor; @@ -57,8 +60,11 @@ public void testQueryPhaseSearcher() { public void testProcessors() { NeuralSearch plugin = new NeuralSearch(); + Settings settings = Settings.builder().build(); + Environment environment = mock(Environment.class); + when(environment.settings()).thenReturn(settings); Processor.Parameters processorParams = new Processor.Parameters( - null, + environment, null, null, null, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java new file mode 100644 index 000000000..dd517aa17 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; + +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; + +public class TextChunkingProcessorIT extends BaseNeuralSearchIT { + private static final String INDEX_NAME = "text_chunking_test_index"; + + private static final String OUTPUT_FIELD = "body_chunk"; + + private static final String INTERMEDIATE_FIELD = "body_chunk_intermediate"; + + private static final String FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME = + "pipeline-text-chunking-fixed-token-length-standard-tokenizer"; + + private static final String FIXED_TOKEN_LENGTH_PIPELINE_WITH_LETTER_TOKENIZER_NAME = + "pipeline-text-chunking-fixed-token-length-letter-tokenizer"; + + private static final String FIXED_TOKEN_LENGTH_PIPELINE_WITH_LOWERCASE_TOKENIZER_NAME = + "pipeline-text-chunking-fixed-token-length-lowercase-tokenizer"; + + private static final String DELIMITER_PIPELINE_NAME = "pipeline-text-chunking-delimiter"; + + private static final String CASCADE_PIPELINE_NAME = "pipeline-text-chunking-cascade"; + + private static final String TEST_DOCUMENT = "processor/chunker/TextChunkingTestDocument.json"; + + private static final String TEST_LONG_DOCUMENT = "processor/chunker/TextChunkingTestLongDocument.json"; + + private static final Map PIPELINE_CONFIGS_BY_NAME = Map.of( + FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME, + "processor/chunker/PipelineForFixedTokenLengthChunkerWithStandardTokenizer.json", + FIXED_TOKEN_LENGTH_PIPELINE_WITH_LETTER_TOKENIZER_NAME, + "processor/chunker/PipelineForFixedTokenLengthChunkerWithLetterTokenizer.json", + FIXED_TOKEN_LENGTH_PIPELINE_WITH_LOWERCASE_TOKENIZER_NAME, + "processor/chunker/PipelineForFixedTokenLengthChunkerWithLowercaseTokenizer.json", + DELIMITER_PIPELINE_NAME, + "processor/chunker/PipelineForDelimiterChunker.json", + CASCADE_PIPELINE_NAME, + "processor/chunker/PipelineForCascadedChunker.json" + ); + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + @SneakyThrows + public void testTextChunkingProcessor_withFixedTokenLengthAlgorithmStandardTokenizer_thenSucceed() { + try { + createPipelineProcessor(FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME); + createTextChunkingIndex(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME); + ingestDocument(TEST_DOCUMENT); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + validateIndexIngestResults(INDEX_NAME, OUTPUT_FIELD, expectedPassages); + } finally { + wipeOfTestResources(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME, null, null); + } + } + + @SneakyThrows + public void testTextChunkingProcessor_withFixedTokenLengthAlgorithmLetterTokenizer_thenSucceed() { + try { + createPipelineProcessor(FIXED_TOKEN_LENGTH_PIPELINE_WITH_LETTER_TOKENIZER_NAME); + createTextChunkingIndex(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_LETTER_TOKENIZER_NAME); + ingestDocument(TEST_DOCUMENT); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by standard "); + expectedPassages.add("tokenizer in OpenSearch."); + validateIndexIngestResults(INDEX_NAME, OUTPUT_FIELD, expectedPassages); + } finally { + wipeOfTestResources(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_LETTER_TOKENIZER_NAME, null, null); + } + } + + @SneakyThrows + public void testTextChunkingProcessor_withFixedTokenLengthAlgorithmLowercaseTokenizer_thenSucceed() { + try { + createPipelineProcessor(FIXED_TOKEN_LENGTH_PIPELINE_WITH_LOWERCASE_TOKENIZER_NAME); + createTextChunkingIndex(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_LOWERCASE_TOKENIZER_NAME); + ingestDocument(TEST_DOCUMENT); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by standard "); + expectedPassages.add("tokenizer in OpenSearch."); + validateIndexIngestResults(INDEX_NAME, OUTPUT_FIELD, expectedPassages); + } finally { + wipeOfTestResources(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_LOWERCASE_TOKENIZER_NAME, null, null); + } + } + + @SneakyThrows + public void testTextChunkingProcessor_withFixedTokenLengthAlgorithmStandardTokenizer_whenExceedMaxTokenCount_thenFail() { + try { + createPipelineProcessor(FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME); + createTextChunkingIndex(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME); + Exception exception = assertThrows(Exception.class, () -> ingestDocument(TEST_LONG_DOCUMENT)); + // max_token_count is 100 by index settings + assert (exception.getMessage() + .contains("The number of tokens produced by calling _analyze has exceeded the allowed maximum of [100].")); + assertEquals(0, getDocCount(INDEX_NAME)); + } finally { + wipeOfTestResources(INDEX_NAME, FIXED_TOKEN_LENGTH_PIPELINE_WITH_STANDARD_TOKENIZER_NAME, null, null); + } + } + + @SneakyThrows + public void testTextChunkingProcessor_withDelimiterAlgorithm_successful() { + try { + createPipelineProcessor(DELIMITER_PIPELINE_NAME); + createTextChunkingIndex(INDEX_NAME, DELIMITER_PIPELINE_NAME); + ingestDocument(TEST_DOCUMENT); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked."); + expectedPassages.add( + " The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + validateIndexIngestResults(INDEX_NAME, OUTPUT_FIELD, expectedPassages); + } finally { + wipeOfTestResources(INDEX_NAME, DELIMITER_PIPELINE_NAME, null, null); + } + } + + @SneakyThrows + public void testTextChunkingProcessor_withCascadePipeline_successful() { + try { + createPipelineProcessor(CASCADE_PIPELINE_NAME); + createTextChunkingIndex(INDEX_NAME, CASCADE_PIPELINE_NAME); + ingestDocument(TEST_DOCUMENT); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked."); + expectedPassages.add(" The document contains a single paragraph, two sentences and 24 "); + expectedPassages.add("tokens by standard tokenizer in OpenSearch."); + validateIndexIngestResults(INDEX_NAME, OUTPUT_FIELD, expectedPassages); + + expectedPassages.clear(); + expectedPassages.add("This is an example document to be chunked."); + expectedPassages.add( + " The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + validateIndexIngestResults(INDEX_NAME, INTERMEDIATE_FIELD, expectedPassages); + } finally { + wipeOfTestResources(INDEX_NAME, CASCADE_PIPELINE_NAME, null, null); + } + } + + private void validateIndexIngestResults(String indexName, String fieldName, Object expected) { + assertEquals(1, getDocCount(indexName)); + MatchAllQueryBuilder query = new MatchAllQueryBuilder(); + Map searchResults = search(indexName, query, 10); + assertNotNull(searchResults); + Map document = getFirstInnerHit(searchResults); + assertNotNull(document); + Object documentSource = document.get("_source"); + assert (documentSource instanceof Map); + @SuppressWarnings("unchecked") + Map documentSourceMap = (Map) documentSource; + assert (documentSourceMap).containsKey(fieldName); + Object ingestOutputs = documentSourceMap.get(fieldName); + assertEquals(expected, ingestOutputs); + } + + private void createPipelineProcessor(String pipelineName) throws Exception { + URL pipelineURLPath = classLoader.getResource(PIPELINE_CONFIGS_BY_NAME.get(pipelineName)); + Objects.requireNonNull(pipelineURLPath); + String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_ingest/pipeline/" + pipelineName, + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("true", node.get("acknowledged").toString()); + } + + private void createTextChunkingIndex(String indexName, String pipelineName) throws Exception { + URL indexSettingsURLPath = classLoader.getResource("processor/chunker/TextChunkingIndexSettings.json"); + Objects.requireNonNull(indexSettingsURLPath); + createIndexWithConfiguration(indexName, Files.readString(Path.of(indexSettingsURLPath.toURI())), pipelineName); + } + + private void ingestDocument(String documentPath) throws Exception { + URL documentURLPath = classLoader.getResource(documentPath); + Objects.requireNonNull(documentURLPath); + String ingestDocument = Files.readString(Path.of(documentURLPath.toURI())); + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java new file mode 100644 index 000000000..934918e18 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorTests.java @@ -0,0 +1,668 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.junit.Before; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.mock; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.env.TestEnvironment; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.index.analysis.TokenizerFactory; +import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.indices.IndicesService; +import org.opensearch.indices.analysis.AnalysisModule; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.chunker.DelimiterChunker; +import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker; +import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; +import org.opensearch.plugins.AnalysisPlugin; +import org.opensearch.test.OpenSearchTestCase; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.ALGORITHM_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; + +public class TextChunkingProcessorTests extends OpenSearchTestCase { + + private TextChunkingProcessorFactory textChunkingProcessorFactory; + + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String INPUT_FIELD = "body"; + private static final String INPUT_NESTED_FIELD_KEY = "nested"; + private static final String OUTPUT_FIELD = "body_chunk"; + private static final String INDEX_NAME = "_index"; + + @SneakyThrows + private AnalysisRegistry getAnalysisRegistry() { + Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build(); + Environment environment = TestEnvironment.newEnvironment(settings); + AnalysisPlugin plugin = new AnalysisPlugin() { + + @Override + public Map> getTokenizers() { + return singletonMap( + "keyword", + (indexSettings, environment, name, settings) -> TokenizerFactory.newFactory( + name, + () -> new MockTokenizer(MockTokenizer.KEYWORD, false) + ) + ); + } + }; + return new AnalysisModule(environment, singletonList(plugin)).getAnalysisRegistry(); + } + + @Before + public void setup() { + Metadata metadata = mock(Metadata.class); + Environment environment = mock(Environment.class); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(environment.settings()).thenReturn(settings); + ClusterState clusterState = mock(ClusterState.class); + ClusterService clusterService = mock(ClusterService.class); + IndicesService indicesService = mock(IndicesService.class); + when(metadata.index(anyString())).thenReturn(null); + when(clusterState.metadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + textChunkingProcessorFactory = new TextChunkingProcessorFactory(environment, clusterService, indicesService, getAnalysisRegistry()); + } + + private Map createFixedTokenLengthParameters() { + Map parameters = new HashMap<>(); + parameters.put(FixedTokenLengthChunker.TOKEN_LIMIT_FIELD, 10); + return parameters; + } + + private List> createSourceDataListNestedMap() { + Map documents = new HashMap<>(); + documents.put(INPUT_FIELD, createSourceDataString()); + return List.of(documents, documents); + } + + private Map createFixedTokenLengthParametersWithMaxChunkLimit(int maxChunkLimit) { + Map parameters = new HashMap<>(); + parameters.put(FixedTokenLengthChunker.TOKEN_LIMIT_FIELD, 10); + parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + return parameters; + } + + private Map createDelimiterParameters() { + Map parameters = new HashMap<>(); + parameters.put(DelimiterChunker.DELIMITER_FIELD, "."); + return parameters; + } + + private Map createStringFieldMap() { + Map fieldMap = new HashMap<>(); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + return fieldMap; + } + + private Map createNestedFieldMap() { + Map fieldMap = new HashMap<>(); + fieldMap.put(INPUT_NESTED_FIELD_KEY, Map.of(INPUT_FIELD, OUTPUT_FIELD)); + return fieldMap; + } + + @SneakyThrows + private TextChunkingProcessor createDefaultAlgorithmInstance(Map fieldMap) { + Map config = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + return textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextChunkingProcessor createFixedTokenLengthInstance(Map fieldMap) { + Map config = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParameters()); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + return textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextChunkingProcessor createFixedTokenLengthInstanceWithMaxChunkLimit(Map fieldMap, int maxChunkLimit) { + Map config = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParametersWithMaxChunkLimit(maxChunkLimit)); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + return textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + private TextChunkingProcessor createDelimiterInstance() { + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + algorithmMap.put(DelimiterChunker.ALGORITHM_NAME, createDelimiterParameters()); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + return textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + public void testCreate_whenAlgorithmFieldMissing_thenFail() { + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + config.put(FIELD_MAP_FIELD, fieldMap); + Map registry = new HashMap<>(); + OpenSearchParseException openSearchParseException = assertThrows( + OpenSearchParseException.class, + () -> textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals( + String.format(Locale.ROOT, "[%s] required property is missing", ALGORITHM_FIELD), + openSearchParseException.getMessage() + ); + } + + @SneakyThrows + public void testCreate_whenMaxChunkLimitInvalidValue_thenFail() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParametersWithMaxChunkLimit(-2)); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assert (illegalArgumentException.getMessage() + .contains(String.format(Locale.ROOT, "Parameter [%s] must be positive", MAX_CHUNK_LIMIT_FIELD))); + } + + @SneakyThrows + public void testCreate_whenMaxChunkLimitDisabledValue_thenSucceed() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParametersWithMaxChunkLimit(-1)); + config.put(FIELD_MAP_FIELD, fieldMap); + config.put(ALGORITHM_FIELD, algorithmMap); + textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + public void testCreate_whenAlgorithmMapMultipleAlgorithms_thenFail() { + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(TextChunkingProcessor.FIELD_MAP_FIELD, fieldMap); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, createFixedTokenLengthParameters()); + algorithmMap.put(DelimiterChunker.ALGORITHM_NAME, createDelimiterParameters()); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals( + String.format(Locale.ROOT, "Unable to create %s processor as [%s] contains multiple algorithms", TYPE, ALGORITHM_FIELD), + illegalArgumentException.getMessage() + ); + } + + public void testCreate_wheAlgorithmMapInvalidAlgorithmName_thenFail() { + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + String invalid_algorithm_type = "invalid algorithm"; + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(TextChunkingProcessor.FIELD_MAP_FIELD, fieldMap); + algorithmMap.put(invalid_algorithm_type, createFixedTokenLengthParameters()); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assert (illegalArgumentException.getMessage() + .contains(String.format(Locale.ROOT, "Chunking algorithm [%s] is not supported.", invalid_algorithm_type))); + } + + public void testCreate_whenAlgorithmMapInvalidAlgorithmType_thenFail() { + Map config = new HashMap<>(); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(TextChunkingProcessor.FIELD_MAP_FIELD, fieldMap); + algorithmMap.put(FixedTokenLengthChunker.ALGORITHM_NAME, 1); + config.put(ALGORITHM_FIELD, algorithmMap); + Map registry = new HashMap<>(); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> textChunkingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals( + String.format( + Locale.ROOT, + "Unable to create %s processor as parameters for [%s] algorithm must be an object", + TYPE, + FixedTokenLengthChunker.ALGORITHM_NAME + ), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testGetType() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + String type = processor.getType(); + assertEquals(TYPE, type); + } + + private String createSourceDataString() { + return "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + } + + private List createSourceDataListStrings() { + List documents = new ArrayList<>(); + documents.add( + "This is the first document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + documents.add( + "This is the second document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + return documents; + } + + private List createSourceDataListWithInvalidType() { + List documents = new ArrayList<>(); + documents.add( + "This is the first document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + documents.add(1); + return documents; + } + + private List createSourceDataListWithHybridType() { + List documents = new ArrayList<>(); + documents.add( + "This is the first document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + documents.add(ImmutableMap.of()); + return documents; + } + + private List createSourceDataListWithNull() { + List documents = new ArrayList<>(); + documents.add( + "This is the first document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + documents.add(null); + return documents; + } + + private Map createSourceDataNestedMap() { + Map documents = new HashMap<>(); + documents.put(INPUT_FIELD, createSourceDataString()); + return documents; + } + + private Map createSourceDataInvalidNestedMap() { + Map documents = new HashMap<>(); + documents.put(INPUT_FIELD, Map.of(INPUT_NESTED_FIELD_KEY, 1)); + return documents; + } + + private Map createMaxDepthLimitExceedMap(int maxDepth) { + if (maxDepth > 21) { + return null; + } + Map resultMap = new HashMap<>(); + Map innerMap = createMaxDepthLimitExceedMap(maxDepth + 1); + if (Objects.nonNull(innerMap)) { + resultMap.put(INPUT_FIELD, innerMap); + } + return resultMap; + } + + private IngestDocument createIngestDocumentWithNestedSourceData(Object sourceData) { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(INPUT_NESTED_FIELD_KEY, sourceData); + sourceAndMetadata.put(IndexFieldMapper.NAME, INDEX_NAME); + return new IngestDocument(sourceAndMetadata, new HashMap<>()); + } + + private IngestDocument createIngestDocumentWithSourceData(Object sourceData) { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(INPUT_FIELD, sourceData); + sourceAndMetadata.put(IndexFieldMapper.NAME, INDEX_NAME); + return new IngestDocument(sourceAndMetadata, new HashMap<>()); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataStringWithMaxChunkLimit_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), 5); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataStringWithMaxChunkLimitTwice_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), 5); + for (int i = 0; i < 2; i++) { + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataStringWithMaxChunkLimitDisabled_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), -1); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataStringExceedMaxChunkLimit_thenFail() { + int maxChunkLimit = 1; + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), maxChunkLimit); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListExceedMaxChunkLimit_thenFail() { + int maxChunkLimit = 5; + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), maxChunkLimit); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataListStrings()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListDisabledMaxChunkLimit_thenFail() { + int maxChunkLimit = -1; + TextChunkingProcessor processor = createFixedTokenLengthInstanceWithMaxChunkLimit(createStringFieldMap(), maxChunkLimit); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataListStrings()); + processor.execute(ingestDocument); + } + + @SneakyThrows + public void testCreate_withDefaultAlgorithm_andSourceDataString_thenSucceed() { + TextChunkingProcessor processor = createDefaultAlgorithmInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add( + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." + ); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataString_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataInvalidType_thenFail() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(INPUT_FIELD, 1); + sourceAndMetadata.put(IndexFieldMapper.NAME, INDEX_NAME); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + String.format(Locale.ROOT, "field [%s] is neither string nor nested type, cannot process it", INPUT_FIELD), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListStrings_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataListStrings()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is the first document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + expectedPassages.add("This is the second document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListWithInvalidType_thenFail() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataListWithInvalidType()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + String.format(Locale.ROOT, "list type field [%s] has non-string value, cannot process it", INPUT_FIELD), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListWithNull_thenFail() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataListWithNull()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + String.format(Locale.ROOT, "list type field [%s] has null, cannot process it", INPUT_FIELD), + illegalArgumentException.getMessage() + ); + } + + @SuppressWarnings("unchecked") + @SneakyThrows + public void testExecute_withFixedTokenLength_andFieldMapNestedMap_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createNestedFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(createSourceDataNestedMap()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(INPUT_NESTED_FIELD_KEY); + Object nestedResult = document.getSourceAndMetadata().get(INPUT_NESTED_FIELD_KEY); + assert (nestedResult instanceof Map); + assert ((Map) nestedResult).containsKey(OUTPUT_FIELD); + Object passages = ((Map) nestedResult).get(OUTPUT_FIELD); + assert (passages instanceof List); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andMaxDepthLimitExceedFieldMap_thenFail() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createNestedFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(createMaxDepthLimitExceedMap(0)); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + String.format(Locale.ROOT, "map type field [%s] reached max depth limit, cannot process it", INPUT_NESTED_FIELD_KEY), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andFieldMapNestedMap_thenFail() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createNestedFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(createSourceDataInvalidNestedMap()); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> processor.execute(ingestDocument) + ); + assertEquals( + String.format(Locale.ROOT, "map type field [%s] has non-string type, cannot process it", INPUT_NESTED_FIELD_KEY), + illegalArgumentException.getMessage() + ); + } + + @SneakyThrows + @SuppressWarnings("unchecked") + public void testExecute_withFixedTokenLength_andFieldMapNestedMap_sourceDataList_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createNestedFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithNestedSourceData(createSourceDataListNestedMap()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(INPUT_NESTED_FIELD_KEY); + Object nestedResult = document.getSourceAndMetadata().get(INPUT_NESTED_FIELD_KEY); + + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assert (nestedResult instanceof List); + assertEquals(((List) nestedResult).size(), 2); + for (Object result : (List) nestedResult) { + assert (result instanceof Map); + assert ((Map) result).containsKey(OUTPUT_FIELD); + Object passages = ((Map) result).get(OUTPUT_FIELD); + assert (passages instanceof List); + assertEquals(expectedPassages, passages); + } + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataListWithHybridType_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + List sourceDataList = createSourceDataListWithHybridType(); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(sourceDataList); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(INPUT_FIELD); + Object listResult = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (listResult instanceof List); + assertEquals(((List) listResult).size(), 0); + } + + @SneakyThrows + public void testExecute_withFixedTokenLength_andSourceDataNull_thenSucceed() { + TextChunkingProcessor processor = createFixedTokenLengthInstance(createStringFieldMap()); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(null); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(INPUT_FIELD); + Object listResult = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (listResult instanceof List); + assertEquals(((List) listResult).size(), 0); + } + + @SneakyThrows + public void testExecute_withDelimiter_andSourceDataString_thenSucceed() { + TextChunkingProcessor processor = createDelimiterInstance(); + IngestDocument ingestDocument = createIngestDocumentWithSourceData(createSourceDataString()); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey(OUTPUT_FIELD); + Object passages = document.getSourceAndMetadata().get(OUTPUT_FIELD); + assert (passages instanceof List); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked."); + expectedPassages.add(" The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactoryTests.java new file mode 100644 index 000000000..21859c24e --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactoryTests.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import org.mockito.Mock; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD; + +public class ChunkerFactoryTests extends OpenSearchTestCase { + + @Mock + private AnalysisRegistry analysisRegistry; + + public void testCreate_FixedTokenLength() { + Chunker chunker = ChunkerFactory.create(FixedTokenLengthChunker.ALGORITHM_NAME, createChunkParameters()); + assertNotNull(chunker); + assert (chunker instanceof FixedTokenLengthChunker); + } + + public void testCreate_Delimiter() { + Chunker chunker = ChunkerFactory.create(DelimiterChunker.ALGORITHM_NAME, createChunkParameters()); + assertNotNull(chunker); + assert (chunker instanceof DelimiterChunker); + } + + public void testCreate_Invalid() { + String invalidChunkerName = "Invalid Chunker Algorithm"; + assertThrows(NullPointerException.class, () -> ChunkerFactory.create(invalidChunkerName, createChunkParameters())); + } + + private Map createChunkParameters() { + Map parameters = new HashMap<>(); + parameters.put(ANALYSIS_REGISTRY_FIELD, analysisRegistry); + return parameters; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java new file mode 100644 index 000000000..54e296861 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java @@ -0,0 +1,125 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.junit.Assert; +import org.opensearch.test.OpenSearchTestCase; + +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.DelimiterChunker.DELIMITER_FIELD; + +public class DelimiterChunkerTests extends OpenSearchTestCase { + + public void testCreate_withDelimiterFieldInvalidType_thenFail() { + Exception exception = assertThrows( + IllegalArgumentException.class, + () -> new DelimiterChunker(Map.of(DELIMITER_FIELD, List.of(""))) + ); + Assert.assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", DELIMITER_FIELD, String.class.getName()), + exception.getMessage() + ); + } + + public void testCreate_withDelimiterFieldEmptyString_thenFail() { + Exception exception = assertThrows(IllegalArgumentException.class, () -> new DelimiterChunker(Map.of(DELIMITER_FIELD, ""))); + Assert.assertEquals(String.format(Locale.ROOT, "Parameter [%s] should not be empty.", DELIMITER_FIELD), exception.getMessage()); + } + + public void testChunk_withNewlineDelimiter_thenSucceed() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); + String content = "a\nb\nc\nd"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("a\n", "b\n", "c\n", "d"), chunkResult); + } + + public void testChunk_withDefaultDelimiter_thenSucceed() { + // default delimiter is \n\n + DelimiterChunker chunker = new DelimiterChunker(Map.of()); + String content = "a.b\n\nc.d"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("a.b\n\n", "c.d"), chunkResult); + } + + public void testChunk_withOnlyDelimiterContent_thenSucceed() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); + String content = "\n"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("\n"), chunkResult); + } + + public void testChunk_WithAllDelimiterContent_thenSucceed() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n")); + String content = "\n\n\n"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("\n", "\n", "\n"), chunkResult); + } + + public void testChunk_WithPeriodDelimiters_thenSucceed() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, ".")); + String content = "a.b.cc.d."; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("a.", "b.", "cc.", "d."), chunkResult); + } + + public void testChunk_withDoubleNewlineDelimiter_thenSucceed() { + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n")); + String content = "\n\na\n\n\n"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult); + } + + public void testChunk_whenExceedMaxChunkLimit_thenFail() { + int maxChunkLimit = 2; + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + String content = "\n\na\n\n\n"; + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> chunker.chunk(content, Map.of()) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } + + public void testChunk_whenWithinMaxChunkLimit_thenSucceed() { + int maxChunkLimit = 3; + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + String content = "\n\na\n\n\n"; + List chunkResult = chunker.chunk(content, Map.of()); + assertEquals(List.of("\n\n", "a\n\n", "\n"), chunkResult); + } + + public void testChunk_whenExceedRuntimeMaxChunkLimit_thenFail() { + int maxChunkLimit = 3; + DelimiterChunker chunker = new DelimiterChunker(Map.of(DELIMITER_FIELD, "\n\n", MAX_CHUNK_LIMIT_FIELD, maxChunkLimit)); + String content = "\n\na\n\n\n"; + int runtimeMaxChunkLimit = 2; + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> chunker.chunk(content, Map.of(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit)) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java new file mode 100644 index 000000000..bbcaa7069 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunkerTests.java @@ -0,0 +1,309 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.chunker; + +import lombok.SneakyThrows; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.junit.Before; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.env.TestEnvironment; +import org.opensearch.index.analysis.TokenizerFactory; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.indices.analysis.AnalysisModule; +import org.opensearch.plugins.AnalysisPlugin; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.chunker.Chunker.MAX_CHUNK_LIMIT_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.ALGORITHM_NAME; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKEN_LIMIT_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.OVERLAP_RATE_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKENIZER_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD; + +public class FixedTokenLengthChunkerTests extends OpenSearchTestCase { + + private FixedTokenLengthChunker fixedTokenLengthChunker; + + @Before + public void setup() { + fixedTokenLengthChunker = createFixedTokenLengthChunker(Map.of()); + } + + @SneakyThrows + public FixedTokenLengthChunker createFixedTokenLengthChunker(Map parameters) { + Map nonRuntimeParameters = new HashMap<>(parameters); + Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build(); + Environment environment = TestEnvironment.newEnvironment(settings); + AnalysisPlugin plugin = new AnalysisPlugin() { + + @Override + public Map> getTokenizers() { + return singletonMap( + "keyword", + (indexSettings, environment, name, settings) -> TokenizerFactory.newFactory( + name, + () -> new MockTokenizer(MockTokenizer.KEYWORD, false) + ) + ); + } + }; + AnalysisRegistry analysisRegistry = new AnalysisModule(environment, singletonList(plugin)).getAnalysisRegistry(); + nonRuntimeParameters.put(ANALYSIS_REGISTRY_FIELD, analysisRegistry); + return new FixedTokenLengthChunker(nonRuntimeParameters); + } + + public void testParseParameters_whenNoParams_thenSuccessful() { + fixedTokenLengthChunker.parseParameters(Map.of()); + } + + public void testParseParameters_whenIllegalTokenLimitType_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, "invalid token limit"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", TOKEN_LIMIT_FIELD, Integer.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenIllegalTokenLimitValue_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, -1); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be positive.", TOKEN_LIMIT_FIELD), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenIllegalOverlapRateType_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(OVERLAP_RATE_FIELD, "invalid overlap rate"); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", OVERLAP_RATE_FIELD, Double.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenTooLargeOverlapRate_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(OVERLAP_RATE_FIELD, 0.6); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be between %s and %s", OVERLAP_RATE_FIELD, 0.0, 0.5), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenTooSmallOverlapRateValue_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(OVERLAP_RATE_FIELD, -1); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be between %s and %s", OVERLAP_RATE_FIELD, 0.0, 0.5), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenIllegalTokenizerType_thenFail() { + Map parameters = new HashMap<>(); + parameters.put(TOKENIZER_FIELD, 111); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assertEquals( + String.format(Locale.ROOT, "Parameter [%s] must be of %s type", TOKENIZER_FIELD, String.class.getName()), + illegalArgumentException.getMessage() + ); + } + + public void testParseParameters_whenUnsupportedTokenizer_thenFail() { + String ngramTokenizer = "ngram"; + Map parameters = Map.of(TOKENIZER_FIELD, ngramTokenizer); + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.parseParameters(parameters) + ); + assert (illegalArgumentException.getMessage() + .contains(String.format(Locale.ROOT, "Tokenizer [%s] is not supported for [%s] algorithm.", ngramTokenizer, ALGORITHM_NAME))); + } + + public void testChunk_whenTokenizationException_thenFail() { + // lowercase tokenizer is not supported in unit tests + String lowercaseTokenizer = "lowercase"; + Map parameters = Map.of(TOKENIZER_FIELD, lowercaseTokenizer); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + IllegalStateException illegalStateException = assertThrows( + IllegalStateException.class, + () -> fixedTokenLengthChunker.chunk(content, parameters) + ); + assert (illegalStateException.getMessage() + .contains(String.format(Locale.ROOT, "analyzer %s throws exception", lowercaseTokenizer))); + } + + public void testChunk_withEmptyInput_thenSucceed() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(TOKENIZER_FIELD, "standard"); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + String content = ""; + List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); + assert (passages.isEmpty()); + } + + public void testChunk_withTokenLimit10_thenSucceed() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(TOKENIZER_FIELD, "standard"); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + public void testChunk_withTokenLimit20_thenSucceed() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 20); + parameters.put(TOKENIZER_FIELD, "standard"); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); + List expectedPassages = new ArrayList<>(); + expectedPassages.add( + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by " + ); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + public void testChunk_withOverlapRateHalf_thenSucceed() { + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(OVERLAP_RATE_FIELD, 0.5); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + List passages = fixedTokenLengthChunker.chunk(content, Map.of()); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("to be chunked. The document contains a single paragraph, two "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("sentences and 24 tokens by standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + public void testChunk_whenExceedMaxChunkLimit_thenFail() { + int maxChunkLimit = 2; + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(TOKENIZER_FIELD, "standard"); + parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.chunk(content, runtimeParameters) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } + + public void testChunk_whenWithinMaxChunkLimit_thenSucceed() { + int maxChunkLimit = 3; + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(TOKENIZER_FIELD, "standard"); + parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + List passages = fixedTokenLengthChunker.chunk(content, runtimeParameters); + List expectedPassages = new ArrayList<>(); + expectedPassages.add("This is an example document to be chunked. The document "); + expectedPassages.add("contains a single paragraph, two sentences and 24 tokens by "); + expectedPassages.add("standard tokenizer in OpenSearch."); + assertEquals(expectedPassages, passages); + } + + public void testChunk_whenExceedRuntimeMaxChunkLimit_thenFail() { + int maxChunkLimit = 3, runtimeMaxChunkLimit = 2; + Map parameters = new HashMap<>(); + parameters.put(TOKEN_LIMIT_FIELD, 10); + parameters.put(TOKENIZER_FIELD, "standard"); + parameters.put(MAX_CHUNK_LIMIT_FIELD, maxChunkLimit); + FixedTokenLengthChunker fixedTokenLengthChunker = createFixedTokenLengthChunker(parameters); + Map runtimeParameters = new HashMap<>(); + runtimeParameters.put(MAX_TOKEN_COUNT_FIELD, 10000); + runtimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit); + String content = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + IllegalArgumentException illegalArgumentException = assertThrows( + IllegalArgumentException.class, + () -> fixedTokenLengthChunker.chunk(content, runtimeParameters) + ); + assert (illegalArgumentException.getMessage() + .contains( + String.format( + Locale.ROOT, + "The number of chunks produced by %s processor has exceeded the allowed maximum of [%s].", + TYPE, + maxChunkLimit + ) + )); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactoryTests.java new file mode 100644 index 000000000..3d9993e7b --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextChunkingProcessorFactoryTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.apache.lucene.tests.analysis.MockTokenizer; +import org.junit.Before; +import java.util.HashMap; +import java.util.Map; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.mockito.Mockito.mock; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.env.Environment; +import org.opensearch.env.TestEnvironment; +import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.index.analysis.TokenizerFactory; +import org.opensearch.indices.IndicesService; +import org.opensearch.indices.analysis.AnalysisModule; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.TextChunkingProcessor; +import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker; +import org.opensearch.plugins.AnalysisPlugin; +import org.opensearch.test.OpenSearchTestCase; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.TYPE; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextChunkingProcessor.ALGORITHM_FIELD; + +public class TextChunkingProcessorFactoryTests extends OpenSearchTestCase { + + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final Map algorithmMap = Map.of(FixedTokenLengthChunker.ALGORITHM_NAME, new HashMap<>()); + + private TextChunkingProcessorFactory textChunkingProcessorFactory; + + @SneakyThrows + private AnalysisRegistry getAnalysisRegistry() { + Settings settings = Settings.builder().put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString()).build(); + Environment environment = TestEnvironment.newEnvironment(settings); + AnalysisPlugin plugin = new AnalysisPlugin() { + + @Override + public Map> getTokenizers() { + return singletonMap( + "keyword", + (indexSettings, environment, name, settings) -> TokenizerFactory.newFactory( + name, + () -> new MockTokenizer(MockTokenizer.KEYWORD, false) + ) + ); + } + }; + return new AnalysisModule(environment, singletonList(plugin)).getAnalysisRegistry(); + } + + @Before + public void setup() { + Environment environment = mock(Environment.class); + ClusterService clusterService = mock(ClusterService.class); + IndicesService indicesService = mock(IndicesService.class); + this.textChunkingProcessorFactory = new TextChunkingProcessorFactory( + environment, + clusterService, + indicesService, + getAnalysisRegistry() + ); + } + + @SneakyThrows + public void testTextChunkingProcessorFactory_whenAllParamsPassed_thenSuccessful() { + final Map processorFactories = new HashMap<>(); + Map config = new HashMap<>(); + config.put(ALGORITHM_FIELD, algorithmMap); + config.put(FIELD_MAP_FIELD, new HashMap<>()); + TextChunkingProcessor textChunkingProcessor = textChunkingProcessorFactory.create( + processorFactories, + PROCESSOR_TAG, + DESCRIPTION, + config + ); + assertNotNull(textChunkingProcessor); + assertEquals(TYPE, textChunkingProcessor.getType()); + } + + @SneakyThrows + public void testTextChunkingProcessorFactory_whenOnlyFieldMap_thenFail() { + final Map processorFactories = new HashMap<>(); + Map config = new HashMap<>(); + config.put(FIELD_MAP_FIELD, new HashMap<>()); + Exception exception = assertThrows( + Exception.class, + () -> textChunkingProcessorFactory.create(processorFactories, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[" + ALGORITHM_FIELD + "] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testTextChunkingProcessorFactory_whenOnlyAlgorithm_thenFail() { + final Map processorFactories = new HashMap<>(); + Map config = new HashMap<>(); + config.put(ALGORITHM_FIELD, algorithmMap); + Exception exception = assertThrows( + Exception.class, + () -> textChunkingProcessorFactory.create(processorFactories, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[" + FIELD_MAP_FIELD + "] required property is missing", exception.getMessage()); + } +} diff --git a/src/test/resources/processor/chunker/PipelineForCascadedChunker.json b/src/test/resources/processor/chunker/PipelineForCascadedChunker.json new file mode 100644 index 000000000..e7ba380d4 --- /dev/null +++ b/src/test/resources/processor/chunker/PipelineForCascadedChunker.json @@ -0,0 +1,29 @@ +{ + "description": "An example cascaded pipeline with fixed token length algorithm after chunking algorithm", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk_intermediate" + }, + "algorithm": { + "delimiter": { + "delimiter": "." + } + } + } + }, + { + "text_chunking": { + "field_map": { + "body_chunk_intermediate": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10 + } + } + } + } + ] +} diff --git a/src/test/resources/processor/chunker/PipelineForDelimiterChunker.json b/src/test/resources/processor/chunker/PipelineForDelimiterChunker.json new file mode 100644 index 000000000..c4e66f58c --- /dev/null +++ b/src/test/resources/processor/chunker/PipelineForDelimiterChunker.json @@ -0,0 +1,17 @@ +{ + "description": "An example delimiter chunker pipeline", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "delimiter": { + "delimiter": "." + } + } + } + } + ] +} diff --git a/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLetterTokenizer.json b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLetterTokenizer.json new file mode 100644 index 000000000..7026676f8 --- /dev/null +++ b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLetterTokenizer.json @@ -0,0 +1,18 @@ +{ + "description": "An example fixed token length chunker pipeline with letter tokenizer", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10, + "tokenizer": "letter" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLowercaseTokenizer.json b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLowercaseTokenizer.json new file mode 100644 index 000000000..cd1c67fc5 --- /dev/null +++ b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithLowercaseTokenizer.json @@ -0,0 +1,18 @@ +{ + "description": "An example fixed token length chunker pipeline with lowercase tokenizer", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10, + "tokenizer": "lowercase" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithStandardTokenizer.json b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithStandardTokenizer.json new file mode 100644 index 000000000..6c727b3b4 --- /dev/null +++ b/src/test/resources/processor/chunker/PipelineForFixedTokenLengthChunkerWithStandardTokenizer.json @@ -0,0 +1,18 @@ +{ + "description": "An example fixed token length chunker pipeline with standard tokenizer", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10, + "tokenizer": "standard" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/chunker/TextChunkingIndexSettings.json b/src/test/resources/processor/chunker/TextChunkingIndexSettings.json new file mode 100644 index 000000000..a2b074e69 --- /dev/null +++ b/src/test/resources/processor/chunker/TextChunkingIndexSettings.json @@ -0,0 +1,6 @@ +{ + "settings":{ + "index.analyze.max_token_count" : 100, + "default_pipeline": "%s" + } +} diff --git a/src/test/resources/processor/chunker/TextChunkingTestDocument.json b/src/test/resources/processor/chunker/TextChunkingTestDocument.json new file mode 100644 index 000000000..673e8b1cf --- /dev/null +++ b/src/test/resources/processor/chunker/TextChunkingTestDocument.json @@ -0,0 +1,3 @@ +{ + "body": "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch." +} diff --git a/src/test/resources/processor/chunker/TextChunkingTestLongDocument.json b/src/test/resources/processor/chunker/TextChunkingTestLongDocument.json new file mode 100644 index 000000000..71927887b --- /dev/null +++ b/src/test/resources/processor/chunker/TextChunkingTestLongDocument.json @@ -0,0 +1,3 @@ +{ + "body": "This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch. This is an example long document to be chunked. The document has more than 100 tokens by standard tokenizer in OpenSearch." +}