Skip to content

Commit

Permalink
Refactor CUDA library regex patterns for Windows environments to sear…
Browse files Browse the repository at this point in the history
…ch for nvidia libraries
  • Loading branch information
jchen351 committed Nov 15, 2024
1 parent 053d0a3 commit 457f4a2
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,30 @@
if platform.system() == "Windows": #
# Define the list of DLL patterns
cuda_libs = (
"cublas.dll",
"cudnn.dll",
"cudart.dll",
"nvrtc.dll",
"cufft.dll",
"curand.dll",
"nvJitLink.dll",
"cublas",
"cudnn",
"cudart",
"nvrtc",
"cufft",
"curand",
"nvJitLink",
)
# Convert patterns to regex with case-insensitivity
pattern_regex = {pattern: re.compile(rf"^{re.escape(pattern)}$", re.IGNORECASE) for pattern in cuda_libs}
# Construct a regex pattern for each library name with optional parts
# Pattern explanation:
# - `libname`: Match the base library name (e.g., "cudart")
# - `(?:64)?`: Optionally match "64"
# - `(?:_\d+)*`: Match zero or more occurrences of "_n" where "n" is one or more digits
# - `.dll$`: End with ".dll" ignoring case
lib_pattern = {lib: re.compile(rf"{lib}(?:64)?(?:_\d+)*\.dll$", re.IGNORECASE) for lib in cuda_libs}
# Collect all directories under site-packages/nvidia that contain .dll files (for Windows)
for root, _, files in os.walk(nvidia_path):
# Add the current directory to the DLL search path

with os.add_dll_directory(root):
# Find all .dll files in the current directory
for file in files:
for regex in pattern_regex.items().values():
if regex.match(file):
for pattern in lib_pattern.items().values():
if pattern.match(file):
dll_path = os.path.join(root, file)
_ = ctypes.CDLL(dll_path)
elif platform.system() == "Linux":
Expand All @@ -132,13 +137,13 @@
)

# Regular expression to match .so files with optional versioning (e.g., .so, .so.1, .so.2.3)
pattern_regex = {pattern: re.compile(rf"{re.escape(pattern)}(\.\d+)*$", re.IGNORECASE) for pattern in cuda_libs}
lib_pattern = {pattern: re.compile(rf"{re.escape(pattern)}(\.\d+)*$", re.IGNORECASE) for pattern in cuda_libs}

# Traverse the directory and subdirectories
for root, _, files in os.walk(nvidia_path):
for file in files:
# Check if the file matches the .so pattern
for regex in pattern_regex.items().values():
for regex in lib_pattern.items().values():
if regex.match(file): # Check if the file matches the pattern
so_path = os.path.join(root, file)
_ = ctypes.CDLL(so_path)
Expand Down

0 comments on commit 457f4a2

Please sign in to comment.