Skip to content

Commit

Permalink
Added initial implementation of Microsoft.KernelMemory.AI.ONNX package (
Browse files Browse the repository at this point in the history
#788)

## Motivation and Context (Why the change? What's the scenario?)

This change includes an entirely new extension that allows
TextGeneration with ONNX Models. This scenario allows users to interact
with local cloned ONNX Models

## High level description (Approach, Design)

This is a scenario that abstracts the general implementation of the
OnnxRuntimeGenAI library to work with text generation. More examples on
using OnnxGenAI Libraries, please see
https://github.com/microsoft/onnxruntime-genai

---------

Co-authored-by: Justin Ridings <[email protected]>
Co-authored-by: Devis Lucato <[email protected]>
  • Loading branch information
3 people authored Oct 5, 2024
1 parent 6ad1d81 commit 59e9f5d
Show file tree
Hide file tree
Showing 14 changed files with 698 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
<PackageVersion Include="Microsoft.Extensions.Http" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
<PackageVersion Include="Microsoft.Extensions.Logging.TraceSource" Version="8.0.0" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI" Version="0.4.0" />
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24378.1" />
<PackageVersion Include="Microsoft.KernelMemory.Core" Version="0.75.240924.1" />
<PackageVersion Include="Microsoft.KernelMemory.Service.AspNetCore" Version="0.75.240924.1" />
Expand Down
29 changes: 21 additions & 8 deletions KernelMemory.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-19E2-46EB-B000-513EE2F65769}"
ProjectSection(SolutionItems) = preProject
docs\404.html = docs\404.html
docs\azure.md = docs\azure.md
docs\concepts.md = docs\concepts.md
docs\csharp.png = docs\csharp.png
docs\extensions.md = docs\extensions.md
Expand All @@ -31,7 +32,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{7BA7F1B2-1
docs\service.md = docs\service.md
docs\_config.local.yml = docs\_config.local.yml
docs\_config.yml = docs\_config.yml
docs\azure.md = docs\azure.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "examples", "examples", "{0A43C65C-6007-4BB4-B3FE-8D439FC91841}"
Expand Down Expand Up @@ -70,18 +70,18 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "root", "root", "{6EF76FD8-4
.editorconfig = .editorconfig
.gitattributes = .gitattributes
.gitignore = .gitignore
azure.yaml = azure.yaml
CODE_OF_CONDUCT.md = CODE_OF_CONDUCT.md
COMMUNITY.md = COMMUNITY.md
CONTRIBUTING.md = CONTRIBUTING.md
Directory.Build.props = Directory.Build.props
Directory.Packages.props = Directory.Packages.props
Dockerfile = Dockerfile
KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
LICENSE = LICENSE
nuget.config = nuget.config
README.md = README.md
SECURITY.md = SECURITY.md
KernelMemory.sln.DotSettings = KernelMemory.sln.DotSettings
azure.yaml = azure.yaml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = ".github", ".github", "{B8976338-7CDC-47AE-8502-C2FBAFBEBD68}"
Expand All @@ -94,10 +94,10 @@ EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "workflows", "workflows", "{48E79819-1E9E-4075-90DA-BAEC761C89B2}"
ProjectSection(SolutionItems) = preProject
.github\workflows\docker-build-push.yml = .github\workflows\docker-build-push.yml
.github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
.github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
.github\workflows\github-pages-jekyll.yml = .github\workflows\github-pages-jekyll.yml
.github\workflows\spell-check-with-typos.yml = .github\workflows\spell-check-with-typos.yml
.github\workflows\dotnet-unit-tests.yml = .github\workflows\dotnet-unit-tests.yml
.github\workflows\dotnet-build.yml = .github\workflows\dotnet-build.yml
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "service", "service", "{87DEAE8D-138C-4FDD-B4C9-11C3A7817E8F}"
Expand All @@ -116,9 +116,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools", "tools", "{CA49F1A1
tools\run-qdrant.sh = tools\run-qdrant.sh
tools\run-rabbitmq.sh = tools\run-rabbitmq.sh
tools\run-redis.sh = tools\run-redis.sh
tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\search.sh = tools\search.sh
tools\upload-file.sh = tools\upload-file.sh
tools\run-s3ninja.sh = tools\run-s3ninja.sh
tools\dockerize-amd64.sh = tools\dockerize-amd64.sh
tools\dockerize-arm64.sh = tools\dockerize-arm64.sh
EndProjectSection
Expand Down Expand Up @@ -243,12 +243,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "109-dotnet-custom-webscrape
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "infra", "infra", "{B488168B-AD86-4CC5-9D89-324B6EB743D9}"
ProjectSection(SolutionItems) = preProject
infra\AZD.md = infra\AZD.md
infra\build-main.json.sh = infra\build-main.json.sh
infra\main.bicep = infra\main.bicep
infra\main.json = infra\main.json
infra\README.md = infra\README.md
infra\AZD.md = infra\AZD.md
infra\main.parameters.json = infra\main.parameters.json
infra\README.md = infra\README.md
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{C2D3A947-B6F9-4306-BD42-21D8D1F42750}"
Expand Down Expand Up @@ -323,6 +323,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "212-dotnet-ollama", "exampl
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Ollama", "extensions\Ollama\Ollama\Ollama.csproj", "{F192513B-265B-4943-A2A9-44E23B15BA18}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx.FunctionalTests", "extensions\ONNX\Onnx.FunctionalTests\Onnx.FunctionalTests.csproj", "{7BBD348E-CDD9-4462-B8C9-47613C5EC682}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Onnx", "extensions\ONNX\Onnx\Onnx.csproj", "{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -591,6 +595,13 @@ Global
{F192513B-265B-4943-A2A9-44E23B15BA18}.Debug|Any CPU.Build.0 = Debug|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.ActiveCfg = Release|Any CPU
{F192513B-265B-4943-A2A9-44E23B15BA18}.Release|Any CPU.Build.0 = Release|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Debug|Any CPU.Build.0 = Debug|Any CPU
{7BBD348E-CDD9-4462-B8C9-47613C5EC682}.Release|Any CPU.ActiveCfg = Release|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Debug|Any CPU.Build.0 = Debug|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.ActiveCfg = Release|Any CPU
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -685,6 +696,8 @@ Global
{84AEC1DD-CBAE-400A-949C-91BA373C587D} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{B303885D-F64F-4EEB-B085-0014E863AF61} = {0A43C65C-6007-4BB4-B3FE-8D439FC91841}
{F192513B-265B-4943-A2A9-44E23B15BA18} = {155DA079-E267-49AF-973A-D1D44681970F}
{7BBD348E-CDD9-4462-B8C9-47613C5EC682} = {3C17F42B-CFC8-4900-8CFB-88936311E919}
{345DEF9B-6EE1-49DF-B46A-25E38CE9B151} = {155DA079-E267-49AF-973A-D1D44681970F}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CC136C62-115C-41D1-B414-F9473EFF6EA8}
Expand Down
36 changes: 36 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<AssemblyName>Microsoft.Onnx.FunctionalTests</AssemblyName>
<RootNamespace>Microsoft.Onnx.FunctionalTests</RootNamespace>
<TargetFramework>net8.0</TargetFramework>
<RollForward>LatestMajor</RollForward>
<IsTestProject>true</IsTestProject>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);KMEXP01;</NoWarn>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\service\tests\TestHelpers\TestHelpers.csproj" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.DependencyInjection" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Xunit.DependencyInjection" />
<PackageReference Include="Xunit.DependencyInjection.Logging" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.abstractions" />
<PackageReference Include="xunit.runner.visualstudio">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="coverlet.collector">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
</ItemGroup>

</Project>
67 changes: 67 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics;
using System.Text;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.Onnx;
using Microsoft.KM.TestHelpers;
using Xunit.Abstractions;

namespace Microsoft.Onnx.FunctionalTests;

public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase
{
private readonly OnnxTextGenerator _target;
private readonly Stopwatch _timer;

public OnnxTextGeneratorTest(
IConfiguration cfg,
ITestOutputHelper output) : base(cfg, output)
{
this._timer = new Stopwatch();

this.OnnxConfig.Validate();
this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null);

var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir);
var modelFile = Directory.GetFiles(modelDirectory)
.FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase));

Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}");
}

