Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added initial implementation of Microsoft.KernelMemory.AI.ONNX package #788

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<PackageVersion Include="Microsoft.Extensions.Http" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.Logging.TraceSource" Version="8.0.0" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.4.0" />
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24378.1" />
<PackageVersion Include="Microsoft.KernelMemory.Core" Version="0.75.240924.1" />
<PackageVersion Include="Microsoft.KernelMemory.Service.AspNetCore" Version="0.75.240924.1" />
Expand Down
29 changes: 21 additions & 8 deletions KernelMemory.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-19E2-46EB-B000-513EE2F65769}"
ProjectSection(SolutionItems) = preProject
docs\404.html = docs\404.html
docs\azure.md = docs\azure.md
docs\concepts.md = docs\concepts.md
docs\csharp.png = docs\csharp.png
docs\extensions.md = docs\extensions.md
Expand All @@ -31,7 +32,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-1
docs\service.md = docs\service.md
docs\_config.local.yml = docs\_config.local.yml
docs\_config.yml = docs\_config.yml
docs\azure.md = docs\azure.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{0A43C65C-6007-4BB4-B3FE-8D439FC91841}"
Expand Down Expand Up @@ -70,18 +70,18 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{6EF76FD8-4
.editorconfig = .editorconfig
.gitattributes = .gitattributes
.gitignore = .gitignore
azure.yaml = azure.yaml
CODE_OF_CONDUCT.md = CODE_OF_CONDUCT.md
COMMUNITY.md = COMMUNITY.md
CONTRIBUTING.md = CONTRIBUTING.md
Directory.Build.props = Directory.Build.props
Directory.Packages.props = Directory.Packages.props
Dockerfile = Dockerfile
KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
LICENSE = LICENSE
nuget.config = nuget.config
README.md = README.md
SECURITY.md = SECURITY.md
KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
azure.yaml = azure.yaml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".github", ".github", "{B8976338-7CDC-47AE-8502-C2FBAFBEBD68}"
Expand All @@ -94,10 +94,10 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{48E79819-1E9E-4075-90DA-BAEC761C89B2}"
ProjectSection(SolutionItems) = preProject
.github\workflows\docker-build-push.yml = .github\workflows\docker-build-push.yml
.github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
.github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
.github\workflows\github-pages-jekyll.yml = .github\workflows\github-pages-jekyll.yml
.github\workflows\spell-check-with-typos.yml = .github\workflows\spell-check-with-typos.yml
.github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
.github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "service", "service", "{87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}"
Expand All @@ -116,9 +116,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{CA49F1A1
tools\run-qdrant.sh = tools\run-qdrant.sh
tools\run-rabbitmq.sh = tools\run-rabbitmq.sh
tools\run-redis.sh = tools\run-redis.sh
tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\search.sh = tools\search.sh
tools\upload-file.sh = tools\upload-file.sh
tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\dockerize-amd64.sh = tools\dockerize-amd64.sh
tools\dockerize-arm64.sh = tools\dockerize-arm64.sh
EndProjectSection
Expand Down Expand Up @@ -243,12 +243,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "109-dotnet-custom-webscrape
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "infra", "infra", "{B488168B-AD86-4CC5-9D89-324B6EB743D9}"
ProjectSection(SolutionItems) = preProject
infra\AZD.md = infra\AZD.md
infra\build-main.json.sh = infra\build-main.json.sh
infra\main.bicep = infra\main.bicep
infra\main.json = infra\main.json
infra\README.md = infra\README.md
infra\AZD.md = infra\AZD.md
infra\main.parameters.json = infra\main.parameters.json
infra\README.md = infra\README.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3A947-B6F9-4306-BD42-21D8D1F42750}"
Expand Down Expand Up @@ -323,6 +323,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "212-dotnet-ollama", "exampl
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Ollama", "extensions\Ollama\Ollama\Ollama.csproj", "{F192513B-265B-4943-A2A9-44E23B15BA18}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx.FunctionalTests", "extensions\ONNX\Onnx.FunctionalTests\Onnx.FunctionalTests.csproj", "{7BBD348E-CDD9-4462-B8C9-47613C5EC682}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx", "extensions\ONNX\Onnx\Onnx.csproj", "{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -591,6 +595,13 @@ Global
{F192513B-265B-4943-A2A9-44E23B15BA18}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.Build.0 = Release|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Release|Any CPU.ActiveCfg = Release|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.Build.0 = Debug|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.ActiveCfg = Release|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -685,6 +696,8 @@ Global
{84AEC1DD-CBAE-400A-949C-91BA373C587D} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{B303885D-F64F-4EEB-B085-0014E863AF61} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{F192513B-265B-4943-A2A9-44E23B15BA18} = {155DA079-E267-49AF-973A-D1D44681970F}
{7BBD348E-CDD9-4462-B8C9-47613C5EC682} = {3C17F42B-CFC8-4900-8CFB-88936311E919}
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151} = {155DA079-E267-49AF-973A-D1D44681970F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CC136C62-115C-41D1-B414-F9473EFF6EA8}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<AssemblyName>Microsoft.Onnx.FunctionalTests</AssemblyName>
<RootNamespace>Microsoft.Onnx.FunctionalTests</RootNamespace>
<TargetFramework>net8.0</TargetFramework>
<RollForward>LatestMajor</RollForward>
<IsTestProject>true</IsTestProject>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);KMEXP01;</NoWarn>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\service\tests\TestHelpers\TestHelpers.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Xunit.DependencyInjection" />
<PackageReference Include="Xunit.DependencyInjection.Logging" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.abstractions" />
<PackageReference Include="xunit.runner.visualstudio">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="coverlet.collector">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
</ItemGroup>

