Skip to content

Commit

Permalink
Enable wheels for MacOS (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
mavenlin authored Dec 4, 2023
1 parent d91309a commit 0ba84a5
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 42 deletions.
17 changes: 14 additions & 3 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
build --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
build --action_env=BAZEL_LINKOPTS=-static-libgcc
build:linux --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:macos_x86_64 --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:macos_arm64 --copt=-g0 --copt=-O3 --copt=-DNDEBUG
build:windows_x86_64 -c opt --compiler=clang-cl

build:linux --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm
build:linux --action_env=BAZEL_LINKOPTS=-static-libgcc
build:linux --define os=linux --define cpu=x86_64

build:macos_x86_64 --define os=macos --define cpu=x86_64
build:macos_arm64 --define os=macos --define cpu=arm64

build:windows_x86_64 --define os=windows --define cpu=x86_64
build:windows_arm64 --define os=windows --define cpu=arm64
19 changes: 10 additions & 9 deletions .github/workflows/release.yml → .github/workflows/linux.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Action name
name: Release Wheel
name: Release for Linux

on:
push:
Expand All @@ -11,24 +11,25 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
container:
image: ghcr.io/sail-sg/jax-xc-image:latest
python-version: ['3.9', '3.10', '3.11', '3.12']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Setup Python-${{ matrix.python-version }} and Build
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
eval "$(pyenv init -)" && pyenv global ${{ matrix.python-version }}-dev
pip install -r requirements.txt
bazel build --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
bazel build --config=linux --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: wheel
name: linux_wheel
path: bazel-bin/external/maple2jax/*.whl
publish:
runs-on: ubuntu-latest
Expand All @@ -37,7 +38,7 @@ jobs:
- name: Download artifact
uses: actions/download-artifact@main
with:
name: wheel
name: linux_wheel
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
Expand Down
45 changes: 45 additions & 0 deletions .github/workflows/macos.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Release for Macos

on:
push:
tags:
- 'v[0-9]+\.[0-9]+\.[0-9]+'

jobs:
release:
runs-on: macos-11
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
pip install -r requirements.txt
bazel build --config=macos_x86_64 --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: macos_wheel
path: bazel-bin/external/maple2jax/*.whl
publish:
runs-on: ubuntu-latest
needs: release
steps:
- name: Download artifact
uses: actions/download-artifact@main
with:
name: macos_wheel
path: dist
- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
11 changes: 6 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
container:
image: ghcr.io/sail-sg/jax-xc-image:latest
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.12
- name: Test
run: |
eval "$(pyenv init -)" && pyenv global 3.11-dev
pip install --upgrade -r requirements.txt
bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/...
pip install -r requirements.txt
bazel test --config=linux --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/...
31 changes: 31 additions & 0 deletions .github/workflows/windows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Release for Windows

# Don't run yet, it fails at linking, very difficult to debug with only a remote runner.
# on: [push]

jobs:
release:
runs-on: windows-2019
strategy:
matrix:
python-version: ['3.9']
steps:
- name: Cancel previous run
uses: styfle/[email protected]
with:
access_token: ${{ github.token }}
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Build
run: |
pip install -r requirements.txt
bazel build --config=windows_x86_64 --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 @maple2jax//:jax_xc_wheel
shell: pwsh
- name: Upload artifact
uses: actions/upload-artifact@main
with:
name: macos_wheel
path: bazel-bin/external/maple2jax/*.whl
13 changes: 13 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ workspace(name = "jax_xc")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

http_archive(
name = "bazel_skylib",
sha256 = "cd55a062e763b9349921f0f5db8c3933288dc8ba4f76dd9416aac68acee3cb94",
urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz",
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz",
],
)

load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")

bazel_skylib_workspace()

http_archive(
name = "rules_python",
sha256 = "8c8fe44ef0a9afc256d1e75ad5f448bb59b81aba149b8958f02f7b3a98f5d9b4",
Expand Down
2 changes: 1 addition & 1 deletion maple2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .functionals import * # noqa
from . import experimental # noqa

__version__ = "0.0.9"
__version__ = "0.0.10"
50 changes: 35 additions & 15 deletions maple2jax/libxc/build.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ cc_library(
visibility = ["//visibility:public"],
)

cc_library(
name = "register",
hdrs = [
"src_cc/register.h",
],
deps = [
":xc_inc",
"@visit_struct",
"@pybind11",
],
)

{% for src in c_file_basenames %}
genrule(
name = "gen_{{ src }}c",
Expand All @@ -44,16 +56,14 @@ genrule(
tools = [":wrap", ":wrap.cc.jinja"],
)

{% endfor %}

cc_binary(
name = "libxc.so",
copts = ["-std=c++14", "-fexceptions"],
features = [
"-use_header_modules", # Required for pybind11.
"-parse_headers",
cc_library(
name = "{{ src }}c.obj",
srcs = ["src_cc/{{ src }}c"],
features = ["windows_export_all_symbols"],
deps = [
":xc_inc",
":register",
],
linkshared = 1,
includes = [
".",
"src",
Expand All @@ -64,19 +74,29 @@ cc_binary(
"XC_DONT_COMPILE_KXC",
"XC_DONT_COMPILE_LXC",
],
alwayslink = True,
)

{% endfor %}

pybind_extension(
name = "libxc",
copts = select({
"@platforms//os:windows": [],
"//conditions:default": ["-std=c++14"],
}),
features = ["windows_export_all_symbols"],
deps = [
":xc_inc",
"@visit_struct",
"@pybind11",
"@local_config_python//:python_headers",
":register",
{% for basename in c_file_basenames %}
":{{ basename }}c.obj",
{% endfor %}
],
srcs = [
"src_cc/register.h",
"src_cc/register.cc",
"src_cc/libxc.cc",
{% for basename in c_file_basenames %}
"src_cc/{{ basename }}c",
{% endfor %}
],
visibility = ["//visibility:public"],
)
Expand Down
4 changes: 2 additions & 2 deletions maple2jax/libxc/gen_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def main(_):
else:
c_file_basenames.append(basename)

with open(FLAGS.template, "r") as f:
with open(FLAGS.template, "r", encoding="utf8") as f:
template = Template(f.read(), trim_blocks=True, lstrip_blocks=True)
build = template.render(c_file_basenames=c_file_basenames)
with open(FLAGS.build, "w") as out:
with open(FLAGS.build, "w", encoding="utf8") as out:
out.write(build)


Expand Down
6 changes: 3 additions & 3 deletions maple2jax/libxc/wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def wrap_file(filename, out):
with open(filename, "r") as f:
with open(filename, "r", encoding="utf8") as f:
content = f.read()
# find all init function and the corresponding param struct name
results = re.findall(
Expand Down Expand Up @@ -74,14 +74,14 @@ def wrap_file(filename, out):
fields.extend(members)
register_struct.append((s, fields, struct_to_init[s]))

with open(FLAGS.template, "r") as f:
with open(FLAGS.template, "r", encoding="utf8") as f:
t = Template(f.read(), trim_blocks=True, lstrip_blocks=True)
content = t.render(
filename=os.path.basename(filename),
register_struct=register_struct,
register_maple=register_maple,
)
with open(FLAGS.out, "wt") as fout:
with open(FLAGS.out, "wt", encoding="utf8") as fout:
fout.write(content)


Expand Down
6 changes: 5 additions & 1 deletion maple2jax/python.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def _get_abi_tag(rctx, python_bin):
"version = platform.python_version_tuple();" +
"print(f'cp{version[0]}{version[1]}{sys.abiflags}')",
])
return result.stdout.splitlines()[0]
lines = result.stdout.splitlines()
if len(lines) == 0:
return ""
else:
return lines[0]

def _declare_python_abi_impl(rctx):
python_bin = _get_python_bin(rctx)
Expand Down
43 changes: 41 additions & 2 deletions maple2jax/wheel.BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
load("@bazel_skylib//lib:selects.bzl", "selects")
load("@python_abi//:abi.bzl", "abi_tag", "python_tag")
load("@rules_python//python:packaging.bzl", "py_wheel")

selects.config_setting_group(
name = "macos_arm64",
match_all = [
"@platforms//os:macos",
"@platforms//cpu:arm64",
],
)

selects.config_setting_group(
name = "macos_x86_64",
match_all = [
"@platforms//os:macos",
"@platforms//cpu:x86_64",
],
)

selects.config_setting_group(
name = "linux_x86_64",
match_all = [
"@platforms//os:linux",
"@platforms//cpu:x86_64",
],
)

selects.config_setting_group(
name = "windows_x86_64",
match_all = [
"@platforms//os:windows",
"@platforms//cpu:x86_64",
],
)

py_wheel(
name = "jax_xc_wheel",
abi = abi_tag(),
Expand All @@ -12,7 +45,13 @@ py_wheel(
],
description_file = "@jax_xc//:README.rst",
distribution = "jax_xc",
platform = "manylinux_2_17_x86_64",
platform = select({
":macos_arm64": "macosx_11_0_arm64",
":macos_x86_64": "macosx_11_0_x86_64",
":linux_x86_64": "manylinux_2_17_x86_64",
":windows_x86_64": "win_amd64",
"//conditions:default": "manylinux_2_17_x86_64",
}),
python_requires = ">=3.9",
python_tag = python_tag(),
requires = [
Expand All @@ -21,7 +60,7 @@ py_wheel(
"tensorflow-probability",
"autofd",
],
version = "0.0.9",
version = "0.0.10",
deps = [
"@maple2jax//jax_xc",
"@maple2jax//jax_xc:experimental",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ tensorflow-probability
jinja2
absl-py
numpy
pyscf
regex
jaxtyping
autofd

0 comments on commit 0ba84a5

Please sign in to comment.