diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index be157a0419fc0..168c7137e57d9 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -327,9 +327,29 @@ public struct OrtApi public IntPtr RunOptionsAddActiveLoraAdapter; } +#if !__MOBILE__ + [StructLayout(LayoutKind.Sequential)] +#if NETSTANDARD2_0 + public class OrtDmlApi +#else + public struct OrtDmlApi +#endif + { + public IntPtr SessionOptionsAppendExecutionProvider_DML; + public IntPtr SessionOptionsAppendExecutionProvider_DML1; + public IntPtr CreateGPUAllocationFromD3DResource; + public IntPtr FreeGPUAllocation; + public IntPtr GetD3D12ResourceFromAllocation; + public IntPtr SessionOptionsAppendExecutionProvider_DML2; + } +#endif + internal static class NativeMethods { static OrtApi api_; +#if !__MOBILE__ + static OrtDmlApi dmlApi_; +#endif #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -507,6 +527,14 @@ static NativeMethods() OrtGetSymbolicDimensions = (DOrtGetSymbolicDimensions)Marshal.GetDelegateForFunctionPointer(api_.GetSymbolicDimensions, typeof(DOrtGetSymbolicDimensions)); OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount)); OrtGetTensorMemoryInfo = (DOrtGetTensorMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.GetTensorMemoryInfo, typeof(DOrtGetTensorMemoryInfo)); + OrtGetExecutionProviderApi = (DOrtGetExecutionProviderApi)Marshal.GetDelegateForFunctionPointer(api_.GetExecutionProviderApi, typeof(DOrtGetExecutionProviderApi)); +#if !__MOBILE__ + var utf8ProviderName = NativeOnnxValueHelper.StringToZeroTerminatedUtf8("DML"); + NativeApiStatus.VerifySuccess(OrtGetExecutionProviderApi(utf8ProviderName, ORT_API_VERSION, out var ortDmlApiPtr)); + dmlApi_ = (OrtDmlApi)Marshal.PtrToStructure(ortDmlApiPtr, typeof(OrtDmlApi)); + OrtSessionOptionsAppendExecutionProvider_DML1 = (DOrtSessionOptionsAppendExecutionProvider_DML1)Marshal.GetDelegateForFunctionPointer( + dmlApi_.SessionOptionsAppendExecutionProvider_DML1, typeof(DOrtSessionOptionsAppendExecutionProvider_DML1)); +#endif // MapTypeInfo OrtGetMapKeyType = (DGetMapKeyType)Marshal.GetDelegateForFunctionPointer(api_.GetMapKeyType, typeof(DGetMapKeyType)); OrtCastTypeInfoToMapTypeInfo = (DCastTypeInfoToMapTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.CastTypeInfoToMapTypeInfo, typeof(DCastTypeInfoToMapTypeInfo)); @@ -2093,6 +2121,22 @@ out IntPtr lora_adapter public static DOrtGetTensorMemoryInfo OrtGetTensorMemoryInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetExecutionProviderApi( + byte[] /*(const char*)*/ provider_name, + uint /*(uint32_t)*/ version, + out IntPtr /* const OrtMemoryInfo** */ provider_api + ); + + public static DOrtGetExecutionProviderApi OrtGetExecutionProviderApi; + +#if !__MOBILE__ + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsAppendExecutionProvider_DML1(IntPtr /*(OrtSessionOptions*) */ options, IntPtr dml_device, IntPtr cmd_queue); + + public static DOrtSessionOptionsAppendExecutionProvider_DML1 OrtSessionOptionsAppendExecutionProvider_DML1; +#endif + /// Map Type API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DCastTypeInfoToMapTypeInfo(IntPtr /*(const struct OrtTypeInfo*)*/ typeInfo, out IntPtr /*const OrtMapTypeInfo** */ mapTypeInfo); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 3acd84b3016de..b916d812de651 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -275,6 +275,19 @@ public void AppendExecutionProvider_DML(int deviceId = 0) #endif } + /// + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + /// A IDMLDevice reference + /// A ID3D12CommandQueue reference + public void AppendExecutionProvider_DML1(IntPtr dmlDevice, IntPtr commandQueue) + { +#if __MOBILE__ + throw new NotSupportedException("The DML Execution Provider is not supported in this build"); +#else + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_DML1(handle, dmlDevice, commandQueue)); +#endif + } /// /// Use only if you have the onnxruntime package specific to this Execution Provider.