</Project>
67 changes: 67 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics;
using System.Text;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.Onnx;
using Microsoft.KM.TestHelpers;
using Xunit.Abstractions;

namespace Microsoft.Onnx.FunctionalTests;

public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase
{
private readonly OnnxTextGenerator _target;
private readonly Stopwatch _timer;

public OnnxTextGeneratorTest(
IConfiguration cfg,
ITestOutputHelper output) : base(cfg, output)
{
this._timer = new Stopwatch();

this.OnnxConfig.Validate();
this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null);

var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir);
var modelFile = Directory.GetFiles(modelDirectory)
.FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase));

Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}");
}

[Fact]
[Trait("Category", "Onnx")]
public async Task ItGeneratesText()
{
var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy");
var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}.";
var question = $"What is the current date?";
var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>";

var options = new TextGenerationOptions();

// Act
this._timer.Restart();
var tokens = this._target.GenerateTextAsync(prompt, options);
var result = new StringBuilder();
await foreach (string token in tokens)
{
result.Append(token);
}

this._timer.Stop();
var answer = result.ToString();

// Assert
Console.WriteLine($"Model Output:\n=============================\n{answer}\n=============================");
Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs");
Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase);
}

protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
this._target.Dispose();
}
}
22 changes: 22 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/Startup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.

/* IMPORTANT: the Startup class must be at the root of the namespace and
* the namespace must match exactly (required by Xunit.DependencyInjection) */

namespace Microsoft.Onnx.FunctionalTests;

public class Startup
{
public void ConfigureHost(IHostBuilder hostBuilder)
{
var config = new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.AddJsonFile("appsettings.development.json", optional: true)
.AddJsonFile("appsettings.Development.json", optional: true)
.AddUserSecrets<Startup>()
.AddEnvironmentVariables()
.Build();

hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config));
}
}
3 changes: 3 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/Usings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Copyright (c) Microsoft. All rights reserved.

global using Xunit;
67 changes: 67 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/appsettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"KernelMemory": {
"ServiceAuthorization": {
"Endpoint": "http://127.0.0.1:9001/",
"AccessKey": ""
},
"Services": {
"Onnx": {
// Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32"
"TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32"
},
"SimpleVectorDb": {
// Options: "Disk" or "Volatile". Volatile data is lost after each execution.
"StorageType": "Volatile",
// Directory where files are stored.
"Directory": "_vectors"
},
"AzureAISearch": {
// "ApiKey" or "AzureIdentity". For other options see <AzureAISearchConfig>.
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>",
"APIKey": ""
},
"Postgres": {
// Postgres instance connection string
"ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public",
// Mandatory prefix to add to the name of table managed by KM,
// e.g. to exclude other tables in the same schema.
"TableNamePrefix": "tests-"
},
"Qdrant": {
"Endpoint": "http://127.0.0.1:6333",
"APIKey": ""
},
"OpenAI": {
// Name of the model used to generate text (text completion or chat completion)
"TextModel": "gpt-4o-mini",
// The max number of tokens supported by the text model.
"TextModelMaxTokenTotal": 16384,
// What type of text generation, by default autodetect using the model name.
// Possible values: "Auto", "TextCompletion", "Chat"
"TextGenerationType": "Auto",
// Name of the model used to generate text embeddings
"EmbeddingModel": "text-embedding-ada-002",
// The max number of tokens supported by the embedding model
// See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
"EmbeddingModelMaxTokenTotal": 8191,
// OpenAI API Key
"APIKey": "",
// OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs)
"OrgId": "",
// Endpoint to use. By default the system uses 'https://api.openai.com/v1'.
// Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio.
"Endpoint": "",
// How many times to retry in case of throttling
"MaxRetries": 10
}
}
},
"Logging": {
"LogLevel": {
"Default": "Information"
}
}
}
60 changes: 60 additions & 0 deletions extensions/ONNX/Onnx/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.Onnx;

#pragma warning disable IDE0130 // reduce number of "using" statements
// ReSharper disable once CheckNamespace - reduce number of "using" statements
namespace Microsoft.KernelMemory;

/// <summary>
/// Kernel Memory builder extensions
/// </summary>
public static partial class KernelMemoryBuilderExtensions
{
public static IKernelMemoryBuilder WithOnnxTextGeneration(
this IKernelMemoryBuilder builder,
string modelPath,
uint maxTokenTotal,
ITextTokenizer? textTokenizer = null)
{
var config = new OnnxConfig
{
TextModelDir = modelPath
};

builder.Services.AddOnnxTextGeneration(config, textTokenizer);

return builder;
}

public static IKernelMemoryBuilder WithOnnxTextGeneration(
this IKernelMemoryBuilder builder,
OnnxConfig config,
ITextTokenizer? textTokenizer = null)
{
builder.Services.AddOnnxTextGeneration(config, textTokenizer);
return builder;
}
}

/// <summary>
/// .NET IServiceCollection dependency injection extensions.
/// </summary>
public static partial class DependencyInjection
{
public static IServiceCollection AddOnnxTextGeneration(
this IServiceCollection services,
OnnxConfig config,
ITextTokenizer? textTokenizer = null)
{
config.Validate();
return services
.AddSingleton<ITextGenerator, OnnxTextGenerator>(serviceProvider => new OnnxTextGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
}
}
Loading
Loading