From 457f4a204184f3089710201c3e4e82bc9bf48219 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 14 Nov 2024 18:00:51 -0800 Subject: [PATCH] Refactor CUDA library regex patterns for Windows environments to search for nvidia libraries --- onnxruntime/__init__.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index dc5b04a800fb9..416998d2ef005 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -98,16 +98,21 @@ if platform.system() == "Windows": # # Define the list of DLL patterns cuda_libs = ( - "cublas.dll", - "cudnn.dll", - "cudart.dll", - "nvrtc.dll", - "cufft.dll", - "curand.dll", - "nvJitLink.dll", + "cublas", + "cudnn", + "cudart", + "nvrtc", + "cufft", + "curand", + "nvJitLink", ) - # Convert patterns to regex with case-insensitivity - pattern_regex = {pattern: re.compile(rf"^{re.escape(pattern)}$", re.IGNORECASE) for pattern in cuda_libs} + # Construct a regex pattern for each library name with optional parts + # Pattern explanation: + # - `libname`: Match the base library name (e.g., "cudart") + # - `(?:64)?`: Optionally match "64" + # - `(?:_\d+)*`: Match zero or more occurrences of "_n" where "n" is one or more digits + # - `.dll$`: End with ".dll" ignoring case + lib_pattern = {lib: re.compile(rf"{lib}(?:64)?(?:_\d+)*\.dll$", re.IGNORECASE) for lib in cuda_libs} # Collect all directories under site-packages/nvidia that contain .dll files (for Windows) for root, _, files in os.walk(nvidia_path): # Add the current directory to the DLL search path @@ -115,8 +120,8 @@ with os.add_dll_directory(root): # Find all .dll files in the current directory for file in files: - for regex in pattern_regex.items().values(): - if regex.match(file): + for pattern in lib_pattern.items().values(): + if pattern.match(file): dll_path = os.path.join(root, file) _ = ctypes.CDLL(dll_path) elif platform.system() == "Linux": @@ -132,13 +137,13 @@ ) # Regular expression to match .so files with optional versioning (e.g., .so, .so.1, .so.2.3) - pattern_regex = {pattern: re.compile(rf"{re.escape(pattern)}(\.\d+)*$", re.IGNORECASE) for pattern in cuda_libs} + lib_pattern = {pattern: re.compile(rf"{re.escape(pattern)}(\.\d+)*$", re.IGNORECASE) for pattern in cuda_libs} # Traverse the directory and subdirectories for root, _, files in os.walk(nvidia_path): for file in files: # Check if the file matches the .so pattern - for regex in pattern_regex.items().values(): + for regex in lib_pattern.items().values(): if regex.match(file): # Check if the file matches the pattern so_path = os.path.join(root, file) _ = ctypes.CDLL(so_path)