[Fact]
[Trait("Category", "Onnx")]
public async Task ItGeneratesText()
{
var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy");
var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}.";
var question = $"What is the current date?";
var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>";

var options = new TextGenerationOptions();

// Act
this._timer.Restart();
var tokens = this._target.GenerateTextAsync(prompt, options);
var result = new StringBuilder();
await foreach (string token in tokens)
{
result.Append(token);
}

this._timer.Stop();
var answer = result.ToString();

// Assert
Console.WriteLine($"Model Output:\n=============================\n{answer}\n=============================");
Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs");
Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase);
}

protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
this._target.Dispose();
}
}
22 changes: 22 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/Startup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft. All rights reserved.

/* IMPORTANT: the Startup class must be at the root of the namespace and
* the namespace must match exactly (required by Xunit.DependencyInjection) */

namespace Microsoft.Onnx.FunctionalTests;

public class Startup
{
public void ConfigureHost(IHostBuilder hostBuilder)
{
var config = new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.AddJsonFile("appsettings.development.json", optional: true)
.AddJsonFile("appsettings.Development.json", optional: true)
.AddUserSecrets<Startup>()
.AddEnvironmentVariables()
.Build();

hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config));
}
}
3 changes: 3 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/Usings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Copyright (c) Microsoft. All rights reserved.

global using Xunit;
67 changes: 67 additions & 0 deletions extensions/ONNX/Onnx.FunctionalTests/appsettings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"KernelMemory": {
"ServiceAuthorization": {
"Endpoint": "http://127.0.0.1:9001/",
"AccessKey": ""
},
"Services": {
"Onnx": {
// Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32"
"TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32"
},
"SimpleVectorDb": {
// Options: "Disk" or "Volatile". Volatile data is lost after each execution.
"StorageType": "Volatile",
// Directory where files are stored.
"Directory": "_vectors"
},
"AzureAISearch": {
// "ApiKey" or "AzureIdentity". For other options see <AzureAISearchConfig>.
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET.
"Auth": "AzureIdentity",
"Endpoint": "https://<...>",
"APIKey": ""
},
"Postgres": {
// Postgres instance connection string
"ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public",
// Mandatory prefix to add to the name of table managed by KM,
// e.g. to exclude other tables in the same schema.
"TableNamePrefix": "tests-"
},
"Qdrant": {
"Endpoint": "http://127.0.0.1:6333",
"APIKey": ""
},
"OpenAI": {
// Name of the model used to generate text (text completion or chat completion)
"TextModel": "gpt-4o-mini",
// The max number of tokens supported by the text model.
"TextModelMaxTokenTotal": 16384,
// What type of text generation, by default autodetect using the model name.
// Possible values: "Auto", "TextCompletion", "Chat"
"TextGenerationType": "Auto",
// Name of the model used to generate text embeddings
"EmbeddingModel": "text-embedding-ada-002",
// The max number of tokens supported by the embedding model
// See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
"EmbeddingModelMaxTokenTotal": 8191,
// OpenAI API Key
"APIKey": "",
// OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs)
"OrgId": "",
// Endpoint to use. By default the system uses 'https://api.openai.com/v1'.
// Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio.
"Endpoint": "",
// How many times to retry in case of throttling
"MaxRetries": 10
}
}
},
"Logging": {
"LogLevel": {
"Default": "Information"
}
}
}
60 changes: 60 additions & 0 deletions extensions/ONNX/Onnx/DependencyInjection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.AI.Onnx;

#pragma warning disable IDE0130 // reduce number of "using" statements
// ReSharper disable once CheckNamespace - reduce number of "using" statements
namespace Microsoft.KernelMemory;

/// <summary>
/// Kernel Memory builder extensions
/// </summary>
public static partial class KernelMemoryBuilderExtensions
{
public static IKernelMemoryBuilder WithOnnxTextGeneration(
this IKernelMemoryBuilder builder,
string modelPath,
uint maxTokenTotal,
ITextTokenizer? textTokenizer = null)
{
var config = new OnnxConfig
{
TextModelDir = modelPath
};

builder.Services.AddOnnxTextGeneration(config, textTokenizer);

return builder;
}

public static IKernelMemoryBuilder WithOnnxTextGeneration(
this IKernelMemoryBuilder builder,
OnnxConfig config,
ITextTokenizer? textTokenizer = null)
{
builder.Services.AddOnnxTextGeneration(config, textTokenizer);
return builder;
}
}

/// <summary>
/// .NET IServiceCollection dependency injection extensions.
/// </summary>
public static partial class DependencyInjection
{
public static IServiceCollection AddOnnxTextGeneration(
this IServiceCollection services,
OnnxConfig config,
ITextTokenizer? textTokenizer = null)
{
config.Validate();
return services
.AddSingleton<ITextGenerator, OnnxTextGenerator>(serviceProvider => new OnnxTextGenerator(
config: config,
textTokenizer: textTokenizer,
loggerFactory: serviceProvider.GetService<ILoggerFactory>()));
}
}
Loading

0 comments on commit 59e9f5d

Please sign in to comment.