-
Notifications
You must be signed in to change notification settings - Fork 294
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added initial implementation of Microsoft.KernelMemory.AI.ONNX package (
#788) ## Motivation and Context (Why the change? What's the scenario?) This change includes an entirely new extension that allows TextGeneration with ONNX Models. This scenario allows users to interact with local cloned ONNX Models ## High level description (Approach, Design) This is a scenario that abstracts the general implementation of the OnnxRuntimeGenAI library to work with text generation. More examples on using OnnxGenAI Libraries, please see https://github.com/microsoft/onnxruntime-genai --------- Co-authored-by: Justin Ridings <[email protected]> Co-authored-by: Devis Lucato <[email protected]>
- Loading branch information
1 parent
6ad1d81
commit 59e9f5d
Showing
14 changed files
with
698 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
extensions/ONNX/Onnx.FunctionalTests/Onnx.FunctionalTests.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
<Project Sdk="Microsoft.NET.Sdk.Web"> | ||
|
||
<PropertyGroup> | ||
<AssemblyName>Microsoft.Onnx.FunctionalTests</AssemblyName> | ||
<RootNamespace>Microsoft.Onnx.FunctionalTests</RootNamespace> | ||
<TargetFramework>net8.0</TargetFramework> | ||
<RollForward>LatestMajor</RollForward> | ||
<IsTestProject>true</IsTestProject> | ||
<ImplicitUsings>enable</ImplicitUsings> | ||
<Nullable>enable</Nullable> | ||
<IsPackable>false</IsPackable> | ||
<NoWarn>$(NoWarn);KMEXP01;</NoWarn> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\..\service\tests\TestHelpers\TestHelpers.csproj" /> | ||
</ItemGroup> | ||
|
||
<ItemGroup> | ||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" /> | ||
<PackageReference Include="Microsoft.NET.Test.Sdk" /> | ||
<PackageReference Include="Xunit.DependencyInjection" /> | ||
<PackageReference Include="Xunit.DependencyInjection.Logging" /> | ||
<PackageReference Include="xunit" /> | ||
<PackageReference Include="xunit.abstractions" /> | ||
<PackageReference Include="xunit.runner.visualstudio"> | ||
<PrivateAssets>all</PrivateAssets> | ||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||
</PackageReference> | ||
<PackageReference Include="coverlet.collector"> | ||
<PrivateAssets>all</PrivateAssets> | ||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | ||
</PackageReference> | ||
</ItemGroup> | ||
|
||
</Project> |
67 changes: 67 additions & 0 deletions
67
extensions/ONNX/Onnx.FunctionalTests/OnnxTextGeneratorTest.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Diagnostics; | ||
using System.Text; | ||
using Microsoft.KernelMemory.AI; | ||
using Microsoft.KernelMemory.AI.Onnx; | ||
using Microsoft.KM.TestHelpers; | ||
using Xunit.Abstractions; | ||
|
||
namespace Microsoft.Onnx.FunctionalTests; | ||
|
||
public sealed class OnnxTextGeneratorTest : BaseFunctionalTestCase | ||
{ | ||
private readonly OnnxTextGenerator _target; | ||
private readonly Stopwatch _timer; | ||
|
||
public OnnxTextGeneratorTest( | ||
IConfiguration cfg, | ||
ITestOutputHelper output) : base(cfg, output) | ||
{ | ||
this._timer = new Stopwatch(); | ||
|
||
this.OnnxConfig.Validate(); | ||
this._target = new OnnxTextGenerator(this.OnnxConfig, loggerFactory: null); | ||
|
||
var modelDirectory = Path.GetFullPath(this.OnnxConfig.TextModelDir); | ||
var modelFile = Directory.GetFiles(modelDirectory) | ||
.FirstOrDefault(file => string.Equals(Path.GetExtension(file), ".ONNX", StringComparison.OrdinalIgnoreCase)); | ||
|
||
Console.WriteLine($"Using model {Path.GetFileNameWithoutExtension(modelFile)} from: {modelDirectory}"); | ||
} | ||
|
||
[Fact] | ||
[Trait("Category", "Onnx")] | ||
public async Task ItGeneratesText() | ||
{ | ||
var utcDate = DateTime.UtcNow.Date.ToString("MM/dd/yyyy"); | ||
var systemPrompt = $"Following the format \"MM/dd/yyyy\", the current date is {utcDate}."; | ||
var question = $"What is the current date?"; | ||
var prompt = $"<|system|>{systemPrompt}<|end|><|user|>{question}<|end|><|assistant|>"; | ||
|
||
var options = new TextGenerationOptions(); | ||
|
||
// Act | ||
this._timer.Restart(); | ||
var tokens = this._target.GenerateTextAsync(prompt, options); | ||
var result = new StringBuilder(); | ||
await foreach (string token in tokens) | ||
{ | ||
result.Append(token); | ||
} | ||
|
||
this._timer.Stop(); | ||
var answer = result.ToString(); | ||
|
||
// Assert | ||
Console.WriteLine($"Model Output:\n=============================\n{answer}\n============================="); | ||
Console.WriteLine($"Time: {this._timer.ElapsedMilliseconds / 1000} secs"); | ||
Assert.Contains(utcDate.ToString(), answer, StringComparison.OrdinalIgnoreCase); | ||
} | ||
|
||
protected override void Dispose(bool disposing) | ||
{ | ||
base.Dispose(disposing); | ||
this._target.Dispose(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
/* IMPORTANT: the Startup class must be at the root of the namespace and | ||
* the namespace must match exactly (required by Xunit.DependencyInjection) */ | ||
|
||
namespace Microsoft.Onnx.FunctionalTests; | ||
|
||
public class Startup | ||
{ | ||
public void ConfigureHost(IHostBuilder hostBuilder) | ||
{ | ||
var config = new ConfigurationBuilder() | ||
.AddJsonFile("appsettings.json") | ||
.AddJsonFile("appsettings.development.json", optional: true) | ||
.AddJsonFile("appsettings.Development.json", optional: true) | ||
.AddUserSecrets<Startup>() | ||
.AddEnvironmentVariables() | ||
.Build(); | ||
|
||
hostBuilder.ConfigureHostConfiguration(builder => builder.AddConfiguration(config)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
global using Xunit; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
{ | ||
"KernelMemory": { | ||
"ServiceAuthorization": { | ||
"Endpoint": "http://127.0.0.1:9001/", | ||
"AccessKey": "" | ||
}, | ||
"Services": { | ||
"Onnx": { | ||
// Path to directory containing ONNX Model, e.g. "C:\\....\\Phi-3-mini-128k-instruct-onnx\\....\\cpu-int4-rtn-block-32" | ||
"TextModelDir": "Z:\\tools\\LocalModels\\Phi-3-mini-128k-instruct-onnx\\cpu_and_mobile\\cpu-int4-rtn-block-32" | ||
}, | ||
"SimpleVectorDb": { | ||
// Options: "Disk" or "Volatile". Volatile data is lost after each execution. | ||
"StorageType": "Volatile", | ||
// Directory where files are stored. | ||
"Directory": "_vectors" | ||
}, | ||
"AzureAISearch": { | ||
// "ApiKey" or "AzureIdentity". For other options see <AzureAISearchConfig>. | ||
// AzureIdentity: use automatic AAD authentication mechanism. You can test locally | ||
// using the env vars AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET. | ||
"Auth": "AzureIdentity", | ||
"Endpoint": "https://<...>", | ||
"APIKey": "" | ||
}, | ||
"Postgres": { | ||
// Postgres instance connection string | ||
"ConnectionString": "Host=localhost;Port=5432;Username=public;Password=;Database=public", | ||
// Mandatory prefix to add to the name of table managed by KM, | ||
// e.g. to exclude other tables in the same schema. | ||
"TableNamePrefix": "tests-" | ||
}, | ||
"Qdrant": { | ||
"Endpoint": "http://127.0.0.1:6333", | ||
"APIKey": "" | ||
}, | ||
"OpenAI": { | ||
// Name of the model used to generate text (text completion or chat completion) | ||
"TextModel": "gpt-4o-mini", | ||
// The max number of tokens supported by the text model. | ||
"TextModelMaxTokenTotal": 16384, | ||
// What type of text generation, by default autodetect using the model name. | ||
// Possible values: "Auto", "TextCompletion", "Chat" | ||
"TextGenerationType": "Auto", | ||
// Name of the model used to generate text embeddings | ||
"EmbeddingModel": "text-embedding-ada-002", | ||
// The max number of tokens supported by the embedding model | ||
// See https://platform.openai.com/docs/guides/embeddings/what-are-embeddings | ||
"EmbeddingModelMaxTokenTotal": 8191, | ||
// OpenAI API Key | ||
"APIKey": "", | ||
// OpenAI Organization ID (usually empty, unless you have multiple accounts on different orgs) | ||
"OrgId": "", | ||
// Endpoint to use. By default the system uses 'https://api.openai.com/v1'. | ||
// Change this to use proxies or services compatible with OpenAI HTTP protocol like LM Studio. | ||
"Endpoint": "", | ||
// How many times to retry in case of throttling | ||
"MaxRetries": 10 | ||
} | ||
} | ||
}, | ||
"Logging": { | ||
"LogLevel": { | ||
"Default": "Information" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.Extensions.Logging; | ||
using Microsoft.KernelMemory.AI; | ||
using Microsoft.KernelMemory.AI.Onnx; | ||
|
||
#pragma warning disable IDE0130 // reduce number of "using" statements | ||
// ReSharper disable once CheckNamespace - reduce number of "using" statements | ||
namespace Microsoft.KernelMemory; | ||
|
||
/// <summary> | ||
/// Kernel Memory builder extensions | ||
/// </summary> | ||
public static partial class KernelMemoryBuilderExtensions | ||
{ | ||
public static IKernelMemoryBuilder WithOnnxTextGeneration( | ||
this IKernelMemoryBuilder builder, | ||
string modelPath, | ||
uint maxTokenTotal, | ||
ITextTokenizer? textTokenizer = null) | ||
{ | ||
var config = new OnnxConfig | ||
{ | ||
TextModelDir = modelPath | ||
}; | ||
|
||
builder.Services.AddOnnxTextGeneration(config, textTokenizer); | ||
|
||
return builder; | ||
} | ||
|
||
public static IKernelMemoryBuilder WithOnnxTextGeneration( | ||
this IKernelMemoryBuilder builder, | ||
OnnxConfig config, | ||
ITextTokenizer? textTokenizer = null) | ||
{ | ||
builder.Services.AddOnnxTextGeneration(config, textTokenizer); | ||
return builder; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// .NET IServiceCollection dependency injection extensions. | ||
/// </summary> | ||
public static partial class DependencyInjection | ||
{ | ||
public static IServiceCollection AddOnnxTextGeneration( | ||
this IServiceCollection services, | ||
OnnxConfig config, | ||
ITextTokenizer? textTokenizer = null) | ||
{ | ||
config.Validate(); | ||
return services | ||
.AddSingleton<ITextGenerator, OnnxTextGenerator>(serviceProvider => new OnnxTextGenerator( | ||
config: config, | ||
textTokenizer: textTokenizer, | ||
loggerFactory: serviceProvider.GetService<ILoggerFactory>())); | ||
} | ||
} |
Oops, something went wrong.