diff --git a/Directory.Packages.props b/Directory.Packages.props
index 41efaafd0..2834bf71d 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -34,7 +34,7 @@
-
+
diff --git a/KernelMemory.sln b/KernelMemory.sln
index 15db02fe8..7dc6a100e 100644
--- a/KernelMemory.sln
+++ b/KernelMemory.sln
@@ -7,6 +7,7 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-19E2-46EB-B000-513EE2F65769}"
ProjectSection(SolutionItems) = preProject
docs\404.html = docs\404.html
+ docs\azure.md = docs\azure.md
docs\concepts.md = docs\concepts.md
docs\csharp.png = docs\csharp.png
docs\extensions.md = docs\extensions.md
@@ -31,7 +32,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-1
docs\service.md = docs\service.md
docs\_config.local.yml = docs\_config.local.yml
docs\_config.yml = docs\_config.yml
- docs\azure.md = docs\azure.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{0A43C65C-6007-4BB4-B3FE-8D439FC91841}"
@@ -70,18 +70,18 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{6EF76FD8-4
.editorconfig = .editorconfig
.gitattributes = .gitattributes
.gitignore = .gitignore
+ azure.yaml = azure.yaml
CODE_OF_CONDUCT.md = CODE_OF_CONDUCT.md
COMMUNITY.md = COMMUNITY.md
CONTRIBUTING.md = CONTRIBUTING.md
Directory.Build.props = Directory.Build.props
Directory.Packages.props = Directory.Packages.props
Dockerfile = Dockerfile
+ KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
LICENSE = LICENSE
nuget.config = nuget.config
README.md = README.md
SECURITY.md = SECURITY.md
- KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
- azure.yaml = azure.yaml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".github", ".github", "{B8976338-7CDC-47AE-8502-C2FBAFBEBD68}"
@@ -94,10 +94,10 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{48E79819-1E9E-4075-90DA-BAEC761C89B2}"
ProjectSection(SolutionItems) = preProject
.github\workflows\docker-build-push.yml = .github\workflows\docker-build-push.yml
+ .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
+ .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
.github\workflows\github-pages-jekyll.yml = .github\workflows\github-pages-jekyll.yml
.github\workflows\spell-check-with-typos.yml = .github\workflows\spell-check-with-typos.yml
- .github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
- .github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "service", "service", "{87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}"
@@ -116,9 +116,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{CA49F1A1
tools\run-qdrant.sh = tools\run-qdrant.sh
tools\run-rabbitmq.sh = tools\run-rabbitmq.sh
tools\run-redis.sh = tools\run-redis.sh
+ tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\search.sh = tools\search.sh
tools\upload-file.sh = tools\upload-file.sh
- tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\dockerize-amd64.sh = tools\dockerize-amd64.sh
tools\dockerize-arm64.sh = tools\dockerize-arm64.sh
EndProjectSection
@@ -243,12 +243,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "109-dotnet-custom-webscrape
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "infra", "infra", "{B488168B-AD86-4CC5-9D89-324B6EB743D9}"
ProjectSection(SolutionItems) = preProject
+ infra\AZD.md = infra\AZD.md
infra\build-main.json.sh = infra\build-main.json.sh
infra\main.bicep = infra\main.bicep
infra\main.json = infra\main.json
- infra\README.md = infra\README.md
- infra\AZD.md = infra\AZD.md
infra\main.parameters.json = infra\main.parameters.json
+ infra\README.md = infra\README.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3A947-B6F9-4306-BD42-21D8D1F42750}"
@@ -323,6 +323,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "212-dotnet-ollama", "exampl
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Ollama", "extensions\Ollama\Ollama\Ollama.csproj", "{F192513B-265B-4943-A2A9-44E23B15BA18}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx.FunctionalTests", "extensions\ONNX\Onnx.FunctionalTests\Onnx.FunctionalTests.csproj", "{7BBD348E-CDD9-4462-B8C9-47613C5EC682}"
+EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx", "extensions\ONNX\Onnx\Onnx.csproj", "{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -591,6 +595,13 @@ Global
{F192513B-265B-4943-A2A9-44E23B15BA18}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.Build.0 = Release|Any CPU
+ {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -685,6 +696,8 @@ Global
{84AEC1DD-CBAE-400A-949C-91BA373C587D} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{B303885D-F64F-4EEB-B085-0014E863AF61} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{F192513B-265B-4943-A2A9-44E23B15BA18} = {155DA079-E267-49AF-973A-D1D44681970F}
+ {7BBD348E-CDD9-4462-B8C9-47613C5EC682} = {3C17F42B-CFC8-4900-8CFB-88936311E919}
+ {345DEF9B-6EE1-49DF-B46A-25E38CE9B151} = {155DA079-E267-49AF-973A-D1D44681970F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CC136C62-115C-41D1-B414-F9473EFF6EA8}
diff --git a/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj
new file mode 100644
index 000000000..6143d9130
--- /dev/null
+++ b/extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj
@@ -0,0 +1,36 @@
+
+
+
+ Microsoft.Onnx.FunctionalTests
+ Microsoft.Onnx.FunctionalTests
+ net8.0
+ LatestMajor
+ true
+ enable
+ enable
+ false
+ $(NoWarn);KMEXP01;
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+ all
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+
+
+
+
diff --git a/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs
new file mode 100644
index 000000000..a2e491b6d
--- /dev/null
+++ b/extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs
@@ -0,0 +1,67 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Diagnostics;
+using System.Text;
+using Microsoft.KernelMemory.AI;
+using Microsoft.KernelMemory.AI.Onnx;
+using Microsoft.KM.TestHelpers;
+using Xunit.Abstractions;
+
+namespace Microsoft.Onnx.FunctionalTests;
+
+public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase
+{
+ private readonly OnnxTextGenerator _target;
+ private readonly Stopwatch _timer;
+
+ public OnnxTextGeneratorTest(
+ IConfiguration cfg,
+ ITestOutputHelper output) : base(cfg, output)
+ {
+ this._timer = new Stopwatch();
+
+ this.OnnxConfig.Validate();
+ this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null);
+
+ var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir);
+ var modelFile = Directory.GetFiles(modelDirectory)
+ .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase));
+
+ Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}");
+ }
+
+ [Fact]
+ [Trait("Category", "Onnx")]
+ public async Task ItGeneratesText()
+ {
+ var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy");
+ var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}.";
+ var question = $"What is the current date?";
+ var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>";
+
+ var options = new TextGenerationOptions();
+
+ // Act
+ this._timer.Restart();
+ var tokens = this._target.GenerateTextAsync(prompt, options);
+ var result = new StringBuilder();
+ await foreach (string token in tokens)
+ {
+ result.Append(token);
+ }
+
+ this._timer.Stop();
+ var answer = result.ToString();
+
+ // Assert
+ Console.WriteLine($"Model Output:\n=============================\n{answer}\n=============================");
+ Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs");
+ Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase);
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ base.Dispose(disposing);
+ this._target.Dispose();
+ }
+}
diff --git a/extensions/ONNX/Onnx.FunctionalTests/Startup.cs b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs
new file mode 100644
index 000000000..fe9b40a95
--- /dev/null
+++ b/extensions/ONNX/Onnx.FunctionalTests/Startup.cs
@@ -0,0 +1,22 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+/* IMPORTANT: the Startup class must be at the root of the namespace and
+ * the namespace must match exactly (required by Xunit.DependencyInjection) */
+
+namespace Microsoft.Onnx.FunctionalTests;
+
+public class Startup
+{
+ public void ConfigureHost(IHostBuilder hostBuilder)
+ {
+ var config = new ConfigurationBuilder()
+ .AddJsonFile("appsettings.json")
+ .AddJsonFile("appsettings.development.json", optional: true)
+ .AddJsonFile("appsettings.Development.json", optional: true)
+ .AddUserSecrets()
+ .AddEnvironmentVariables()
+ .Build();
+
+ hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config));
+ }
+}
diff --git a/extensions/ONNX/Onnx.FunctionalTests/Usings.cs b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs
new file mode 100644
index 000000000..38b0ca7bb
--- /dev/null
+++ b/extensions/ONNX/Onnx.FunctionalTests/Usings.cs
@@ -0,0 +1,3 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+global using Xunit;
diff --git a/extensions/ONNX/Onnx.FunctionalTests/appsettings.json b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json
new file mode 100644
index 000000000..ec84442d4
--- /dev/null
+++ b/extensions/ONNX/Onnx.FunctionalTests/appsettings.json
@@ -0,0 +1,67 @@
+{
+ "KernelMemory": {
+ "ServiceAuthorization": {
+ "Endpoint": "http://127.0.0.1:9001/",
+ "AccessKey": ""
+ },
+ "Services": {
+ "Onnx": {
+ // Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32"
+ "TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32"
+ },
+ "SimpleVectorDb": {
+ // Options: "Disk" or "Volatile". Volatile data is lost after each execution.
+ "StorageType": "Volatile",
+ // Directory where files are stored.
+ "Directory": "_vectors"
+ },
+ "AzureAISearch": {
+ // "ApiKey" or "AzureIdentity". For other options see .
+ // AzureIdentity: use automatic AAD authentication mechanism. You can test locally
+ // using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
+ "Auth": "AzureIdentity",
+ "Endpoint": "https://<...>",
+ "APIKey": ""
+ },
+ "Postgres": {
+ // Postgres instance connection string
+ "ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public",
+ // Mandatory prefix to add to the name of table managed by KM,
+ // e.g. to exclude other tables in the same schema.
+ "TableNamePrefix": "tests-"
+ },
+ "Qdrant": {
+ "Endpoint": "http://127.0.0.1:6333",
+ "APIKey": ""
+ },
+ "OpenAI": {
+ // Name of the model used to generate text (text completion or chat completion)
+ "TextModel": "gpt-4o-mini",
+ // The max number of tokens supported by the text model.
+ "TextModelMaxTokenTotal": 16384,
+ // What type of text generation, by default autodetect using the model name.
+ // Possible values: "Auto", "TextCompletion", "Chat"
+ "TextGenerationType": "Auto",
+ // Name of the model used to generate text embeddings
+ "EmbeddingModel": "text-embedding-ada-002",
+ // The max number of tokens supported by the embedding model
+ // See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
+ "EmbeddingModelMaxTokenTotal": 8191,
+ // OpenAI API Key
+ "APIKey": "",
+ // OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs)
+ "OrgId": "",
+ // Endpoint to use. By default the system uses 'https://api.openai.com/v1'.
+ // Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio.
+ "Endpoint": "",
+ // How many times to retry in case of throttling
+ "MaxRetries": 10
+ }
+ }
+ },
+ "Logging": {
+ "LogLevel": {
+ "Default": "Information"
+ }
+ }
+}
\ No newline at end of file
diff --git a/extensions/ONNX/Onnx/DependencyInjection.cs b/extensions/ONNX/Onnx/DependencyInjection.cs
new file mode 100644
index 000000000..55e3b5018
--- /dev/null
+++ b/extensions/ONNX/Onnx/DependencyInjection.cs
@@ -0,0 +1,60 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.KernelMemory.AI;
+using Microsoft.KernelMemory.AI.Onnx;
+
+#pragma warning disable IDE0130 // reduce number of "using" statements
+// ReSharper disable once CheckNamespace - reduce number of "using" statements
+namespace Microsoft.KernelMemory;
+
+///
+/// Kernel Memory builder extensions
+///
+public static partial class KernelMemoryBuilderExtensions
+{
+ public static IKernelMemoryBuilder WithOnnxTextGeneration(
+ this IKernelMemoryBuilder builder,
+ string modelPath,
+ uint maxTokenTotal,
+ ITextTokenizer? textTokenizer = null)
+ {
+ var config = new OnnxConfig
+ {
+ TextModelDir = modelPath
+ };
+
+ builder.Services.AddOnnxTextGeneration(config, textTokenizer);
+
+ return builder;
+ }
+
+ public static IKernelMemoryBuilder WithOnnxTextGeneration(
+ this IKernelMemoryBuilder builder,
+ OnnxConfig config,
+ ITextTokenizer? textTokenizer = null)
+ {
+ builder.Services.AddOnnxTextGeneration(config, textTokenizer);
+ return builder;
+ }
+}
+
+///
+/// .NET IServiceCollection dependency injection extensions.
+///
+public static partial class DependencyInjection
+{
+ public static IServiceCollection AddOnnxTextGeneration(
+ this IServiceCollection services,
+ OnnxConfig config,
+ ITextTokenizer? textTokenizer = null)
+ {
+ config.Validate();
+ return services
+ .AddSingleton(serviceProvider => new OnnxTextGenerator(
+ config: config,
+ textTokenizer: textTokenizer,
+ loggerFactory: serviceProvider.GetService()));
+ }
+}
diff --git a/extensions/ONNX/Onnx/Onnx.csproj b/extensions/ONNX/Onnx/Onnx.csproj
new file mode 100644
index 000000000..96ec73dff
--- /dev/null
+++ b/extensions/ONNX/Onnx/Onnx.csproj
@@ -0,0 +1,32 @@
+
+
+
+ net8.0
+ LatestMajor
+ Microsoft.KernelMemory.AI.Onnx
+ Microsoft.KernelMemory.AI.Onnx
+ $(NoWarn);KMEXP00;KMEXP01;CA1724;
+
+
+
+ true
+ Microsoft.KernelMemory.AI.Onnx
+ ONNX LLM connector for Kernel Memory
+ Provide access to ONNX LLM models in Kernel Memory to generate text
+ ONNX, Memory, Kernel Memory, Semantic Memory, Episodic Memory, Declarative Memory, AI, Artificial Intelligence, Semantic Search, Memory DB
+ bin/$(Configuration)/$(TargetFramework)/$(AssemblyName).xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/extensions/ONNX/Onnx/OnnxConfig.cs b/extensions/ONNX/Onnx/OnnxConfig.cs
new file mode 100644
index 000000000..4a0ce66fe
--- /dev/null
+++ b/extensions/ONNX/Onnx/OnnxConfig.cs
@@ -0,0 +1,171 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.IO;
+using System.Linq;
+
+#pragma warning disable IDE0130 // reduce number of "using" statements
+
+#pragma warning disable IDE0130 // reduce number of "using" statements
+// ReSharper disable once CheckNamespace - reduce number of "using" statements
+namespace Microsoft.KernelMemory;
+
+public class OnnxConfig
+{
+ ///
+ /// An enum representing the possible text generation search types used by OnnxTextGenerator.
+ /// See https://onnxruntime.ai/docs/genai/reference/config.html#search-combinations for more details.
+ ///
+ public enum OnnxSearchType
+ {
+ ///
+ /// A decoding algorithm that keeps track of the top K sequences at each step. It explores
+ /// multiple paths simultaneously, balancing exploration and exploitation. Often results in more
+ /// coherent and higher quality text generation than Greedy Search would.
+ ///
+ BeamSearch,
+
+ ///
+ /// The default and simplest decoding algorithm. At each step, a token is selected with the highest
+ /// probability as the next word in the sequence.
+ ///
+ GreedySearch,
+
+ ///
+ /// Combined Top-P (Nucleus) and Top-K Sampling: A decoding algorithm that samples from the top k tokens
+ /// with the highest probabilities, while also considering the smallest set of tokens whose cumulative
+ /// probability exceeds a threshold p. This approach dynamically balances diversity and coherence in
+ /// text generation by adjusting the sampling pool based on both fixed and cumulative probability criteria.
+ ///
+ TopN
+ }
+
+ ///
+ /// Path to the directory containing the .ONNX file for Text Generation.
+ ///
+ public string TextModelDir { get; set; } = string.Empty;
+
+ ///
+ /// The maximum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html
+ ///
+ public int MaxTokens { get; set; } = 2048;
+
+ ///
+ /// The minimum length of the response that the model will generate. See https://onnxruntime.ai/docs/genai/reference/config.html
+ ///
+ public uint MinLength { get; set; } = 0;
+
+ ///
+ /// The algorithm used in text generation. Defaults to GreedySearch.
+ ///
+ public OnnxSearchType SearchType { get; set; } = OnnxSearchType.GreedySearch;
+
+ ///
+ /// The number of beams to apply when generating the output sequence using beam search.
+ /// If NumBeams=1, then generation is performed using greedy search. If NumBeans > 1, then
+ /// generation is performed using beam search. A null value implies using TopN search.
+ ///
+ public uint? NumBeams { get; set; } = 1;
+
+ ///
+ /// Only includes the most probable tokens with probabilities that add up to P or higher.
+ /// Defaults to 1, which includes all of the tokens. Range is 0 to 1, exclusive of 0.
+ ///
+ public double NucleusSampling { get; set; } = 1.0;
+
+ ///
+ /// Whether to stop the beam search when at least NumBeams sentences are finished per batch or not. Defaults to false.
+ ///
+ public bool EarlyStopping { get; set; } = false;
+
+ ///
+ /// The number of sequences (responses) to generate. Returns the sequences with the highest scores in order.
+ ///
+ public int ResultsPerPrompt { get; set; } = 1;
+
+ ///
+ /// Only includes tokens that fall within the list of the K most probable tokens. Range is 1 to the vocabulary size.
+ /// Defaults to 50.
+ ///
+ public uint TopK { get; set; } = 50;
+
+ ///
+ /// Discounts the scores of previously generated tokens if set to a value greater than 1.
+ /// Defaults to 1.
+ ///
+ public double RepetitionPenalty { get; set; } = 1.0;
+
+ ///
+ /// Controls the length of the output generated. Value less than 1 encourages the generation
+ /// to produce shorter sequences. Values greater than 1 encourages longer sequences. Defaults to 1.
+ ///
+ public double LengthPenalty { get; set; } = 1.0;
+
+ ///
+ /// Verify that the current state is valid.
+ ///
+ public void Validate(bool allowIO = true)
+ {
+ if (string.IsNullOrEmpty(this.TextModelDir))
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.TextModelDir)} is a required field.");
+ }
+
+ var modelDir = Path.GetFullPath(this.TextModelDir);
+
+ if (allowIO)
+ {
+ if (!Directory.Exists(modelDir))
+ {
+ throw new ConfigurationException($"Onnx: {this.TextModelDir} does not exist.");
+ }
+
+ if (Directory.GetFiles(modelDir).Length == 0)
+ {
+ throw new ConfigurationException($"Onnx: {this.TextModelDir} is an empty directory.");
+ }
+
+ var modelFiles = Directory.GetFiles(modelDir)
+ .Where(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase));
+
+ if (modelFiles == null)
+ {
+ throw new ConfigurationException($"Onnx: {this.TextModelDir} does not contain a valid .ONNX model.");
+ }
+ }
+
+ if (this.SearchType == OnnxSearchType.GreedySearch)
+ {
+ if (this.NumBeams != 1)
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is only used with Beam Search. Change {nameof(this.NumBeams)} to 1, or change {nameof(this.SearchType)} to BeamSearch.");
+ }
+
+ if (this.EarlyStopping != false)
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch.");
+ }
+ }
+
+ if (this.SearchType == OnnxSearchType.BeamSearch)
+ {
+ if (this.NumBeams == null)
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} is required for Beam Search. Change {nameof(this.NumBeams)} to a value >= 1, or change the {nameof(this.SearchType)}.");
+ }
+ }
+
+ if (this.SearchType == OnnxSearchType.TopN)
+ {
+ if (this.NumBeams != null)
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.NumBeams)} isn't required with TopN Search. Change {nameof(this.NumBeams)} to null, or change the {nameof(this.SearchType)}.");
+ }
+
+ if (this.EarlyStopping != false)
+ {
+ throw new ConfigurationException($"Onnx: {nameof(this.EarlyStopping)} is only used with Beam Search. Change {nameof(this.EarlyStopping)} to false, or change {nameof(this.SearchType)} to BeamSearch.");
+ }
+ }
+ }
+}
diff --git a/extensions/ONNX/Onnx/OnnxTextGenerator.cs b/extensions/ONNX/Onnx/OnnxTextGenerator.cs
new file mode 100644
index 000000000..f0dbab49c
--- /dev/null
+++ b/extensions/ONNX/Onnx/OnnxTextGenerator.cs
@@ -0,0 +1,190 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Logging;
+using Microsoft.KernelMemory.AI.OpenAI;
+using Microsoft.KernelMemory.Diagnostics;
+using Microsoft.ML.OnnxRuntimeGenAI;
+using static Microsoft.KernelMemory.OnnxConfig;
+
+namespace Microsoft.KernelMemory.AI.Onnx;
+
+///
+/// Text generator based on ONNX models, via OnnxRuntimeGenAi
+/// See https://github.com/microsoft/onnxruntime-genai
+///
+[Experimental("KMEXP01")]
+public sealed class OnnxTextGenerator : ITextGenerator, IDisposable
+{
+ ///
+ /// The ONNX Model used for text generation
+ ///
+ private readonly Model _model;
+
+ ///
+ /// Tokenizer used with the Onnx Generator and Model classes to produce tokens.
+ /// This has the potential to contain a null value, depending on the contents of the Model Directory.
+ ///
+ private readonly Tokenizer? _tokenizer = default;
+
+ ///
+ /// Tokenizer used for GetTokens() and CountTokens()
+ ///
+ private readonly ITextTokenizer _textTokenizer;
+
+ private readonly ILogger _log;
+
+ private readonly OnnxConfig _config;
+
+ ///
+ public int MaxTokenTotal { get; internal set; }
+
+ ///
+ /// Create a new instance
+ ///
+ /// Configuration settings
+ /// Text Tokenizer
+ /// Application Logger instance
+ public OnnxTextGenerator(
+ OnnxConfig config,
+ ITextTokenizer? textTokenizer = null,
+ ILoggerFactory? loggerFactory = null)
+ {
+ this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger();
+ if (textTokenizer == null)
+ {
+ this._log.LogWarning(
+ "Tokenizer not specified, will use {0}. The token count might be incorrect, causing unexpected errors",
+ nameof(GPT4oTokenizer));
+ textTokenizer = new GPT4oTokenizer();
+ }
+
+ config.Validate();
+ this._config = config;
+ this.MaxTokenTotal = (int)config.MaxTokens;
+ this._textTokenizer = textTokenizer;
+
+ var modelDir = Path.GetFullPath(config.TextModelDir);
+ var modelFile = Directory.GetFiles(modelDir)
+ .FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase));
+
+ this._log.LogDebug("Loading Onnx model: {1} from directory {0}", modelDir, Path.GetFileNameWithoutExtension(modelFile));
+ this._model = new Model(config.TextModelDir);
+ this._tokenizer = new Tokenizer(this._model);
+ this._log.LogDebug("Onnx model loaded");
+ }
+
+ ///
+ public async IAsyncEnumerable GenerateTextAsync(
+ string prompt,
+ TextGenerationOptions? options = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ var tokens = this._tokenizer?.Encode(prompt);
+ using var generatorParams = new GeneratorParams(this._model);
+
+ generatorParams.SetSearchOption("max_length", this.MaxTokenTotal);
+ generatorParams.SetSearchOption("min_length", this._config.MinLength);
+ generatorParams.SetSearchOption("num_return_sequences", this._config.ResultsPerPrompt);
+ generatorParams.SetSearchOption("repetition_penalty", this._config.RepetitionPenalty);
+ generatorParams.SetSearchOption("length_penalty", this._config.LengthPenalty);
+ generatorParams.SetSearchOption("temperature", 0);
+
+ if (options != null)
+ {
+ generatorParams.SetSearchOption("num_return_sequences", options.ResultsPerPrompt);
+ generatorParams.SetSearchOption("temperature", options.Temperature);
+
+ if (options.MaxTokens > 0)
+ {
+ generatorParams.SetSearchOption("max_length", (int)options.MaxTokens);
+ }
+ }
+
+ switch (this._config.SearchType)
+ {
+ case OnnxSearchType.BeamSearch:
+ generatorParams.SetSearchOption("do_sample", false);
+ generatorParams.SetSearchOption("early_stopping", this._config.EarlyStopping);
+
+ if (this._config.NumBeams != null)
+ {
+ generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams);
+ }
+
+ break;
+
+ case OnnxSearchType.TopN:
+ generatorParams.SetSearchOption("do_sample", true);
+ generatorParams.SetSearchOption("top_k", this._config.TopK);
+
+ generatorParams.SetSearchOption("top_p", options is { NucleusSampling: > 0 and <= 1 }
+ ? options.NucleusSampling
+ : this._config.NucleusSampling);
+
+ break;
+
+ default:
+
+ generatorParams.SetSearchOption("do_sample", false);
+
+ if (this._config.NumBeams != null)
+ {
+ generatorParams.SetSearchOption("num_beams", (double)this._config.NumBeams);
+ }
+
+ break;
+ }
+
+ generatorParams.SetInputSequences(tokens);
+
+ using (var generator = new Generator(this._model, generatorParams))
+ {
+ List outputTokens = new();
+
+ while (!generator.IsDone() && cancellationToken.IsCancellationRequested == false)
+ {
+ generator.ComputeLogits();
+ generator.GenerateNextToken();
+
+ outputTokens.AddRange(generator.GetSequence(0));
+
+ if (outputTokens.Count > 0 && this._tokenizer != null)
+ {
+ var newToken = outputTokens[^1];
+ yield return this._tokenizer.Decode(new int[] { newToken });
+ }
+ }
+ }
+
+ await Task.CompletedTask.ConfigureAwait(false);
+ }
+
+ ///
+ public int CountTokens(string text)
+ {
+ // TODO: Implement with _tokenizer and remove _textTokenizer
+ return this._textTokenizer.CountTokens(text);
+ }
+
+ ///
+ public IReadOnlyList GetTokens(string text)
+ {
+ // TODO: Implement with _tokenizer and remove _textTokenizer
+ return this._textTokenizer.GetTokens(text);
+ }
+
+ ///
+ public void Dispose()
+ {
+ this._model?.Dispose();
+ this._tokenizer?.Dispose();
+ }
+}
diff --git a/extensions/ONNX/README.md b/extensions/ONNX/README.md
new file mode 100644
index 000000000..2ab365705
--- /dev/null
+++ b/extensions/ONNX/README.md
@@ -0,0 +1,25 @@
+# Kernel Memory with ONNX
+
+[![Nuget package](https://img.shields.io/nuget/v/Microsoft.KernelMemory.AI.Onnx)](https://www.nuget.org/packages/Microsoft.KernelMemory.AI.Onnx/)
+[![Discord](https://img.shields.io/discord/1063152441819942922?label=Discord&logo=discord&logoColor=white&color=d82679)](https://aka.ms/KMdiscord)
+
+This project contains the
+[ONNX](https://onnxruntime.ai/docs/genai/)
+LLM connector to access to LLM models via Onnx service to generate text.
+
+Sample code:
+
+```csharp
+var config = new OnnxConfig
+{
+ ModelPath = "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32"
+};
+
+var memory = new KernelMemoryBuilder()
+ .WithOnnxTextGeneration(config)
+ .Build();
+
+await memory.ImportTextAsync("Today is October 32nd, 2476");
+
+var answer = await memory.AskAsync("What's the current date (don't check for validity)?");
+```
diff --git a/service/tests/TestHelpers/BaseFunctionalTestCase.cs b/service/tests/TestHelpers/BaseFunctionalTestCase.cs
index 856bde08c..d73110713 100644
--- a/service/tests/TestHelpers/BaseFunctionalTestCase.cs
+++ b/service/tests/TestHelpers/BaseFunctionalTestCase.cs
@@ -31,6 +31,7 @@ public abstract class BaseFunctionalTestCase : IDisposable
protected readonly SimpleVectorDbConfig SimpleVectorDbConfig;
protected readonly LlamaSharpConfig LlamaSharpConfig;
protected readonly ElasticsearchConfig ElasticsearchConfig;
+ protected readonly OnnxConfig OnnxConfig;
// IMPORTANT: install Xunit.DependencyInjection package
protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output)
@@ -50,6 +51,7 @@ protected BaseFunctionalTestCase(IConfiguration cfg, ITestOutputHelper output)
this.SimpleVectorDbConfig = cfg.GetSection("KernelMemory:Services:SimpleVectorDb").Get() ?? new();
this.LlamaSharpConfig = cfg.GetSection("KernelMemory:Services:LlamaSharp").Get() ?? new();
this.ElasticsearchConfig = cfg.GetSection("KernelMemory:Services:Elasticsearch").Get() ?? new();
+ this.OnnxConfig = cfg.GetSection("KernelMemory:Services:Onnx").Get() ?? new();
}
protected IKernelMemory GetMemoryWebClient()
diff --git a/service/tests/TestHelpers/TestHelpers.csproj b/service/tests/TestHelpers/TestHelpers.csproj
index 3996f22f9..b0b991041 100644
--- a/service/tests/TestHelpers/TestHelpers.csproj
+++ b/service/tests/TestHelpers/TestHelpers.csproj
@@ -16,6 +16,7 @@
+