We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS
import pymc as pm N_OBSERVATIONS = 50 with pm.Model() as model: mu = pm.Normal("mu") sigma = pm.HalfNormal("sigma", sigma=0.5) y = pm.Normal("y", mu=mu, sigma=sigma, shape=N_OBSERVATIONS) prior_trace = pm.sample_prior_predictive(random_seed=100) data = prior_trace.prior.y.isel(chain=0, draw=0) with pm.observe(model, {y: data}): pm.sample(compile_kwargs=dict(mode="JAX"), mp_ctx="forkserver") # fine pm.sample(compile_kwargs=dict(mode="JAX")) # hangs forever
Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)
Relevant code:
pymc/pymc/sampling/parallel.py
Lines 437 to 450 in 268e13b
To find the backend that is being used something like this can be used:
from pytensor.compile.mode import get_mode from pytensor.link.jax import JAXLinker ... # Somewhere inside/downstream of pm.sample mode = compile_kwargs.get("mode", None) using_jax = isinstance(get_mode(mode).linker, JAXLinker)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
JAX does not play well with fork, which is the default we're using for linux OS and arm-based MacOS
Wherever we're defaulting to fork, we should switch to forkserver/spawn instead (whichever is supported)
Relevant code:
pymc/pymc/sampling/parallel.py
Lines 437 to 450 in 268e13b
To find the backend that is being used something like this can be used:
The text was updated successfully, but these errors were encountered: