Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML] Add support for AppendExecutionProvider_DML1 with C#. #22291

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,29 @@ public struct OrtApi
public IntPtr RunOptionsAddActiveLoraAdapter;
}

#if !__MOBILE__
[StructLayout(LayoutKind.Sequential)]
#if NETSTANDARD2_0
public class OrtDmlApi
#else
public struct OrtDmlApi
#endif
{
public IntPtr SessionOptionsAppendExecutionProvider_DML;
public IntPtr SessionOptionsAppendExecutionProvider_DML1;
public IntPtr CreateGPUAllocationFromD3DResource;
public IntPtr FreeGPUAllocation;
public IntPtr GetD3D12ResourceFromAllocation;
public IntPtr SessionOptionsAppendExecutionProvider_DML2;
}
#endif

internal static class NativeMethods
{
static OrtApi api_;
#if !__MOBILE__
static OrtDmlApi dmlApi_;
#endif

#if NETSTANDARD2_0
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
Expand Down Expand Up @@ -507,6 +527,14 @@ static NativeMethods()
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
OrtGetTensorMemoryInfo = (DOrtGetTensorMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTensorMemoryInfo, typeof(DOrtGetTensorMemoryInfo));
OrtGetExecutionProviderApi = (DOrtGetExecutionProviderApi)Marshal.GetDelegateForFunctionPointer(api_.GetExecutionProviderApi, typeof(DOrtGetExecutionProviderApi));
#if !__MOBILE__
var utf8ProviderName = NativeOnnxValueHelper.StringToZeroTerminatedUtf8("DML");
NativeApiStatus.VerifySuccess(OrtGetExecutionProviderApi(utf8ProviderName, ORT_API_VERSION, out var ortDmlApiPtr));
dmlApi_ = (OrtDmlApi)Marshal.PtrToStructure(ortDmlApiPtr, typeof(OrtDmlApi));
OrtSessionOptionsAppendExecutionProvider_DML1 = (DOrtSessionOptionsAppendExecutionProvider_DML1)Marshal.GetDelegateForFunctionPointer(
dmlApi_.SessionOptionsAppendExecutionProvider_DML1, typeof(DOrtSessionOptionsAppendExecutionProvider_DML1));
#endif
// MapTypeInfo
OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType));
OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo));
Expand Down Expand Up @@ -2093,6 +2121,22 @@ out IntPtr lora_adapter

public static DOrtGetTensorMemoryInfo OrtGetTensorMemoryInfo;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetExecutionProviderApi(
byte[] /*(const char*)*/ provider_name,
uint /*(uint32_t)*/ version,
out IntPtr /* const OrtMemoryInfo** */ provider_api
);

public static DOrtGetExecutionProviderApi OrtGetExecutionProviderApi;

#if !__MOBILE__
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsAppendExecutionProvider_DML1(IntPtr /*(OrtSessionOptions*) */ options, IntPtr dml_device, IntPtr cmd_queue);

public static DOrtSessionOptionsAppendExecutionProvider_DML1 OrtSessionOptionsAppendExecutionProvider_DML1;
#endif

/// Map Type API
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo);
Expand Down
13 changes: 13 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,19 @@ public void AppendExecutionProvider_DML(int deviceId = 0)
#endif
}

/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="dmlDevice">A IDMLDevice reference</param>
/// <param name="commandQueue">A ID3D12CommandQueue reference</param>
public void AppendExecutionProvider_DML1(IntPtr dmlDevice, IntPtr commandQueue)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, we do not expose native IntPtr handles via public API. The preferred way is to expose an object of DmlDevice that has a handle.
That object is likely to be IDisposable.
Then that object can expose an internal property Handle that can be used for NativeApi.
The same probably goes for commandQueue.

{
#if __MOBILE__
throw new NotSupportedException("The DML Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_DML1(handle, dmlDevice, commandQueue));
#endif
}

/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
Expand Down
Loading