diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index be157a0419fc0..168c7137e57d9 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -327,9 +327,29 @@ public struct OrtApi
public IntPtr RunOptionsAddActiveLoraAdapter;
}
+#if !__MOBILE__
+ [StructLayout(LayoutKind.Sequential)]
+#if NETSTANDARD2_0
+ public class OrtDmlApi
+#else
+ public struct OrtDmlApi
+#endif
+ {
+ public IntPtr SessionOptionsAppendExecutionProvider_DML;
+ public IntPtr SessionOptionsAppendExecutionProvider_DML1;
+ public IntPtr CreateGPUAllocationFromD3DResource;
+ public IntPtr FreeGPUAllocation;
+ public IntPtr GetD3D12ResourceFromAllocation;
+ public IntPtr SessionOptionsAppendExecutionProvider_DML2;
+ }
+#endif
+
internal static class NativeMethods
{
static OrtApi api_;
+#if !__MOBILE__
+ static OrtDmlApi dmlApi_;
+#endif
#if NETSTANDARD2_0
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
@@ -507,6 +527,14 @@ static NativeMethods()
OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions));
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
OrtGetTensorMemoryInfo = (DOrtGetTensorMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTensorMemoryInfo, typeof(DOrtGetTensorMemoryInfo));
+ OrtGetExecutionProviderApi = (DOrtGetExecutionProviderApi)Marshal.GetDelegateForFunctionPointer(api_.GetExecutionProviderApi, typeof(DOrtGetExecutionProviderApi));
+#if !__MOBILE__
+ var utf8ProviderName = NativeOnnxValueHelper.StringToZeroTerminatedUtf8("DML");
+ NativeApiStatus.VerifySuccess(OrtGetExecutionProviderApi(utf8ProviderName, ORT_API_VERSION, out var ortDmlApiPtr));
+ dmlApi_ = (OrtDmlApi)Marshal.PtrToStructure(ortDmlApiPtr, typeof(OrtDmlApi));
+ OrtSessionOptionsAppendExecutionProvider_DML1 = (DOrtSessionOptionsAppendExecutionProvider_DML1)Marshal.GetDelegateForFunctionPointer(
+ dmlApi_.SessionOptionsAppendExecutionProvider_DML1, typeof(DOrtSessionOptionsAppendExecutionProvider_DML1));
+#endif
// MapTypeInfo
OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType));
OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo));
@@ -2093,6 +2121,22 @@ out IntPtr lora_adapter
public static DOrtGetTensorMemoryInfo OrtGetTensorMemoryInfo;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtGetExecutionProviderApi(
+ byte[] /*(const char*)*/ provider_name,
+ uint /*(uint32_t)*/ version,
+ out IntPtr /* const OrtMemoryInfo** */ provider_api
+ );
+
+ public static DOrtGetExecutionProviderApi OrtGetExecutionProviderApi;
+
+#if !__MOBILE__
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsAppendExecutionProvider_DML1(IntPtr /*(OrtSessionOptions*) */ options, IntPtr dml_device, IntPtr cmd_queue);
+
+ public static DOrtSessionOptionsAppendExecutionProvider_DML1 OrtSessionOptionsAppendExecutionProvider_DML1;
+#endif
+
/// Map Type API
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index 3acd84b3016de..b916d812de651 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -275,6 +275,19 @@ public void AppendExecutionProvider_DML(int deviceId = 0)
#endif
}
+ ///
+ /// Use only if you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// A IDMLDevice reference
+ /// A ID3D12CommandQueue reference
+ public void AppendExecutionProvider_DML1(IntPtr dmlDevice, IntPtr commandQueue)
+ {
+#if __MOBILE__
+ throw new NotSupportedException("The DML Execution Provider is not supported in this build");
+#else
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_DML1(handle, dmlDevice, commandQueue));
+#endif
+ }
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.