Skip to content

Commit

Permalink
Support for Homebrew on ARM64 (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu authored Aug 30, 2023
1 parent 4033dee commit b499593
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_build_wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
matrix:
include:
- torch: '2.0.1'
py: '3.10'
py: '3.11'
with:
torch: ${{ matrix.torch }}
py: ${{ matrix.py }}
94 changes: 53 additions & 41 deletions fairseq2n/python/src/fairseq2n/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,38 @@
__version__ = "0.2.0+devel"


# We import `torch` to ensure that libtorch and libtorch_python are loaded into
# the process before our extension module.
import torch # noqa: F401

# Holds the shared libraries that we loaded using our own extended lookup logic
# Keeps the shared libraries that we load using our own extended lookup logic
# in memory.
_libs: List[CDLL] = []


def _load_shared_library(lib_name: str) -> Optional[CDLL]:
# If we are not in a Conda environment, try to load the library using the
# default lookup rules of the dynamic linker first. In Conda environments
# we always expect native libraries to be part of the environment.
if not "CONDA_PREFIX" in environ:
try:
# Use the global namespace to ensure that all modules use the same
# library instance.
return CDLL(lib_name, mode=RTLD_GLOBAL)
except OSError:
pass

if site.ENABLE_USER_SITE:
site_packages = [site.getusersitepackages()]
else:
site_packages = []

site_packages += site.getsitepackages()

# If the system does not have the library, try to load it from the site
# packages of the current Python environment.
for packages_dir in site_packages:
lib_path = Path(packages_dir).parent.parent.joinpath(lib_name)
def _load_shared_libraries() -> None:
# We import `torch` to ensure that libtorch and libtorch_python are loaded
# into the process before our extension module.
import torch

try:
return CDLL(str(lib_path), mode=RTLD_GLOBAL)
except OSError:
pass
# Intel oneTBB is only available on x86_64 systems.
if platform.machine() == "x86_64":
_load_tbb()

return None
_load_sndfile()


def _load_tbb() -> None:
if sys.platform == "darwin":
if platform.system() == "Darwin":
lib_name = "libtbb.12.dylib"
else:
lib_name = "libtbb.so.12"

libtbb = _load_shared_library(lib_name)
if libtbb is None:
raise OSError("Intel oneTBB is not found! Check your fairseq2n installation!")
raise OSError("Intel oneTBB is not found! Check your fairseq2 installation!")

_libs.append(libtbb)


def _load_sndfile() -> None:
if sys.platform == "darwin":
if platform.system() == "Darwin":
lib_name = "libsndfile.1.dylib"
else:
lib_name = "libsndfile.so.1"
Expand All @@ -90,12 +66,48 @@ def _load_sndfile() -> None:
_libs.append(libsndfile)


# We load the shared libraries we depend on using our own extended lookup logic
# since they might be located in a virtual environment.
if (uname := platform.uname()).machine == "x86_64":
_load_tbb()
def _load_shared_library(lib_name: str) -> Optional[CDLL]:
# In Conda environments, we always expect native libraries to be part of the
# environment, so we skip the default lookup rules of the dynamic linker.
if not "CONDA_PREFIX" in environ:
try:
# Use the global namespace to ensure that all modules use the same
# library instance.
return CDLL(lib_name, mode=RTLD_GLOBAL)
except OSError:
pass

# On macOS, we also explicitly check the standard Homebrew locations.
if platform.system() == "Darwin":
for brew_path in ["/usr/local/lib", "/opt/homebrew/lib"]:
try:
return CDLL(str(Path(brew_path, lib_name)), mode=RTLD_GLOBAL)
except OSError:
pass

if site.ENABLE_USER_SITE:
site_packages = [site.getusersitepackages()]
else:
site_packages = []

site_packages += site.getsitepackages()

# If the system does not have the library, try to load it from the site
# packages of the current Python environment.
for packages_dir in site_packages:
lib_path = Path(packages_dir).parent.parent.joinpath(lib_name)

try:
return CDLL(str(lib_path), mode=RTLD_GLOBAL)
except OSError:
pass

return None


_load_sndfile()
# We load shared libraries that we depend on using our own extended lookup logic
# since they might be located in non-default locations.
_load_shared_libraries()


def get_lib() -> Path:
Expand Down
4 changes: 4 additions & 0 deletions src/fairseq2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


# We import fairseq2n to report any initialization error eagerly.
import fairseq2n

__version__ = "0.2.0+devel"


Expand Down

0 comments on commit b499593

Please sign in to comment.