Source code for jesterTOV.inference.likelihoods.combined

r"""Combined and utility likelihood classes"""

import jax.numpy as jnp
from jaxtyping import Array, Float

from jesterTOV.inference.base import LikelihoodBase


[docs] class CombinedLikelihood(LikelihoodBase): """ Combine multiple likelihoods into a single log-likelihood sum Parameters ---------- likelihoods_list : list[LikelihoodBase] List of likelihood objects to combine Attributes ---------- likelihoods_list : list[LikelihoodBase] Stored list of likelihoods counter : int Evaluation counter (for debugging/monitoring) """ likelihoods_list: list[LikelihoodBase] counter: int
[docs] def __init__(self, likelihoods_list: list[LikelihoodBase]) -> None: super().__init__() self.likelihoods_list = likelihoods_list self.counter = 0
[docs] def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate combined log-likelihood Parameters ---------- params : dict[str, Float | Array] Parameter dictionary passed to all likelihoods Returns ------- Float Sum of all log-likelihoods """ # TODO: perhaps this can be improved performance wise, with vmap or pytree? all_log_likelihoods: Float[Array, " n_likelihoods"] = jnp.array( [likelihood.evaluate(params) for likelihood in self.likelihoods_list] ) return jnp.sum(all_log_likelihoods)
[docs] class ZeroLikelihood(LikelihoodBase): """ Placeholder likelihood that always returns 0 (for testing/debugging) Attributes ---------- counter : int Evaluation counter (for debugging/monitoring) """ counter: int
[docs] def __init__(self) -> None: super().__init__() self.counter = 0
[docs] def evaluate(self, params: dict[str, Float | Array]) -> Float: """ Evaluate zero log-likelihood Parameters ---------- params : dict[str, Float | Array] Parameter dictionary (ignored) Returns ------- Float Always returns 0.0 """ return 0.0