From 59e9f5d6d56ce6a547a9600efe817a88c9a98582 Mon Sep 17 00:00:00 2001 From: Justin Ridings <49916830+JustinRidings@users.noreply.github.com> Date: Fri, 4 Oct 2024 21:09:21 -0700 Subject: [PATCH] Added initial implementation of Microsoft.KernelMemory.AI.ONNX package (#788) ## Motivation and Context (Why the change? What's the scenario?) This change includes an entirely new extension that allows TextGeneration with ONNX Models. This scenario allows users to interact with local cloned ONNX Models ## High level description (Approach, Design) This is a scenario that abstracts the general implementation of the OnnxRuntimeGenAI library to work with text generation. More examples on using OnnxGenAI Libraries, please see https://github.com/microsoft/onnxruntime-genai --------- Co-authored-by: Justin Ridings Co-authored-by: Devis Lucato --- Directory.Packages.props | 2 +- KernelMemory.sln | 29 ++- .../Onnx.FunctionalTests.csproj | 36 ++++ .../OnnxTextGeneratorTest.cs | 67 ++++++ .../ONNX/Onnx.FunctionalTests/Startup.cs | 22 ++ .../ONNX/Onnx.FunctionalTests/Usings.cs | 3 + .../Onnx.FunctionalTests/appsettings.json | 67 ++++++ extensions/ONNX/Onnx/DependencyInjection.cs | 60 ++++++ extensions/ONNX/Onnx/Onnx.csproj | 32 +++ extensions/ONNX/Onnx/OnnxConfig.cs | 171 ++++++++++++++++ extensions/ONNX/Onnx/OnnxTextGenerator.cs | 190 ++++++++++++++++++ extensions/ONNX/README.md | 25 +++ .../TestHelpers/BaseFunctionalTestCase.cs | 2 + service/tests/TestHelpers/TestHelpers.csproj | 1 + 14 files changed, 698 insertions(+), 9 deletions(-) create mode 100644 extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj create mode 100644 extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs create mode 100644 extensions/ONNX/Onnx.FunctionalTests/Startup.cs create mode 100644 extensions/ONNX/Onnx.FunctionalTests/Usings.cs create mode 100644 extensions/ONNX/Onnx.FunctionalTests/appsettings.json create mode 100644 extensions/ONNX/Onnx/DependencyInjection.cs create mode 100644 extensions/ONNX/Onnx/Onnx.csproj create mode 100644 extensions/ONNX/Onnx/OnnxConfig.cs create mode 100644 extensions/ONNX/Onnx/OnnxTextGenerator.cs create mode 100644 extensions/ONNX/README.md diff --git a/Directory.Packages.props b/Directory.Packages.props index 41efaafd0..2834bf71d 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -34,7 +34,7 @@ - + diff --git a/KernelMemory.sln b/KernelMemory.sln index 15db02fe8..7dc6a100e 100644 --- a/KernelMemory.sln +++ b/KernelMemory.sln @@ -7,6 +7,7 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-19E2-46EB-B000-513EE2F65769}" ProjectSection(SolutionItems) = preProject docs\404.html = docs\404.html + docs\azure.md = docs\azure.md docs\concepts.md = docs\concepts.md docs\csharp.png = docs\csharp.png docs\extensions.md = docs\extensions.md @@ -31,7 +32,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-1 docs\service.md = docs\service.md docs\_config.local.yml = docs\_config.local.yml docs\_config.yml = docs\_config.yml - docs\azure.md = docs\azure.md EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{0A43C65C-6007-4BB4-B3FE-8D439FC91841}" @@ -70,18 +70,18 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{6EF76FD8-4 .editorconfig = .editorconfig .gitattributes = .gitattributes .gitignore = .gitignore + azure.yaml = azure.yaml CODE_OF_CONDUCT.md = CODE_OF_CONDUCT.md COMMUNITY.md = COMMUNITY.md CONTRIBUTING.md = CONTRIBUTING.md Directory.Build.props = Directory.Build.props Directory.Packages.props = Directory.Packages.props Dockerfile = Dockerfile + KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings LICENSE = LICENSE nuget.config = nuget.config README.md = README.md SECURITY.md = SECURITY.md - KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings - azure.yaml = azure.yaml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".github", ".github", "{B8976338-7CDC-47AE-8502-C2FBAFBEBD68}" @@ -94,10 +94,10 @@ EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{48E79819-1E9E-4075-90DA-BAEC761C89B2}" ProjectSection(SolutionItems) = preProject .github\workflows\docker-build-push.yml = .github\workflows\docker-build-push.yml + .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml + .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml .github\workflows\github-pages-jekyll.yml = .github\workflows\github-pages-jekyll.yml .github\workflows\spell-check-with-typos.yml = .github\workflows\spell-check-with-typos.yml - .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml - .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "service", "service", "{87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}" @@ -116,9 +116,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{CA49F1A1 tools\run-qdrant.sh = tools\run-qdrant.sh tools\run-rabbitmq.sh = tools\run-rabbitmq.sh tools\run-redis.sh = tools\run-redis.sh + tools\run-s3ninja.sh = tools\run-s3ninja.sh tools\search.sh = tools\search.sh tools\upload-file.sh = tools\upload-file.sh - tools\run-s3ninja.sh = tools\run-s3ninja.sh tools\dockerize-amd64.sh = tools\dockerize-amd64.sh tools\dockerize-arm64.sh = tools\dockerize-arm64.sh EndProjectSection @@ -243,12 +243,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "109-dotnet-custom-webscrape EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "infra", "infra", "{B488168B-AD86-4CC5-9D89-324B6EB743D9}" ProjectSection(SolutionItems) = preProject + infra\AZD.md = infra\AZD.md infra\build-main.json.sh = infra\build-main.json.sh infra\main.bicep = infra\main.bicep infra\main.json = infra\main.json - infra\README.md = infra\README.md - infra\AZD.md = infra\AZD.md infra\main.parameters.json = infra\main.parameters.json + infra\README.md = infra\README.md EndProjectSection EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3A947-B6F9-4306-BD42-21D8D1F42750}" @@ -323,6 +323,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "212-dotnet-ollama", "exampl EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Ollama", "extensions\Ollama\Ollama\Ollama.csproj", "{F192513B-265B-4943-A2A9-44E23B15BA18}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx.FunctionalTests", "extensions\ONNX\Onnx.FunctionalTests\Onnx.FunctionalTests.csproj", "{7BBD348E-CDD9-4462-B8C9-47613C5EC682}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx", "extensions\ONNX\Onnx\Onnx.csproj", "{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -591,6 +595,13 @@ Global {F192513B-265B-4943-A2A9-44E23B15BA18}.Debug|Any CPU.Build.0 = Debug|Any CPU {F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.ActiveCfg = Release|Any CPU {F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.Build.0 = Release|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.Build.0 = Debug|Any CPU + {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Release|Any CPU.ActiveCfg = Release|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.Build.0 = Debug|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.ActiveCfg = Release|Any CPU + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -685,6 +696,8 @@ Global {84AEC1DD-CBAE-400A-949C-91BA373C587D} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} {B303885D-F64F-4EEB-B085-0014E863AF61} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841} {F192513B-265B-4943-A2A9-44E23B15BA18} = {155DA079-E267-49AF-973A-D1D44681970F} + {7BBD348E-CDD9-4462-B8C9-47613C5EC682} = {3C17F42B-CFC8-4900-8CFB-88936311E919} + {345DEF9B-6EE1-49DF-B46A-25E38CE9B151} = {155DA079-E267-49AF-973A-D1D44681970F} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {CC136C62-115C-41D1-B414-F9473EFF6EA8} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj new file mode 100644 index 000000000..6143d9130 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj @@ -0,0 +1,36 @@ + + + + Microsoft.Onnx.FunctionalTests + Microsoft.Onnx.FunctionalTests + net8.0 + LatestMajor + true + enable + enable + false + $(NoWarn);KMEXP01; + + + + + + + + + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + diff --git a/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs new file mode 100644 index 000000000..a2e491b6d --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics; +using System.Text; +using Microsoft.KernelMemory.AI; +using Microsoft.KernelMemory.AI.Onnx; +using Microsoft.KM.TestHelpers; +using Xunit.Abstractions; + +namespace Microsoft.Onnx.FunctionalTests; + +public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase +{ + private readonly OnnxTextGenerator _target; + private readonly Stopwatch _timer; + + public OnnxTextGeneratorTest( + IConfiguration cfg, + ITestOutputHelper output) : base(cfg, output) + { + this._timer = new Stopwatch(); + + this.OnnxConfig.Validate(); + this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null); + + var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir); + var modelFile = Directory.GetFiles(modelDirectory) + .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}"); + } + + [Fact] + [Trait("Category", "Onnx")] + public async Task ItGeneratesText() + { + var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy"); + var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}."; + var question = $"What is the current date?"; + var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>"; + + var options = new TextGenerationOptions(); + + // Act + this._timer.Restart(); + var tokens = this._target.GenerateTextAsync(prompt, options); + var result = new StringBuilder(); + await foreach (string token in tokens) + { + result.Append(token); + } + + this._timer.Stop(); + var answer = result.ToString(); + + // Assert + Console.WriteLine($"Model Output:\n=============================\n{answer}\n============================="); + Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs"); + Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase); + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + this._target.Dispose(); + } +} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Startup.cs b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs new file mode 100644 index 000000000..fe9b40a95 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +/* IMPORTANT: the Startup class must be at the root of the namespace and + * the namespace must match exactly (required by Xunit.DependencyInjection) */ + +namespace Microsoft.Onnx.FunctionalTests; + +public class Startup +{ + public void ConfigureHost(IHostBuilder hostBuilder) + { + var config = new ConfigurationBuilder() + .AddJsonFile("appsettings.json") + .AddJsonFile("appsettings.development.json", optional: true) + .AddJsonFile("appsettings.Development.json", optional: true) + .AddUserSecrets() + .AddEnvironmentVariables() + .Build(); + + hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config)); + } +} diff --git a/extensions/ONNX/Onnx.FunctionalTests/Usings.cs b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs new file mode 100644 index 000000000..38b0ca7bb --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs @@ -0,0 +1,3 @@ +// Copyright (c) Microsoft. All rights reserved. + +global using Xunit; diff --git a/extensions/ONNX/Onnx.FunctionalTests/appsettings.json b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json new file mode 100644 index 000000000..ec84442d4 --- /dev/null +++ b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json @@ -0,0 +1,67 @@ +{ + "KernelMemory": { + "ServiceAuthorization": { + "Endpoint": "http://127.0.0.1:9001/", + "AccessKey": "" + }, + "Services": { + "Onnx": { + // Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32" + "TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32" + }, + "SimpleVectorDb": { + // Options: "Disk" or "Volatile". Volatile data is lost after each execution. + "StorageType": "Volatile", + // Directory where files are stored. + "Directory": "_vectors" + }, + "AzureAISearch": { + // "ApiKey" or "AzureIdentity". For other options see . + // AzureIdentity: use automatic AAD authentication mechanism. You can test locally + // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. + "Auth": "AzureIdentity", + "Endpoint": "https://<...>", + "APIKey": "" + }, + "Postgres": { + // Postgres instance connection string + "ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public", + // Mandatory prefix to add to the name of table managed by KM, + // e.g. to exclude other tables in the same schema. + "TableNamePrefix": "tests-" + }, + "Qdrant": { + "Endpoint": "http://127.0.0.1:6333", + "APIKey": "" + }, + "OpenAI": { + // Name of the model used to generate text (text completion or chat completion) + "TextModel": "gpt-4o-mini", + // The max number of tokens supported by the text model. + "TextModelMaxTokenTotal": 16384, + // What type of text generation, by default autodetect using the model name. + // Possible values: "Auto", "TextCompletion", "Chat" + "TextGenerationType": "Auto", + // Name of the model used to generate text embeddings + "EmbeddingModel": "text-embedding-ada-002", + // The max number of tokens supported by the embedding model + // See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings + "EmbeddingModelMaxTokenTotal": 8191, + // OpenAI API Key + "APIKey": "", + // OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs) + "OrgId": "", + // Endpoint to use. By default the system uses 'https://api.openai.com/v1'. + // Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio. + "Endpoint": "", + // How many times to retry in case of throttling + "MaxRetries": 10 + } + } + }, + "Logging": { + "LogLevel": { + "Default": "Information" + } + } +} \ No newline at end of file diff --git a/extensions/ONNX/Onnx/DependencyInjection.cs b/extensions/ONNX/Onnx/DependencyInjection.cs new file mode 100644 index 000000000..55e3b5018 --- /dev/null +++ b/extensions/ONNX/Onnx/DependencyInjection.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory.AI; +using Microsoft.KernelMemory.AI.Onnx; + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +/// +/// Kernel Memory builder extensions +/// +public static partial class KernelMemoryBuilderExtensions +{ + public static IKernelMemoryBuilder WithOnnxTextGeneration( + this IKernelMemoryBuilder builder, + string modelPath, + uint maxTokenTotal, + ITextTokenizer? textTokenizer = null) + { + var config = new OnnxConfig + { + TextModelDir = modelPath + }; + + builder.Services.AddOnnxTextGeneration(config, textTokenizer); + + return builder; + } + + public static IKernelMemoryBuilder WithOnnxTextGeneration( + this IKernelMemoryBuilder builder, + OnnxConfig config, + ITextTokenizer? textTokenizer = null) + { + builder.Services.AddOnnxTextGeneration(config, textTokenizer); + return builder; + } +} + +/// +/// .NET IServiceCollection dependency injection extensions. +/// +public static partial class DependencyInjection +{ + public static IServiceCollection AddOnnxTextGeneration( + this IServiceCollection services, + OnnxConfig config, + ITextTokenizer? textTokenizer = null) + { + config.Validate(); + return services + .AddSingleton(serviceProvider => new OnnxTextGenerator( + config: config, + textTokenizer: textTokenizer, + loggerFactory: serviceProvider.GetService())); + } +} diff --git a/extensions/ONNX/Onnx/Onnx.csproj b/extensions/ONNX/Onnx/Onnx.csproj new file mode 100644 index 000000000..96ec73dff --- /dev/null +++ b/extensions/ONNX/Onnx/Onnx.csproj @@ -0,0 +1,32 @@ + + + + net8.0 + LatestMajor + Microsoft.KernelMemory.AI.Onnx + Microsoft.KernelMemory.AI.Onnx + $(NoWarn);KMEXP00;KMEXP01;CA1724; + + + + true + Microsoft.KernelMemory.AI.Onnx + ONNX LLM connector for Kernel Memory + Provide access to ONNX LLM models in Kernel Memory to generate text + ONNX, Memory, Kernel Memory, Semantic Memory, Episodic Memory, Declarative Memory, AI, Artificial Intelligence, Semantic Search, Memory DB + bin/$(Configuration)/$(TargetFramework)/$(AssemblyName).xml + + + + + + + + + + + + + + + diff --git a/extensions/ONNX/Onnx/OnnxConfig.cs b/extensions/ONNX/Onnx/OnnxConfig.cs new file mode 100644 index 000000000..4a0ce66fe --- /dev/null +++ b/extensions/ONNX/Onnx/OnnxConfig.cs @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.IO; +using System.Linq; + +#pragma warning disable IDE0130 // reduce number of "using" statements + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +public class OnnxConfig +{ + /// + /// An enum representing the possible text generation search types used by OnnxTextGenerator. + /// See https://onnxruntime.ai/docs/genai/reference/config.html#search-combinations for more details. + /// + public enum OnnxSearchType + { + /// + /// A decoding algorithm that keeps track of the top K sequences at each step. It explores + /// multiple paths simultaneously, balancing exploration and exploitation. Often results in more + /// coherent and higher quality text generation than Greedy Search would. + /// + BeamSearch, + + /// + /// The default and simplest decoding algorithm. At each step, a token is selected with the highest + /// probability as the next word in the sequence. + /// + GreedySearch, + + /// + /// Combined Top-P (Nucleus) and Top-K Sampling: A decoding algorithm that samples from the top k tokens + /// with the highest probabilities, while also considering the smallest set of tokens whose cumulative + /// probability exceeds a threshold p. This approach dynamically balances diversity and coherence in + /// text generation by adjusting the sampling pool based on both fixed and cumulative probability criteria. + /// + TopN + } + + /// + /// Path to the directory containing the .ONNX file for Text Generation. + /// + public string TextModelDir { get; set; } = string.Empty; + + /// + /// The maximum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html + /// + public int MaxTokens { get; set; } = 2048; + + /// + /// The minimum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html + /// + public uint MinLength { get; set; } = 0; + + /// + /// The algorithm used in text generation. Defaults to GreedySearch. + /// + public OnnxSearchType SearchType { get; set; } = OnnxSearchType.GreedySearch; + + /// + /// The number of beams to apply when generating the output sequence using beam search. + /// If NumBeams=1, then generation is performed using greedy search. If NumBeans > 1, then + /// generation is performed using beam search. A null value implies using TopN search. + /// + public uint? NumBeams { get; set; } = 1; + + /// + /// Only includes the most probable tokens with probabilities that add up to P or higher. + /// Defaults to 1, which includes all of the tokens. Range is 0 to 1, exclusive of 0. + /// + public double NucleusSampling { get; set; } = 1.0; + + /// + /// Whether to stop the beam search when at least NumBeams sentences are finished per batch or not. Defaults to false. + /// + public bool EarlyStopping { get; set; } = false; + + /// + /// The number of sequences (responses) to generate. Returns the sequences with the highest scores in order. + /// + public int ResultsPerPrompt { get; set; } = 1; + + /// + /// Only includes tokens that fall within the list of the K most probable tokens. Range is 1 to the vocabulary size. + /// Defaults to 50. + /// + public uint TopK { get; set; } = 50; + + /// + /// Discounts the scores of previously generated tokens if set to a value greater than 1. + /// Defaults to 1. + /// + public double RepetitionPenalty { get; set; } = 1.0; + + /// + /// Controls the length of the output generated. Value less than 1 encourages the generation + /// to produce shorter sequences. Values greater than 1 encourages longer sequences. Defaults to 1. + /// + public double LengthPenalty { get; set; } = 1.0; + + /// + /// Verify that the current state is valid. + /// + public void Validate(bool allowIO = true) + { + if (string.IsNullOrEmpty(this.TextModelDir)) + { + throw new ConfigurationException($"Onnx: {nameof(this.TextModelDir)} is a required field."); + } + + var modelDir = Path.GetFullPath(this.TextModelDir); + + if (allowIO) + { + if (!Directory.Exists(modelDir)) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} does not exist."); + } + + if (Directory.GetFiles(modelDir).Length == 0) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} is an empty directory."); + } + + var modelFiles = Directory.GetFiles(modelDir) + .Where(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + if (modelFiles == null) + { + throw new ConfigurationException($"Onnx: {this.TextModelDir} does not contain a valid .ONNX model."); + } + } + + if (this.SearchType == OnnxSearchType.GreedySearch) + { + if (this.NumBeams != 1) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is only used with Beam Search. Change {nameof(this.NumBeams)} to 1, or change {nameof(this.SearchType)} to BeamSearch."); + } + + if (this.EarlyStopping != false) + { + throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch."); + } + } + + if (this.SearchType == OnnxSearchType.BeamSearch) + { + if (this.NumBeams == null) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is required for Beam Search. Change {nameof(this.NumBeams)} to a value >= 1, or change the {nameof(this.SearchType)}."); + } + } + + if (this.SearchType == OnnxSearchType.TopN) + { + if (this.NumBeams != null) + { + throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} isn't required with TopN Search. Change {nameof(this.NumBeams)} to null, or change the {nameof(this.SearchType)}."); + } + + if (this.EarlyStopping != false) + { + throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch."); + } + } + } +} diff --git a/extensions/ONNX/Onnx/OnnxTextGenerator.cs b/extensions/ONNX/Onnx/OnnxTextGenerator.cs new file mode 100644 index 000000000..f0dbab49c --- /dev/null +++ b/extensions/ONNX/Onnx/OnnxTextGenerator.cs @@ -0,0 +1,190 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory.AI.OpenAI; +using Microsoft.KernelMemory.Diagnostics; +using Microsoft.ML.OnnxRuntimeGenAI; +using static Microsoft.KernelMemory.OnnxConfig; + +namespace Microsoft.KernelMemory.AI.Onnx; + +/// +/// Text generator based on ONNX models, via OnnxRuntimeGenAi +/// See https://github.com/microsoft/onnxruntime-genai +/// +[Experimental("KMEXP01")] +public sealed class OnnxTextGenerator : ITextGenerator, IDisposable +{ + /// + /// The ONNX Model used for text generation + /// + private readonly Model _model; + + /// + /// Tokenizer used with the Onnx Generator and Model classes to produce tokens. + /// This has the potential to contain a null value, depending on the contents of the Model Directory. + /// + private readonly Tokenizer? _tokenizer = default; + + /// + /// Tokenizer used for GetTokens() and CountTokens() + /// + private readonly ITextTokenizer _textTokenizer; + + private readonly ILogger _log; + + private readonly OnnxConfig _config; + + /// + public int MaxTokenTotal { get; internal set; } + + /// + /// Create a new instance + /// + /// Configuration settings + /// Text Tokenizer + /// Application Logger instance + public OnnxTextGenerator( + OnnxConfig config, + ITextTokenizer? textTokenizer = null, + ILoggerFactory? loggerFactory = null) + { + this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger(); + if (textTokenizer == null) + { + this._log.LogWarning( + "Tokenizer not specified, will use {0}. The token count might be incorrect, causing unexpected errors", + nameof(GPT4oTokenizer)); + textTokenizer = new GPT4oTokenizer(); + } + + config.Validate(); + this._config = config; + this.MaxTokenTotal = (int)config.MaxTokens; + this._textTokenizer = textTokenizer; + + var modelDir = Path.GetFullPath(config.TextModelDir); + var modelFile = Directory.GetFiles(modelDir) + .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); + + this._log.LogDebug("Loading Onnx model: {1} from directory {0}", modelDir, Path.GetFileNameWithoutExtension(modelFile)); + this._model = new Model(config.TextModelDir); + this._tokenizer = new Tokenizer(this._model); + this._log.LogDebug("Onnx model loaded"); + } + + /// + public async IAsyncEnumerable GenerateTextAsync( + string prompt, + TextGenerationOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var tokens = this._tokenizer?.Encode(prompt); + using var generatorParams = new GeneratorParams(this._model); + + generatorParams.SetSearchOption("max_length", this.MaxTokenTotal); + generatorParams.SetSearchOption("min_length", this._config.MinLength); + generatorParams.SetSearchOption("num_return_sequences", this._config.ResultsPerPrompt); + generatorParams.SetSearchOption("repetition_penalty", this._config.RepetitionPenalty); + generatorParams.SetSearchOption("length_penalty", this._config.LengthPenalty); + generatorParams.SetSearchOption("temperature", 0); + + if (options != null) + { + generatorParams.SetSearchOption("num_return_sequences", options.ResultsPerPrompt); + generatorParams.SetSearchOption("temperature", options.Temperature); + + if (options.MaxTokens > 0) + { + generatorParams.SetSearchOption("max_length", (int)options.MaxTokens); + } + } + + switch (this._config.SearchType) + { + case OnnxSearchType.BeamSearch: + generatorParams.SetSearchOption("do_sample", false); + generatorParams.SetSearchOption("early_stopping", this._config.EarlyStopping); + + if (this._config.NumBeams != null) + { + generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams); + } + + break; + + case OnnxSearchType.TopN: + generatorParams.SetSearchOption("do_sample", true); + generatorParams.SetSearchOption("top_k", this._config.TopK); + + generatorParams.SetSearchOption("top_p", options is { NucleusSampling: > 0 and <= 1 } + ? options.NucleusSampling + : this._config.NucleusSampling); + + break; + + default: + + generatorParams.SetSearchOption("do_sample", false); + + if (this._config.NumBeams != null) + { + generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams); + } + + break; + } + + generatorParams.SetInputSequences(tokens); + + using (var generator = new Generator(this._model, generatorParams)) + { + List outputTokens = new(); + + while (!generator.IsDone() && cancellationToken.IsCancellationRequested == false) + { + generator.ComputeLogits(); + generator.GenerateNextToken(); + + outputTokens.AddRange(generator.GetSequence(0)); + + if (outputTokens.Count > 0 && this._tokenizer != null) + { + var newToken = outputTokens[^1]; + yield return this._tokenizer.Decode(new int[] { newToken }); + } + } + } + + await Task.CompletedTask.ConfigureAwait(false); + } + + /// + public int CountTokens(string text) + { + // TODO: Implement with _tokenizer and remove _textTokenizer + return this._textTokenizer.CountTokens(text); + } + + /// + public IReadOnlyList GetTokens(string text) + { + // TODO: Implement with _tokenizer and remove _textTokenizer + return this._textTokenizer.GetTokens(text); + } + + /// + public void Dispose() + { + this._model?.Dispose(); + this._tokenizer?.Dispose(); + } +} diff --git a/extensions/ONNX/README.md b/extensions/ONNX/README.md new file mode 100644 index 000000000..2ab365705 --- /dev/null +++ b/extensions/ONNX/README.md @@ -0,0 +1,25 @@ +# Kernel Memory with ONNX + +[![Nuget package](https://img.shields.io/nuget/v/Microsoft.KernelMemory.AI.Onnx)](https://www.nuget.org/packages/Microsoft.KernelMemory.AI.Onnx/) +[![Discord](https://img.shields.io/discord/1063152441819942922?label=Discord&logo=discord&logoColor=white&color=d82679)](https://aka.ms/KMdiscord) + +This project contains the +[ONNX](https://onnxruntime.ai/docs/genai/) +LLM connector to access to LLM models via Onnx service to generate text. + +Sample code: + +```csharp +var config = new OnnxConfig +{ + ModelPath = "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32" +}; + +var memory = new KernelMemoryBuilder() + .WithOnnxTextGeneration(config) + .Build(); + +await memory.ImportTextAsync("Today is October 32nd, 2476"); + +var answer = await memory.AskAsync("What's the current date (don't check for validity)?"); +``` diff --git a/service/tests/TestHelpers/BaseFunctionalTestCase.cs b/service/tests/TestHelpers/BaseFunctionalTestCase.cs index 856bde08c..d73110713 100644 --- a/service/tests/TestHelpers/BaseFunctionalTestCase.cs +++ b/service/tests/TestHelpers/BaseFunctionalTestCase.cs @@ -31,6 +31,7 @@ public abstract class BaseFunctionalTestCase : IDisposable protected readonly SimpleVectorDbConfig SimpleVectorDbConfig; protected readonly LlamaSharpConfig LlamaSharpConfig; protected readonly ElasticsearchConfig ElasticsearchConfig; + protected readonly OnnxConfig OnnxConfig; // IMPORTANT: install Xunit.DependencyInjection package protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output) @@ -50,6 +51,7 @@ protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output) this.SimpleVectorDbConfig = cfg.GetSection("KernelMemory:Services:SimpleVectorDb").Get() ?? new(); this.LlamaSharpConfig = cfg.GetSection("KernelMemory:Services:LlamaSharp").Get() ?? new(); this.ElasticsearchConfig = cfg.GetSection("KernelMemory:Services:Elasticsearch").Get() ?? new(); + this.OnnxConfig = cfg.GetSection("KernelMemory:Services:Onnx").Get() ?? new(); } protected IKernelMemory GetMemoryWebClient() diff --git a/service/tests/TestHelpers/TestHelpers.csproj b/service/tests/TestHelpers/TestHelpers.csproj index 3996f22f9..b0b991041 100644 --- a/service/tests/TestHelpers/TestHelpers.csproj +++ b/service/tests/TestHelpers/TestHelpers.csproj @@ -16,6 +16,7 @@ +