Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not use "fork" as default mp_ctx when compiling JAX functions in the PyMC sampler #7668

Open
ricardoV94 opened this issue Jan 31, 2025 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jan 31, 2025

Description

JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS

import pymc as pm
N_OBSERVATIONS = 50

with pm.Model() as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    y = pm.Normal("y", mu=mu, sigma=sigma, shape=N_OBSERVATIONS)
    prior_trace = pm.sample_prior_predictive(random_seed=100)

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    pm.sample(compile_kwargs=dict(mode="JAX"), mp_ctx="forkserver")  # fine
    pm.sample(compile_kwargs=dict(mode="JAX"))  # hangs forever

Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)

Relevant code:

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
# Related issue https://github.com/pymc-devs/pymc/issues/5339
if mp_ctx is None and platform.system() == "Darwin":
if platform.processor() == "arm":
mp_ctx = "fork"
logger.debug(
"mp_ctx is set to 'fork' for MacOS with ARM architecture. "
+ "This might cause unexpected behavior with JAX, which is inherently multithreaded."
)
else:
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx)

To find the backend that is being used something like this can be used:

from pytensor.compile.mode import get_mode
from pytensor.link.jax import JAXLinker
...
  # Somewhere inside/downstream of pm.sample
  mode = compile_kwargs.get("mode", None)
  using_jax = isinstance(get_mode(mode).linker, JAXLinker)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant