r"""flowMC sampler implementation and setup"""
from typing import Any
import time
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray
import equinox as eqx
from flowMC.resource.base import Resource
from flowMC.resource.buffers import Buffer
from flowMC.resource.states import State
from flowMC.resource.logPDF import LogPDF
from flowMC.resource.local_kernel.MALA import MALA
from flowMC.resource.local_kernel.Gaussian_random_walk import GaussianRandomWalk
from flowMC.resource.nf_model.NF_proposal import NFProposal
from flowMC.resource.nf_model.rqSpline import MaskedCouplingRQSpline
from flowMC.resource.optimizer import Optimizer
from flowMC.strategy.lambda_function import Lambda
from flowMC.strategy.take_steps import TakeSerialSteps, TakeGroupSteps
from flowMC.strategy.train_model import TrainModel
from flowMC.strategy.update_state import UpdateState
from flowMC.Sampler import Sampler
from .jester_sampler import JesterSampler, SamplerOutput
from ..config.schema import FlowMCSamplerConfig
from ..base import LikelihoodBase, Prior, BijectiveTransform, NtoMTransform
from jesterTOV.logging_config import get_logger
logger = get_logger("jester")
[docs]
class FlowMCSampler(JesterSampler):
"""
FlowMC-specific sampler implementation.
This class inherits from JesterSampler and adds flowMC-specific
initialization and configuration. It creates a flowMC Sampler with:
- Local sampler (MALA or GaussianRandomWalk)
- Normalizing flow model (MaskedCouplingRQSpline)
- Training and production sampling loops
Parameters
----------
likelihood : LikelihoodBase
Likelihood object with evaluate(params, data) method
prior : Prior
Prior object with sample() and log_prob() methods
config : FlowMCSamplerConfig
Configuration object from YAML (contains n_chains, n_loop_training, learning_rate, etc.)
sample_transforms : list[BijectiveTransform], optional
Bijective transforms applied during sampling (with Jacobians)
likelihood_transforms : list[NtoMTransform], optional
N-to-M transforms applied before likelihood evaluation
seed : int, optional
Random seed (default: 0)
local_sampler_name : str, optional
Name of the local sampler: "MALA" or "GaussianRandomWalk" (default: "GaussianRandomWalk")
local_sampler_arg : dict[str, Any], optional
Arguments for local sampler (e.g., {"step_size": ...})
num_layers : int, optional
Number of coupling layers in normalizing flow (default: 10)
hidden_size : list[int], optional
Hidden layer sizes for normalizing flow (default: [128, 128])
num_bins : int, optional
Number of bins for rational quadratic splines (default: 8)
Attributes
----------
sampler : Sampler
FlowMC sampler instance
config : FlowMCSamplerConfig
Configuration object
"""
sampler: Sampler
[docs]
def __init__(
self,
likelihood: LikelihoodBase,
prior: Prior,
config: FlowMCSamplerConfig,
sample_transforms: list[BijectiveTransform] | None = None,
likelihood_transforms: list[NtoMTransform] | None = None,
seed: int = 0,
local_sampler_name: str = "GaussianRandomWalk",
local_sampler_arg: dict[str, Any] | None = None,
num_layers: int = 10,
hidden_size: list[int] | None = None,
num_bins: int = 8,
) -> None:
# Handle None defaults
if sample_transforms is None:
sample_transforms = []
if likelihood_transforms is None:
likelihood_transforms = []
if local_sampler_arg is None:
local_sampler_arg = {}
if hidden_size is None:
hidden_size = [128, 128]
# Initialize base class (sets up transforms and parameter names)
super().__init__(likelihood, prior, sample_transforms, likelihood_transforms)
# Store config
self.config = config
# FlowMC-specific initialization
rng_key = jax.random.PRNGKey(seed)
# Create logpdf wrapper that matches new flowMC API
def logpdf_func(x: Float[Array, " n_dim"], data: dict) -> Float:
"""Log PDF function for flowMC 0.4.5 API."""
# Convert array to dict with parameter names
# NOTE: Do NOT use float() on JAX traced values - causes ConcretizationTypeError
params_dict = {name: x[i] for i, name in enumerate(self.parameter_names)}
return self.posterior_from_dict(params_dict, data)
# Build the custom bundle for JESTER
n_dims = self.prior.n_dim
n_chains = config.n_chains
n_local_steps = config.n_local_steps
n_global_steps = config.n_global_steps
n_training_loops = config.n_loop_training
n_production_loops = config.n_loop_production
n_epochs = config.n_epochs
learning_rate = config.learning_rate
local_thinning = config.train_thinning
global_thinning = config.train_thinning
output_local_thinning = config.output_thinning
output_global_thinning = config.output_thinning
# Validate thinning values to prevent zero-length buffers
thinning_errors = []
if local_thinning > n_local_steps:
thinning_errors.append(
f"train_thinning ({local_thinning}) exceeds n_local_steps ({n_local_steps})"
)
if global_thinning > n_global_steps:
thinning_errors.append(
f"train_thinning ({global_thinning}) exceeds n_global_steps ({n_global_steps})"
)
if output_local_thinning > n_local_steps:
thinning_errors.append(
f"output_thinning ({output_local_thinning}) exceeds n_local_steps ({n_local_steps})"
)
if output_global_thinning > n_global_steps:
thinning_errors.append(
f"output_thinning ({output_global_thinning}) exceeds n_global_steps ({n_global_steps})"
)
if thinning_errors:
error_msg = (
"Thinning values exceed step counts, which would produce zero-length buffers:\n "
+ "\n ".join(thinning_errors)
+ "\nPlease reduce thinning values or increase step counts in your config."
)
raise ValueError(error_msg)
# Calculate buffer sizes (guaranteed non-zero after validation)
n_training_steps = (
n_local_steps // local_thinning * n_training_loops
+ n_global_steps // global_thinning * n_training_loops
)
n_production_steps = (
n_local_steps // output_local_thinning * n_production_loops
+ n_global_steps // output_global_thinning * n_production_loops
)
n_total_epochs = n_training_loops * n_epochs
# Create buffers
positions_training = Buffer(
"positions_training", (n_chains, n_training_steps, n_dims), 1
)
log_prob_training = Buffer("log_prob_training", (n_chains, n_training_steps), 1)
local_accs_training = Buffer(
"local_accs_training", (n_chains, n_training_steps), 1
)
global_accs_training = Buffer(
"global_accs_training", (n_chains, n_training_steps), 1
)
loss_buffer = Buffer("loss_buffer", (n_total_epochs,), 0)
position_production = Buffer(
"positions_production", (n_chains, n_production_steps, n_dims), 1
)
log_prob_production = Buffer(
"log_prob_production", (n_chains, n_production_steps), 1
)
local_accs_production = Buffer(
"local_accs_production", (n_chains, n_production_steps), 1
)
global_accs_production = Buffer(
"global_accs_production", (n_chains, n_production_steps), 1
)
# Select and create local sampler
step_size = local_sampler_arg.get("step_size")
if step_size is None:
# Provide default step_size
if local_sampler_name == "MALA":
step_size = 1e-1
else: # GaussianRandomWalk
step_size = jnp.ones(n_dims) * 1e-3
elif isinstance(step_size, jnp.ndarray) and step_size.ndim == 2:
# Extract diagonal from DxD matrix for GaussianRandomWalk
step_size = jnp.diag(step_size)
if local_sampler_name == "MALA":
local_sampler = MALA(step_size=step_size)
elif local_sampler_name == "GaussianRandomWalk":
local_sampler = GaussianRandomWalk(step_size=step_size)
else:
raise ValueError(
f"Unknown local_sampler_name: {local_sampler_name}. "
f"Supported options: 'MALA', 'GaussianRandomWalk'"
)
# Create normalizing flow model
rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(
n_dims, num_layers, hidden_size, num_bins, subkey
)
global_sampler = NFProposal(model, n_NFproposal_batch_size=10000)
optimizer = Optimizer(model=model, learning_rate=learning_rate)
logpdf_resource = LogPDF(logpdf_func, n_dims=n_dims)
# Create sampler state
sampler_state = State(
{
"target_positions": "positions_training",
"target_log_prob": "log_prob_training",
"target_local_accs": "local_accs_training",
"target_global_accs": "global_accs_training",
"training": True,
},
name="sampler_state",
)
# Build resources dict
resources = {
"logpdf": logpdf_resource,
"positions_training": positions_training,
"log_prob_training": log_prob_training,
"local_accs_training": local_accs_training,
"global_accs_training": global_accs_training,
"loss_buffer": loss_buffer,
"positions_production": position_production,
"log_prob_production": log_prob_production,
"local_accs_production": local_accs_production,
"global_accs_production": global_accs_production,
"local_sampler": local_sampler,
"global_sampler": global_sampler,
"model": model,
"optimizer": optimizer,
"sampler_state": sampler_state,
}
# Create strategies
local_stepper = TakeSerialSteps(
"logpdf",
"local_sampler",
"sampler_state",
["target_positions", "target_log_prob", "target_local_accs"],
n_local_steps,
thinning=local_thinning,
chain_batch_size=0,
verbose=False,
)
global_stepper = TakeGroupSteps(
"logpdf",
"global_sampler",
"sampler_state",
["target_positions", "target_log_prob", "target_global_accs"],
n_global_steps,
thinning=global_thinning,
chain_batch_size=0,
verbose=False,
)
model_trainer = TrainModel(
"model",
"positions_training",
"optimizer",
loss_buffer_name="loss_buffer",
n_epochs=n_epochs,
batch_size=10000,
n_max_examples=10000,
verbose=False,
)
update_state = UpdateState(
"sampler_state",
[
"target_positions",
"target_log_prob",
"target_local_accs",
"target_global_accs",
"training",
],
[
"positions_production",
"log_prob_production",
"local_accs_production",
"global_accs_production",
False,
],
)
# Update production phase thinning
def update_production_thinning(
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
"""Update thinning for production phase."""
local_stepper.thinning = output_local_thinning
global_stepper.thinning = output_global_thinning
return rng_key, resources, initial_position
update_production_thinning_lambda = Lambda(
lambda rng_key, resources, initial_position, data: update_production_thinning(
rng_key, resources, initial_position, data
)
)
def reset_steppers(
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
"""Reset the steppers to the initial position."""
local_stepper.set_current_position(0)
global_stepper.set_current_position(0)
return rng_key, resources, initial_position
reset_steppers_lambda = Lambda(
lambda rng_key, resources, initial_position, data: reset_steppers(
rng_key, resources, initial_position, data
)
)
update_global_step = Lambda(
lambda rng_key, resources, initial_position, data: global_stepper.set_current_position(
local_stepper.current_position
)
)
update_local_step = Lambda(
lambda rng_key, resources, initial_position, data: local_stepper.set_current_position(
global_stepper.current_position
)
)
def update_model(
rng_key: PRNGKeyArray,
resources: dict[str, Resource],
initial_position: Float[Array, "n_chains n_dim"],
data: dict,
) -> tuple[
PRNGKeyArray,
dict[str, Resource],
Float[Array, "n_chains n_dim"],
]:
"""Update the model."""
model = resources["model"]
resources["global_sampler"] = eqx.tree_at(
lambda x: x.model,
resources["global_sampler"],
model,
)
return rng_key, resources, initial_position
update_model_lambda = Lambda(
lambda rng_key, resources, initial_position, data: update_model(
rng_key, resources, initial_position, data
)
)
def _fmt_duration(seconds: float) -> str:
"""Format a duration in seconds as a human-readable string."""
seconds = max(0.0, seconds)
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
m, s = divmod(int(seconds), 60)
return f"{m}m {s:02d}s"
else:
h, rem = divmod(int(seconds), 3600)
m, s = divmod(rem, 60)
return f"{h}h {m:02d}m {s:02d}s"
# Progress tracking state (mutable lists allow mutation from closures)
_training_start: list[float] = []
_training_loops_done: list[int] = [0]
_production_start: list[float] = []
_production_loops_done: list[int] = [0]
def _cb_start_training() -> None:
"""Record training start time and print phase header."""
_training_start.append(time.time())
logger.info("=" * 70)
logger.info(f"FLOWMC TRAINING PHASE ({n_training_loops} loops)")
logger.info(
f" Chains: {n_chains} | Local steps/loop: {n_local_steps} | "
f"Global steps/loop: {n_global_steps} | Epochs/loop: {n_epochs}"
)
logger.info("=" * 70)
def _cb_progress_training() -> None:
"""Print training progress after each loop iteration."""
_training_loops_done[0] += 1
loop = _training_loops_done[0]
elapsed = time.time() - _training_start[0]
avg_per_loop = elapsed / loop
remaining = (n_training_loops - loop) * avg_per_loop
bar_length = 20
filled = int(loop / n_training_loops * bar_length)
bar = "█" * filled + "░" * (bar_length - filled)
logger.info(
f"[Training] Loop {loop:3d}/{n_training_loops} [{bar}] | "
f"Elapsed: {_fmt_duration(elapsed):>10} | ETA: {_fmt_duration(remaining):>10} | "
f"{_fmt_duration(avg_per_loop)}/loop"
)
def _cb_start_production() -> None:
"""Record production start time and print phase header."""
_production_start.append(time.time())
logger.info("=" * 70)
logger.info(f"FLOWMC PRODUCTION PHASE ({n_production_loops} loops)")
logger.info(
f" Chains: {n_chains} | Local steps/loop: {n_local_steps} | "
f"Global steps/loop: {n_global_steps}"
)
logger.info("=" * 70)
def _cb_progress_production() -> None:
"""Print production progress after each loop iteration."""
_production_loops_done[0] += 1
loop = _production_loops_done[0]
elapsed = time.time() - _production_start[0]
avg_per_loop = elapsed / loop
remaining = (n_production_loops - loop) * avg_per_loop
bar_length = 20
filled = int(loop / n_production_loops * bar_length)
bar = "█" * filled + "░" * (bar_length - filled)
logger.info(
f"[Production] Loop {loop:3d}/{n_production_loops} [{bar}] | "
f"Elapsed: {_fmt_duration(elapsed):>10} | ETA: {_fmt_duration(remaining):>10} | "
f"{_fmt_duration(avg_per_loop)}/loop"
)
start_training_lambda = Lambda(lambda _rk, _r, _ip, _d: _cb_start_training())
progress_training_lambda = Lambda(
lambda _rk, _r, _ip, _d: _cb_progress_training()
)
start_production_lambda = Lambda(
lambda _rk, _r, _ip, _d: _cb_start_production()
)
progress_production_lambda = Lambda(
lambda _rk, _r, _ip, _d: _cb_progress_production()
)
strategies = {
"local_stepper": local_stepper,
"global_stepper": global_stepper,
"model_trainer": model_trainer,
"update_state": update_state,
"update_global_step": update_global_step,
"update_local_step": update_local_step,
"reset_steppers": reset_steppers_lambda,
"update_model": update_model_lambda,
"update_production_thinning": update_production_thinning_lambda,
"start_training": start_training_lambda,
"progress_training": progress_training_lambda,
"start_production": start_production_lambda,
"progress_production": progress_production_lambda,
}
# Build strategy order
training_phase = [
"local_stepper",
"update_global_step",
"model_trainer",
"update_model",
"global_stepper",
"update_local_step",
]
production_phase = [
"local_stepper",
"update_global_step",
"global_stepper",
"update_local_step",
]
strategy_order: list[str] = ["start_training"]
for _ in range(n_training_loops):
strategy_order.extend(training_phase)
strategy_order.append("progress_training")
strategy_order.append("reset_steppers")
strategy_order.append("update_state")
strategy_order.append("update_production_thinning")
strategy_order.append("start_production")
for _ in range(n_production_loops):
strategy_order.extend(production_phase)
strategy_order.append("progress_production")
# Create flowMC sampler
self.sampler = Sampler(
n_dim=n_dims,
n_chains=n_chains,
rng_key=rng_key,
resources=resources,
strategies=strategies,
strategy_order=strategy_order,
)
[docs]
def sample(self, key):
"""
Run flowMC sampling.
Parameters
----------
key : PRNGKeyArray
JAX random key
Notes
-----
This method includes a critical bug fix: parameter ordering is preserved
when converting from dictionary to array using a list comprehension instead
of jax.tree.leaves().
"""
# Sample initial positions from prior
# Use jnp.inf instead of jnp.nan for initialization
initial_position = (
jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.inf
)
while not jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]
key, subkey = jax.random.split(key)
guess = self.prior.sample(subkey, self.sampler.n_chains)
for transform in self.sample_transforms:
guess = jax.vmap(transform.forward)(guess)
# CRITICAL FIX: Preserve parameter order when converting dict to array
# Do NOT use jax.tree.leaves() as it doesn't preserve dictionary order
guess = jnp.array(
[guess[param_name] for param_name in self.parameter_names]
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])
self.sampler.sample(initial_position, {}) # Empty data dict
[docs]
def get_samples(self) -> dict:
"""
Get production samples from the flowMC sampler.
Returns
-------
dict
Dictionary of production samples
"""
# Access production buffer
from flowMC.resource.buffers import Buffer
positions_buffer = self.sampler.resources["positions_production"]
assert isinstance(positions_buffer, Buffer)
chains = positions_buffer.data # Access the actual buffer data
chains = chains.reshape(-1, self.prior.n_dim)
chains = jax.vmap(self.add_name)(chains)
for sample_transform in reversed(self.sample_transforms):
chains = jax.vmap(sample_transform.backward)(chains)
return chains
[docs]
def get_log_prob(self) -> Array:
"""
Get log probabilities from flowMC sampler (production samples only).
Returns
-------
Array
Log posterior probability values (1D array, flattened across chains)
"""
from flowMC.resource.buffers import Buffer
log_prob_buffer = self.sampler.resources["log_prob_production"]
assert isinstance(log_prob_buffer, Buffer)
return log_prob_buffer.data.flatten()
[docs]
def get_n_samples(self) -> int:
"""
Get number of production samples from flowMC sampler.
Returns
-------
int
Number of production samples (total across all chains)
"""
log_prob = self.get_log_prob()
return len(log_prob)
[docs]
def get_n_training_samples(self) -> int:
"""
Get number of training samples from flowMC sampler.
This is a FlowMC-specific method for diagnostic purposes.
Returns
-------
int
Number of training samples (total across all chains)
"""
from flowMC.resource.buffers import Buffer
log_prob_buffer = self.sampler.resources["log_prob_training"]
assert isinstance(log_prob_buffer, Buffer)
return len(log_prob_buffer.data.flatten())
[docs]
def get_sampler_output(self) -> SamplerOutput:
"""
Get standardized sampler output (production samples only).
Returns
-------
SamplerOutput
- samples: Parameter samples (dict of arrays)
- log_prob: Log posterior probability
- metadata: {} (empty, MCMC has equal weights)
"""
# Get production samples
samples = self.get_samples()
log_prob = self.get_log_prob()
# FlowMC has no metadata (equal weights)
metadata: dict[str, Any] = {}
return SamplerOutput(
samples=samples,
log_prob=log_prob,
metadata=metadata,
)
[docs]
def get_training_sampler_output(self) -> SamplerOutput:
"""
Get standardized sampler output for training samples.
This is a FlowMC-specific method for diagnostic purposes.
Returns
-------
SamplerOutput
- samples: Parameter samples from training phase (dict of arrays)
- log_prob: Log posterior probability from training phase
- metadata: {} (empty, MCMC has equal weights)
"""
# Get training samples directly from buffer
from flowMC.resource.buffers import Buffer
positions_buffer = self.sampler.resources["positions_training"]
assert isinstance(positions_buffer, Buffer)
chains = positions_buffer.data
chains = chains.reshape(-1, self.prior.n_dim)
chains = jax.vmap(self.add_name)(chains)
for sample_transform in reversed(self.sample_transforms):
chains = jax.vmap(sample_transform.backward)(chains)
samples = chains
# Get training log_prob
from flowMC.resource.buffers import Buffer
log_prob_buffer = self.sampler.resources["log_prob_training"]
assert isinstance(log_prob_buffer, Buffer)
log_prob = log_prob_buffer.data.flatten()
# FlowMC has no metadata (equal weights)
metadata: dict[str, Any] = {}
return SamplerOutput(
samples=samples,
log_prob=log_prob,
metadata=metadata,
)
[docs]
def print_summary(self, transform: bool = True):
"""
Generate summary of the flowMC run.
Parameters
----------
transform : bool, optional
Whether to apply inverse sample transforms to results (default: True)
"""
# Access training data
from flowMC.resource.buffers import Buffer
positions_training_buf = self.sampler.resources["positions_training"]
log_prob_training_buf = self.sampler.resources["log_prob_training"]
local_accs_training_buf = self.sampler.resources["local_accs_training"]
global_accs_training_buf = self.sampler.resources["global_accs_training"]
loss_buf = self.sampler.resources["loss_buffer"]
assert isinstance(positions_training_buf, Buffer)
assert isinstance(log_prob_training_buf, Buffer)
assert isinstance(local_accs_training_buf, Buffer)
assert isinstance(global_accs_training_buf, Buffer)
assert isinstance(loss_buf, Buffer)
positions_training = positions_training_buf.data
log_prob_training = log_prob_training_buf.data
local_accs_training = local_accs_training_buf.data
global_accs_training = global_accs_training_buf.data
loss_vals = loss_buf.data
training_chain = positions_training.reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
training_chain = jax.vmap(sample_transform.backward)(training_chain)
training_log_prob = log_prob_training.flatten()
training_local_acceptance = local_accs_training.flatten()
training_global_acceptance = global_accs_training.flatten()
# Access production data
positions_production_buf = self.sampler.resources["positions_production"]
log_prob_production_buf = self.sampler.resources["log_prob_production"]
local_accs_production_buf = self.sampler.resources["local_accs_production"]
global_accs_production_buf = self.sampler.resources["global_accs_production"]
assert isinstance(positions_production_buf, Buffer)
assert isinstance(log_prob_production_buf, Buffer)
assert isinstance(local_accs_production_buf, Buffer)
assert isinstance(global_accs_production_buf, Buffer)
positions_production = positions_production_buf.data
log_prob_production = log_prob_production_buf.data
local_accs_production = local_accs_production_buf.data
global_accs_production = global_accs_production_buf.data
production_chain = positions_production.reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
production_chain = jax.vmap(sample_transform.backward)(production_chain)
production_log_prob = log_prob_production.flatten()
production_local_acceptance = local_accs_production.flatten()
production_global_acceptance = global_accs_production.flatten()
# flowMC 0.4.5: Buffer is initialized with -inf; local and global acceptance
# buffers are only half-filled (local/global steppers share current_position
# but write to separate buffers). Filter -inf slots before computing stats.
valid_training_local = training_local_acceptance[
training_local_acceptance > -jnp.inf
]
valid_training_global = training_global_acceptance[
training_global_acceptance > -jnp.inf
]
logger.info("Training summary")
logger.info("=" * 10)
for key, value in training_chain.items():
logger.info(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
logger.info(
f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
)
logger.info(
f"Local acceptance: {valid_training_local.mean():.3f} +/- {valid_training_local.std():.3f}"
)
logger.info(
f"Global acceptance: {valid_training_global.mean():.3f} +/- {valid_training_global.std():.3f}"
)
logger.info(f"Max loss: {loss_vals.max():.3f}, Min loss: {loss_vals.min():.3f}")
valid_production_local = production_local_acceptance[
production_local_acceptance > -jnp.inf
]
valid_production_global = production_global_acceptance[
production_global_acceptance > -jnp.inf
]
logger.info("Production summary")
logger.info("=" * 10)
for key, value in production_chain.items():
logger.info(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
logger.info(
f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
)
logger.info(
f"Local acceptance: {valid_production_local.mean():.3f} +/- {valid_production_local.std():.3f}"
)
logger.info(
f"Global acceptance: {valid_production_global.mean():.3f} +/- {valid_production_global.std():.3f}"
)