Skip to content

Commit

Permalink
Build system updates; JAX compat
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Jun 23, 2024
1 parent bbc15cc commit 96ac22a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 385 deletions.
37 changes: 26 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
cmake_minimum_required(VERSION 3.12...3.18)
project(exoplanet_core LANGUAGES CXX CUDA)
cmake_minimum_required(VERSION 3.15...3.27)
project(
${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX)

find_package(Python COMPONENTS Interpreter Development REQUIRED)
set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)

include_directories(${CMAKE_CURRENT_LIST_DIR}/src/exoplanet_core/lib/include)
include_directories(${CMAKE_CURRENT_LIST_DIR}/src/exoplanet_core/jax)
include_directories("src/exoplanet_core/lib/include")

include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
pybind11_add_module(
gpu_driver
${CMAKE_CURRENT_LIST_DIR}/src/exoplanet_core/jax/cuda_kernels.cc.cu
${CMAKE_CURRENT_LIST_DIR}/src/exoplanet_core/jax/gpu_driver.cpp)
install(TARGETS gpu_driver DESTINATION jax)
pybind11_add_module(driver "src/exoplanet_core/driver.cpp")
target_compile_features(driver PUBLIC cxx_std_14)
install(TARGETS driver LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

pybind11_add_module(cpu_driver "src/exoplanet_core/jax/cpu_driver.cpp")
target_include_directories(cpu_driver PRIVATE "src/exoplanet_core/jax")
target_compile_features(driver PUBLIC cxx_std_14)
install(TARGETS driver LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

if ("$ENV{EXOPLANET_CORE_CUDA}" STREQUAL "yes")
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
pybind11_add_module(
gpu_driver
"src/exoplanet_core/jax/cuda_kernels.cc.cu"
"src/exoplanet_core/jax/gpu_driver.cpp")
target_compile_features(driver PUBLIC cxx_std_14)
install(TARGETS gpu_driver LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()
59 changes: 50 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,17 +1,58 @@
[build-system]
requires = [
"setuptools>=42",
"wheel",
"setuptools_scm[toml]>=3.4",
"numpy>=1.13.0",
"pybind11>=2.6",
"cmake",
[project]
name = "exoplanet_core"
description = "The compiled backend for exoplanet"
authors = [{ name = "Dan Foreman-Mackey", email = "[email protected]" }]
readme = "README.md"
requires-python = ">=3.9"
license = { text = "MIT License" }
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
]
dynamic = ["version"]
dependencies = ["numpy"]

[project.optional-dependencies]
pymc3 = ["pymc3>=3.9", "numpy<1.22"]
pymc = ["pymc>=5.0.0"]
jax = ["jax", "jaxlib"]
test = ["pytest"]
comparison = ["batman-package", "starry", "numpy<1.22", "xarray<2023.10.0"]
benchmark = [
"pytest",
"pytest-benchmark",
"radvel",
"kepler.py",
"batman-package",
"starry",
"exoplanet==0.4.5",
]
build-backend = "setuptools.build_meta"

[project.urls]
"Homepage" = "https://docs.exoplanet.codes"
"Source" = "https://github.com/exoplanet-dev/exoplanet-core"
"Bug Tracker" = "https://github.com/exoplanet-dev/exoplanet-core/issues"

[build-system]
requires = ["scikit-build-core", "numpy", "pybind11"]
build-backend = "scikit_build_core.build"

[tool.scikit-build]
sdist.exclude = []
sdist.include = ["src/exoplanet_core/exoplanet_core_version.py"]
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"

[tool.setuptools_scm]
write_to = "src/exoplanet_core/exoplanet_core_version.py"

[tool.cibuildwheel]
skip = "pp* *-win32 *-musllinux_* *-manylinux_i686"

[tool.black]
line-length = 79

Expand Down
188 changes: 0 additions & 188 deletions setup.py

This file was deleted.

Loading

0 comments on commit 96ac22a

Please sign in to comment.