diff --git a/Directory.Packages.props b/Directory.Packages.props index 41efaafd0..2834bf71d 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -34,7 +34,7 @@ - + diff --git a/KernelMemory.sln b/KernelMemory.sln index 15db02fe8..7dc6a100e 100644 --- a/KernelMemory.sln +++ b/KernelMemory.sln @@ -7,6 +7,7 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-19E2-46EB-B000-513EE2F65769}" ProjectSection(SolutionItems) = preProject docs\404.html = docs\404.html + docs\azure.md = docs\azure.md docs\concepts.md = docs\concepts.md docs\csharp.png = docs\csharp.png docs\extensions.md = docs\extensions.md @@ -31,7 +32,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-1 docs\service.md = docs\service.md docs\_config.local.yml = docs\_config.local.yml docs\_config.yml = docs\_config.yml - docs\azure.md = docs\azure.md EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{0A43C65C-6007-4BB4-B3FE-8D439FC91841}" @@ -70,18 +70,18 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{6EF76FD8-4 .editorconfig = .editorconfig .gitattributes = .gitattributes .gitignore = .gitignore + azure.yaml = azure.yaml CODE_OF_CONDUCT.md = CODE_OF_CONDUCT.md COMMUNITY.md = COMMUNITY.md CONTRIBUTING.md = CONTRIBUTING.md Directory.Build.props = Directory.Build.props Directory.Packages.props = Directory.Packages.props Dockerfile = Dockerfile + KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings LICENSE = LICENSE nuget.config = nuget.config README.md = README.md SECURITY.md = SECURITY.md - KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings - azure.yaml = azure.yaml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".github", ".github", "{B8976338-7CDC-47AE-8502-C2FBAFBEBD68}" @@ -94,10 +94,10 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{48E79819-1E9E-4075-90DA-BAEC761C89B2}" ProjectSection(SolutionItems) = preProject .github\workflows\docker-build-push.yml = .github\workflows\docker-build-push.yml + .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml + .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml .github\workflows\github-pages-jekyll.yml = .github\workflows\github-pages-jekyll.yml .github\workflows\spell-check-with-typos.yml = .github\workflows\spell-check-with-typos.yml - .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml - .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "service", "service", "{87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}" @@ -116,9 +116,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{CA49F1A1 tools\run-qdrant.sh = tools\run-qdrant.sh tools\run-rabbitmq.sh = tools\run-rabbitmq.sh tools\run-redis.sh = tools\run-redis.sh + tools\run-s3ninja.sh = tools\run-s3ninja.sh tools\search.sh = tools\search.sh tools\upload-file.sh = tools\upload-file.sh - tools\run-s3ninja.sh = tools\run-s3ninja.sh tools\dockerize-amd64.sh = tools\dockerize-amd64.sh tools\dockerize-arm64.sh = tools\dockerize-arm64.sh EndProjectSection @@ -243,12 +243,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "109-dotnet-custom-webscrape EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "infra", "infra", "{B488168B-AD86-4CC5-9D89-324B6EB743D9}" ProjectSection(SolutionItems) = preProject + infra\AZD.md = infra\AZD.md infra\build-main.json.sh = infra\build-main.json.sh infra\main.bicep = infra\main.bicep infra\main.json = infra\main.json - infra\README.md = infra\README.md - infra\AZD.md = infra\AZD.md infra\main.parameters.json = infra\main.parameters.json + infra\README.md = infra\README.md EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3A947-B6F9-4306-BD42-21D8D1F42750}" @@ -323,6 +323,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "212-dotnet-ollama", "exampl EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Ollama", "extensions\Ollama\Ollama\Ollama.csproj", "{F192513B-265B-4943-A2A9-44E23B15BA18}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx.FunctionalTests", "extensions\ONNX\Onnx.FunctionalTests\Onnx.FunctionalTests.csproj", "{7BBD348E-CDD9-4462-B8C9-47613C5EC682}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx", "extensions\ONNX\Onnx\Onnx.csproj", "{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -591,6 +595,13 @@ Global {F192513B-265B-4943-A2A9-44E23B15BA18}.Debug|Any CPU.Build.0 = Debug|Any CPU {F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.ActiveCfg = Release|Any CPU {F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.Build.0 = Release|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Release|Any CPU.ActiveCfg = Release|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.Build.0 = Debug|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.ActiveCfg = Release|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -685,6 +696,8 @@ Global {84AEC1DD-CBAE-400A-949C-91BA373C587D} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} {B303885D-F64F-4EEB-B085-0014E863AF61} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} {F192513B-265B-4943-A2A9-44E23B15BA18} = {155DA079-E267-49AF-973A-D1D44681970F} + {7BBD348E-CDD9-4462-B8C9-47613C5EC682} = {3C17F42B-CFC8-4900-8CFB-88936311E919} + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151} = {155DA079-E267-49AF-973A-D1D44681970F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {CC136C62-115C-41D1-B414-F9473EFF6EA8} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj new file mode 100644 index 000000000..6143d9130 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj @@ -0,0 +1,36 @@ + + + + Microsoft.Onnx.FunctionalTests + Microsoft.Onnx.FunctionalTests + net8.0 + LatestMajor + true + enable + enable + false + $(NoWarn);KMEXP01; + + + + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + diff --git a/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs new file mode 100644 index 000000000..a2e491b6d --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Text; +using Microsoft.KernelMemory.AI; +using Microsoft.KernelMemory.AI.Onnx; +using Microsoft.KM.TestHelpers; +using Xunit.Abstractions; + +namespace Microsoft.Onnx.FunctionalTests; + +public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase +{ + private readonly OnnxTextGenerator _target; + private readonly Stopwatch _timer; + + public OnnxTextGeneratorTest( + IConfiguration cfg, + ITestOutputHelper output) : base(cfg, output) + { + this._timer = new Stopwatch(); + + this.OnnxConfig.Validate(); + this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null); + + var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir); + var modelFile = Directory.GetFiles(modelDirectory) + .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}"); + } + + [Fact] + [Trait("Category", "Onnx")] + public async Task ItGeneratesText() + { + var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy"); + var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}."; + var question = $"What is the current date?"; + var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>"; + + var options = new TextGenerationOptions(); + + // Act + this._timer.Restart(); + var tokens = this._target.GenerateTextAsync(prompt, options); + var result = new StringBuilder(); + await foreach (string token in tokens) + { + result.Append(token); + } + + this._timer.Stop(); + var answer = result.ToString(); + + // Assert + Console.WriteLine($"Model Output:\n=============================\n{answer}\n============================="); + Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs"); + Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase); + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + this._target.Dispose(); + } +} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Startup.cs b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs new file mode 100644 index 000000000..fe9b40a95 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +/* IMPORTANT: the Startup class must be at the root of the namespace and + * the namespace must match exactly (required by Xunit.DependencyInjection) */ + +namespace Microsoft.Onnx.FunctionalTests; + +public class Startup +{ + public void ConfigureHost(IHostBuilder hostBuilder) + { + var config = new ConfigurationBuilder() + .AddJsonFile("appsettings.json") + .AddJsonFile("appsettings.development.json", optional: true) + .AddJsonFile("appsettings.Development.json", optional: true) + .AddUserSecrets() + .AddEnvironmentVariables() + .Build(); + + hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config)); + } +} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Usings.cs b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs new file mode 100644 index 000000000..38b0ca7bb --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +global using Xunit; diff --git a/extensions/ONNX/Onnx.FunctionalTests/appsettings.json b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json new file mode 100644 index 000000000..ec84442d4 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json @@ -0,0 +1,67 @@ +{ + "KernelMemory": { + "ServiceAuthorization": { + "Endpoint": "http://127.0.0.1:9001/", + "AccessKey": "" + }, + "Services": { + "Onnx": { + // Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32" + "TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32" + }, + "SimpleVectorDb": { + // Options: "Disk" or "Volatile". Volatile data is lost after each execution. + "StorageType": "Volatile", + // Directory where files are stored. + "Directory": "_vectors" + }, + "AzureAISearch": { + // "ApiKey" or "AzureIdentity". For other options see . + // AzureIdentity: use automatic AAD authentication mechanism. You can test locally + // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. + "Auth": "AzureIdentity", + "Endpoint": "https://<...>", + "APIKey": "" + }, + "Postgres": { + // Postgres instance connection string + "ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public", + // Mandatory prefix to add to the name of table managed by KM, + // e.g. to exclude other tables in the same schema. + "TableNamePrefix": "tests-" + }, + "Qdrant": { + "Endpoint": "http://127.0.0.1:6333", + "APIKey": "" + }, + "OpenAI": { + // Name of the model used to generate text (text completion or chat completion) + "TextModel": "gpt-4o-mini", + // The max number of tokens supported by the text model. + "TextModelMaxTokenTotal": 16384, + // What type of text generation, by default autodetect using the model name. + // Possible values: "Auto", "TextCompletion", "Chat" + "TextGenerationType": "Auto", + // Name of the model used to generate text embeddings + "EmbeddingModel": "text-embedding-ada-002", + // The max number of tokens supported by the embedding model + // See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings + "EmbeddingModelMaxTokenTotal": 8191, + // OpenAI API Key + "APIKey": "", + // OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs) + "OrgId": "", + // Endpoint to use. By default the system uses 'https://api.openai.com/v1'. + // Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio. + "Endpoint": "", + // How many times to retry in case of throttling + "MaxRetries": 10 + } + } + }, + "Logging": { + "LogLevel": { + "Default": "Information" + } + } +} \ No newline at end of file diff --git a/extensions/ONNX/Onnx/DependencyInjection.cs b/extensions/ONNX/Onnx/DependencyInjection.cs new file mode 100644 index 000000000..55e3b5018 --- /dev/null +++ b/extensions/ONNX/Onnx/DependencyInjection.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory.AI; +using Microsoft.KernelMemory.AI.Onnx; + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +/// +/// Kernel Memory builder extensions +/// +public static partial class KernelMemoryBuilderExtensions +{ + public static IKernelMemoryBuilder WithOnnxTextGeneration( + this IKernelMemoryBuilder builder, + string modelPath, + uint maxTokenTotal, + ITextTokenizer? textTokenizer = null) + { + var config = new OnnxConfig + { + TextModelDir = modelPath + }; + + builder.Services.AddOnnxTextGeneration(config, textTokenizer); + + return builder; + } + + public static IKernelMemoryBuilder WithOnnxTextGeneration( + this IKernelMemoryBuilder builder, + OnnxConfig config, + ITextTokenizer? textTokenizer = null) + { + builder.Services.AddOnnxTextGeneration(config, textTokenizer); + return builder; + } +} + +/// +/// .NET IServiceCollection dependency injection extensions. +/// +public static partial class DependencyInjection +{ + public static IServiceCollection AddOnnxTextGeneration( + this IServiceCollection services, + OnnxConfig config, + ITextTokenizer? textTokenizer = null) + { + config.Validate(); + return services + .AddSingleton(serviceProvider => new OnnxTextGenerator( + config: config, + textTokenizer: textTokenizer, + loggerFactory: serviceProvider.GetService())); + } +} diff --git a/extensions/ONNX/Onnx/Onnx.csproj b/extensions/ONNX/Onnx/Onnx.csproj new file mode 100644 index 000000000..96ec73dff --- /dev/null +++ b/extensions/ONNX/Onnx/Onnx.csproj @@ -0,0 +1,32 @@ + + + + net8.0 + LatestMajor + Microsoft.KernelMemory.AI.Onnx + Microsoft.KernelMemory.AI.Onnx + $(NoWarn);KMEXP00;KMEXP01;CA1724; + + + + true + Microsoft.KernelMemory.AI.Onnx + ONNX LLM connector for Kernel Memory + Provide access to ONNX LLM models in Kernel Memory to generate text + ONNX, Memory, Kernel Memory, Semantic Memory, Episodic Memory, Declarative Memory, AI, Artificial Intelligence, Semantic Search, Memory DB + bin/$(Configuration)/$(TargetFramework)/$(AssemblyName).xml + + + + + + + + + + + + + + + diff --git a/extensions/ONNX/Onnx/OnnxConfig.cs b/extensions/ONNX/Onnx/OnnxConfig.cs new file mode 100644 index 000000000..4a0ce66fe --- /dev/null +++ b/extensions/ONNX/Onnx/OnnxConfig.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; + +#pragma warning disable IDE0130 // reduce number of "using" statements + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +public class OnnxConfig +{ + /// + /// An enum representing the possible text generation search types used by OnnxTextGenerator. + /// See https://onnxruntime.ai/docs/genai/reference/config.html#search-combinations for more details. + /// + public enum OnnxSearchType + { + /// + /// A decoding algorithm that keeps track of the top K sequences at each step. It explores + /// multiple paths simultaneously, balancing exploration and exploitation. Often results in more + /// coherent and higher quality text generation than Greedy Search would. + /// + BeamSearch, + + /// + /// The default and simplest decoding algorithm. At each step, a token is selected with the highest + /// probability as the next word in the sequence. + /// + GreedySearch, + + /// + /// Combined Top-P (Nucleus) and Top-K Sampling: A decoding algorithm that samples from the top k tokens + /// with the highest probabilities, while also considering the smallest set of tokens whose cumulative + /// probability exceeds a threshold p. This approach dynamically balances diversity and coherence in + /// text generation by adjusting the sampling pool based on both fixed and cumulative probability criteria. + /// + TopN + } + + /// + /// Path to the directory containing the .ONNX file for Text Generation. + /// + public string TextModelDir { get; set; } = string.Empty; + + /// + /// The maximum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html + /// + public int MaxTokens { get; set; } = 2048; + + /// + /// The minimum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html + /// + public uint MinLength { get; set; } = 0; + + /// + /// The algorithm used in text generation. Defaults to GreedySearch. + /// + public OnnxSearchType SearchType { get; set; } = OnnxSearchType.GreedySearch; + + /// + /// The number of beams to apply when generating the output sequence using beam search. + /// If NumBeams=1, then generation is performed using greedy search. If NumBeans > 1, then + /// generation is performed using beam search. A null value implies using TopN search. + /// + public uint? NumBeams { get; set; } = 1; + + /// + /// Only includes the most probable tokens with probabilities that add up to P or higher. + /// Defaults to 1, which includes all of the tokens. Range is 0 to 1, exclusive of 0. + /// + public double NucleusSampling { get; set; } = 1.0; + + /// + /// Whether to stop the beam search when at least NumBeams sentences are finished per batch or not. Defaults to false. + /// + public bool EarlyStopping { get; set; } = false; + + /// + /// The number of sequences (responses) to generate. Returns the sequences with the highest scores in order. + /// + public int ResultsPerPrompt { get; set; } = 1; + + /// + /// Only includes tokens that fall within the list of the K most probable tokens. Range is 1 to the vocabulary size. + /// Defaults to 50. + /// + public uint TopK { get; set; } = 50; + + /// + /// Discounts the scores of previously generated tokens if set to a value greater than 1. + /// Defaults to 1. + /// + public double RepetitionPenalty { get; set; } = 1.0; + + /// + /// Controls the length of the output generated. Value less than 1 encourages the generation + /// to produce shorter sequences. Values greater than 1 encourages longer sequences. Defaults to 1. + /// + public double LengthPenalty { get; set; } = 1.0; + + /// + /// Verify that the current state is valid. + /// + public void Validate(bool allowIO = true) + { + if (string.IsNullOrEmpty(this.TextModelDir)) + { + throw new ConfigurationException($"Onnx: {nameof(this.TextModelDir)} is a required field."); + } + + var modelDir = Path.GetFullPath(this.TextModelDir); + + if (allowIO) + { + if (!Directory.Exists(modelDir)) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} does not exist."); + } + + if (Directory.GetFiles(modelDir).Length == 0) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} is an empty directory."); + } + + var modelFiles = Directory.GetFiles(modelDir) + .Where(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + if (modelFiles == null) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} does not contain a valid .ONNX model."); + } + } + + if (this.SearchType == OnnxSearchType.GreedySearch) + { + if (this.NumBeams != 1) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is only used with Beam Search. Change {nameof(this.NumBeams)} to 1, or change {nameof(this.SearchType)} to BeamSearch."); + } + + if (this.EarlyStopping != false) + { + throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch."); + } + } + + if (this.SearchType == OnnxSearchType.BeamSearch) + { + if (this.NumBeams == null) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is required for Beam Search. Change {nameof(this.NumBeams)} to a value >= 1, or change the {nameof(this.SearchType)}."); + } + } + + if (this.SearchType == OnnxSearchType.TopN) + { + if (this.NumBeams != null) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} isn't required with TopN Search. Change {nameof(this.NumBeams)} to null, or change the {nameof(this.SearchType)}."); + } + + if (this.EarlyStopping != false) + { + throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch."); + } + } + } +} diff --git a/extensions/ONNX/Onnx/OnnxTextGenerator.cs b/extensions/ONNX/Onnx/OnnxTextGenerator.cs new file mode 100644 index 000000000..f0dbab49c --- /dev/null +++ b/extensions/ONNX/Onnx/OnnxTextGenerator.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Diagnostics; +using Microsoft.ML.OnnxRuntimeGenAI; +using static Microsoft.KernelMemory.OnnxConfig; + +namespace Microsoft.KernelMemory.AI.Onnx; + +/// +/// Text generator based on ONNX models, via OnnxRuntimeGenAi +/// See https://github.com/microsoft/onnxruntime-genai +/// +[Experimental("KMEXP01")] +public sealed class OnnxTextGenerator : ITextGenerator, IDisposable +{ + /// + /// The ONNX Model used for text generation + /// + private readonly Model _model; + + /// + /// Tokenizer used with the Onnx Generator and Model classes to produce tokens. + /// This has the potential to contain a null value, depending on the contents of the Model Directory. + /// + private readonly Tokenizer? _tokenizer = default; + + /// + /// Tokenizer used for GetTokens() and CountTokens() + /// + private readonly ITextTokenizer _textTokenizer; + + private readonly ILogger _log; + + private readonly OnnxConfig _config; + + /// + public int MaxTokenTotal { get; internal set; } + + /// + /// Create a new instance + /// + /// Configuration settings + /// Text Tokenizer + /// Application Logger instance + public OnnxTextGenerator( + OnnxConfig config, + ITextTokenizer? textTokenizer = null, + ILoggerFactory? loggerFactory = null) + { + this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + if (textTokenizer == null) + { + this._log.LogWarning( + "Tokenizer not specified, will use {0}. The token count might be incorrect, causing unexpected errors", + nameof(GPT4oTokenizer)); + textTokenizer = new GPT4oTokenizer(); + } + + config.Validate(); + this._config = config; + this.MaxTokenTotal = (int)config.MaxTokens; + this._textTokenizer = textTokenizer; + + var modelDir = Path.GetFullPath(config.TextModelDir); + var modelFile = Directory.GetFiles(modelDir) + .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + this._log.LogDebug("Loading Onnx model: {1} from directory {0}", modelDir, Path.GetFileNameWithoutExtension(modelFile)); + this._model = new Model(config.TextModelDir); + this._tokenizer = new Tokenizer(this._model); + this._log.LogDebug("Onnx model loaded"); + } + + /// + public async IAsyncEnumerable GenerateTextAsync( + string prompt, + TextGenerationOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var tokens = this._tokenizer?.Encode(prompt); + using var generatorParams = new GeneratorParams(this._model); + + generatorParams.SetSearchOption("max_length", this.MaxTokenTotal); + generatorParams.SetSearchOption("min_length", this._config.MinLength); + generatorParams.SetSearchOption("num_return_sequences", this._config.ResultsPerPrompt); + generatorParams.SetSearchOption("repetition_penalty", this._config.RepetitionPenalty); + generatorParams.SetSearchOption("length_penalty", this._config.LengthPenalty); + generatorParams.SetSearchOption("temperature", 0); + + if (options != null) + { + generatorParams.SetSearchOption("num_return_sequences", options.ResultsPerPrompt); + generatorParams.SetSearchOption("temperature", options.Temperature); + + if (options.MaxTokens > 0) + { + generatorParams.SetSearchOption("max_length", (int)options.MaxTokens); + } + } + + switch (this._config.SearchType) + { + case OnnxSearchType.BeamSearch: + generatorParams.SetSearchOption("do_sample", false); + generatorParams.SetSearchOption("early_stopping", this._config.EarlyStopping); + + if (this._config.NumBeams != null) + { + generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams); + } + + break; + + case OnnxSearchType.TopN: + generatorParams.SetSearchOption("do_sample", true); + generatorParams.SetSearchOption("top_k", this._config.TopK); + + generatorParams.SetSearchOption("top_p", options is { NucleusSampling: > 0 and <= 1 } + ? options.NucleusSampling + : this._config.NucleusSampling); + + break; + + default: + + generatorParams.SetSearchOption("do_sample", false); + + if (this._config.NumBeams != null) + { + generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams); + } + + break; + } + + generatorParams.SetInputSequences(tokens); + + using (var generator = new Generator(this._model, generatorParams)) + { + List outputTokens = new(); + + while (!generator.IsDone() && cancellationToken.IsCancellationRequested == false) + { + generator.ComputeLogits(); + generator.GenerateNextToken(); + + outputTokens.AddRange(generator.GetSequence(0)); + + if (outputTokens.Count > 0 && this._tokenizer != null) + { + var newToken = outputTokens[^1]; + yield return this._tokenizer.Decode(new int[] { newToken }); + } + } + } + + await Task.CompletedTask.ConfigureAwait(false); + } + + /// + public int CountTokens(string text) + { + // TODO: Implement with _tokenizer and remove _textTokenizer + return this._textTokenizer.CountTokens(text); + } + + /// + public IReadOnlyList GetTokens(string text) + { + // TODO: Implement with _tokenizer and remove _textTokenizer + return this._textTokenizer.GetTokens(text); + } + + /// + public void Dispose() + { + this._model?.Dispose(); + this._tokenizer?.Dispose(); + } +} diff --git a/extensions/ONNX/README.md b/extensions/ONNX/README.md new file mode 100644 index 000000000..2ab365705 --- /dev/null +++ b/extensions/ONNX/README.md @@ -0,0 +1,25 @@ +# Kernel Memory with ONNX + +[![Nuget package](https://img.shields.io/nuget/v/Microsoft.KernelMemory.AI.Onnx)](https://www.nuget.org/packages/Microsoft.KernelMemory.AI.Onnx/) +[![Discord](https://img.shields.io/discord/1063152441819942922?label=Discord&logo=discord&logoColor=white&color=d82679)](https://aka.ms/KMdiscord) + +This project contains the +[ONNX](https://onnxruntime.ai/docs/genai/) +LLM connector to access to LLM models via Onnx service to generate text. + +Sample code: + +```csharp +var config = new OnnxConfig +{ + ModelPath = "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32" +}; + +var memory = new KernelMemoryBuilder() + .WithOnnxTextGeneration(config) + .Build(); + +await memory.ImportTextAsync("Today is October 32nd, 2476"); + +var answer = await memory.AskAsync("What's the current date (don't check for validity)?"); +``` diff --git a/service/tests/TestHelpers/BaseFunctionalTestCase.cs b/service/tests/TestHelpers/BaseFunctionalTestCase.cs index 856bde08c..d73110713 100644 --- a/service/tests/TestHelpers/BaseFunctionalTestCase.cs +++ b/service/tests/TestHelpers/BaseFunctionalTestCase.cs @@ -31,6 +31,7 @@ public abstract class BaseFunctionalTestCase : IDisposable protected readonly SimpleVectorDbConfig SimpleVectorDbConfig; protected readonly LlamaSharpConfig LlamaSharpConfig; protected readonly ElasticsearchConfig ElasticsearchConfig; + protected readonly OnnxConfig OnnxConfig; // IMPORTANT: install Xunit.DependencyInjection package protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output) @@ -50,6 +51,7 @@ protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output) this.SimpleVectorDbConfig = cfg.GetSection("KernelMemory:Services:SimpleVectorDb").Get() ?? new(); this.LlamaSharpConfig = cfg.GetSection("KernelMemory:Services:LlamaSharp").Get() ?? new(); this.ElasticsearchConfig = cfg.GetSection("KernelMemory:Services:Elasticsearch").Get() ?? new(); + this.OnnxConfig = cfg.GetSection("KernelMemory:Services:Onnx").Get() ?? new(); } protected IKernelMemory GetMemoryWebClient() diff --git a/service/tests/TestHelpers/TestHelpers.csproj b/service/tests/TestHelpers/TestHelpers.csproj index 3996f22f9..b0b991041 100644 --- a/service/tests/TestHelpers/TestHelpers.csproj +++ b/service/tests/TestHelpers/TestHelpers.csproj @@ -16,6 +16,7 @@ +