#!/usr/bin/env python
r"""
Modular inference script for jesterTOV
"""
import os
import time
import warnings
import json
from pathlib import Path
import numpy as np
import jax
import jax.numpy as jnp
# Enable 64-bit precision by default
jax.config.update("jax_enable_x64", True)
from .config.parser import load_config
from .config.schema import (
InferenceConfig,
MetamodelCSEEOSConfig,
BaseMetamodelEOSConfig,
GWLikelihoodConfig,
GWEventConfig,
)
from .priors.parser import parse_prior_file
from .base.prior import CombinePrior
from .base.likelihood import LikelihoodBase
from .transforms import JesterTransform
from .likelihoods.factory import create_combined_likelihood
from .samplers import create_sampler, JesterSampler
from .result import InferenceResult
from jesterTOV.logging_config import get_logger
# Set up logger
logger = get_logger("jester")
[docs]
def determine_keep_names(
config: InferenceConfig,
prior: CombinePrior,
fixed_params: dict[str, float] | None = None,
) -> list[str] | None:
"""
Determine which parameters need to be preserved in transform output.
This function checks which likelihoods are enabled and determines which
prior parameters need to be kept in the transform output for likelihood
evaluation.
Parameters
----------
config : InferenceConfig
Configuration object with likelihood settings
prior : CombinePrior
Prior object with parameter names (sampled parameters only)
fixed_params : dict[str, float] | None
Parameters pinned to constant values via ``Fixed(...)`` in the prior
file. These are not in ``prior.parameter_names`` but are still valid.
Returns
-------
list[str] | None
List of parameter names to keep, or None if no special handling needed
Raises
------
ValueError
If a required parameter is missing from both the prior and fixed_params
"""
_fixed = fixed_params or {}
keep_names = []
# ChiEFT likelihood requires 'nbreak' parameter for CSE grid stitching
# Only check if using metamodel_cse EOS
chieft_enabled = any(
lk.enabled and lk.type == "chieft" for lk in config.likelihoods
)
if chieft_enabled and isinstance(config.eos, MetamodelCSEEOSConfig):
if "nbreak" not in prior.parameter_names and "nbreak" not in _fixed:
raise ValueError(
"ChiEFT likelihood is enabled with metamodel_cse but 'nbreak' parameter is not in the prior. "
"Please add 'nbreak' to your prior specification file. "
f"Current prior parameters: {prior.parameter_names}"
)
# Only add to keep_names if nbreak is sampled (not fixed).
# Fixed parameters are already added to the transform output automatically.
if "nbreak" in prior.parameter_names:
keep_names.append("nbreak")
logger.info(
"ChiEFT likelihood enabled: 'nbreak' parameter will be preserved in transform output"
)
else:
logger.info(
"ChiEFT likelihood enabled: 'nbreak' is fixed, will appear in transform output via fixed_params"
)
return keep_names if keep_names else None
[docs]
def setup_prior(config: InferenceConfig) -> tuple[CombinePrior, dict[str, float]]:
"""
Setup prior from configuration.
Parameters
----------
config : InferenceConfig
Configuration object
Returns
-------
prior : CombinePrior
Combined prior over sampled parameters only.
fixed_params : dict[str, float]
Parameters pinned to constant values via ``Fixed(...)`` in the prior
file. These are excluded from the sampling space.
"""
from .base.prior import UniformPrior, CombinePrior
# Determine conditional parameters
nb_CSE = config.eos.nb_CSE if isinstance(config.eos, MetamodelCSEEOSConfig) else 0
# Check if GW or NICER likelihoods are enabled (both need _random_key)
# Note: the default `gw` and `nicer` likelihoods do NOT need
# _random_key as they pre-generate masses on which to evaluate
needs_random_key = False
for lk in config.likelihoods:
if lk.enabled and lk.type in ["gw_resampled", "nicer_kde"]:
needs_random_key = True
break
# Parse prior file
parsed = parse_prior_file(
config.prior.specification_file,
nb_CSE=nb_CSE,
)
prior = parsed.prior
fixed_params = parsed.fixed_params
if fixed_params:
logger.info(f"Fixed parameters found in prior file: {fixed_params}")
# Add _random_key prior if GW or NICER likelihoods are enabled
if needs_random_key:
logger.info("Adding _random_key prior for likelihood sampling")
random_key_prior = UniformPrior(
float(0), float(2**32 - 1), parameter_names=["_random_key"]
)
# Flatten the prior structure to avoid nested CombinePrior
prior = CombinePrior(prior.base_prior + [random_key_prior])
return prior, fixed_params
def prepare_gw_flows(config: InferenceConfig, outdir: Path) -> InferenceConfig:
"""Prepare normalizing flows for GW events that need flow training.
For each enabled :class:`~jesterTOV.inference.config.schema.GWLikelihoodConfig`
with events in bilby or NPZ mode, this function:
1. Resolves the ``nf_model_dir`` to ``{outdir}/gw_flow_cache/{event.name}``.
2. Obtains posterior samples — either by extracting from a bilby HDF5
(``from_bilby_result``) or by using an existing NPZ directly
(``from_npz_file``).
3. Trains a normalizing flow (with config-hash-based cache invalidation).
4. Replaces the training-mode event with a pre-trained-flow-mode event.
This is a no-op if no ``GWLikelihoodConfig`` events use ``from_bilby_result``
or ``from_npz_file``.
Parameters
----------
config : InferenceConfig
Current inference configuration.
outdir : Path
Run output directory (flows are cached under ``{outdir}/gw_flow_cache/``).
Returns
-------
InferenceConfig
Updated config where all GW events point to a pre-trained flow directory.
"""
import hashlib
import shutil
from .flows.bilby_extract import extract_gw_posterior_from_bilby
from .flows.train_flow import train_flow_from_config
from .flows.config import FlowTrainingConfig
logger.info("Checking to prepare GW normalizing flows...")
updated_likelihoods = list(config.likelihoods)
any_changes = False
for lk_idx, lk_config in enumerate(config.likelihoods):
if not isinstance(lk_config, GWLikelihoodConfig) or not lk_config.enabled:
continue
training_events = [
e
for e in lk_config.events
if e.from_bilby_result is not None or e.from_npz_file is not None
]
if not training_events:
continue
updated_events = list(lk_config.events)
for ev_idx, event in enumerate(lk_config.events):
if event.from_bilby_result is None and event.from_npz_file is None:
continue
# 1. Resolve nf_model_dir (cache directory for this event)
nf_model_dir = outdir / "gw_flow_cache" / event.name
nf_model_dir.mkdir(parents=True, exist_ok=True)
# 2. Obtain NPZ path
if event.from_bilby_result is not None:
# Bilby mode: extract NPZ from HDF5 into the cache directory
npz_path = nf_model_dir / "posterior_samples.npz"
if not npz_path.exists():
logger.info(
f"Extracting GW posterior for '{event.name}' from "
f"{event.from_bilby_result}"
)
extract_gw_posterior_from_bilby(
bilby_result_file=event.from_bilby_result,
output_file=str(npz_path),
)
else:
logger.info(
f"Using cached posterior samples for '{event.name}' at {npz_path}"
)
else:
# NPZ mode: use the provided NPZ file directly
npz_path = Path(event.from_npz_file) # type: ignore[arg-type]
if not npz_path.exists():
logger.error(
f"GW event '{event.name}': from_npz_file path does not exist: {npz_path}"
)
raise ValueError(
f"GW event '{event.name}': from_npz_file '{npz_path}' does not exist."
)
if npz_path.suffix.lower() != ".npz":
logger.error(
f"GW event '{event.name}': from_npz_file does not have a .npz extension: {npz_path}"
)
raise ValueError(
f"GW event '{event.name}': from_npz_file '{npz_path}' does not have a .npz extension."
)
logger.info(
f"Using provided NPZ posterior samples for '{event.name}' "
f"at {npz_path}"
)
# 3. Build FlowTrainingConfig
fixed_fields = {
"posterior_file": str(npz_path),
"parameter_names": [
"mass_1_source",
"mass_2_source",
"lambda_1",
"lambda_2",
],
"output_dir": str(nf_model_dir),
}
if event.flow_config:
# If there is a flow_config, use it as a base and override the fixed fields (e.g., to allow custom training settings like n_epochs or batch_size)
ft_config = FlowTrainingConfig.from_yaml(event.flow_config)
ft_config = ft_config.model_copy(update=fixed_fields)
else:
# No flow_config provided, use defaults for everything except the fixed fields
ft_config = FlowTrainingConfig(**fixed_fields)
# 4. Check cache
flow_weights = nf_model_dir / "flow_weights.eqx"
config_hash_file = nf_model_dir / "flow_config_hash.json"
current_hash = hashlib.sha256(
ft_config.model_dump_json().encode()
).hexdigest()
should_train = True
if flow_weights.exists() and not event.retrain_flow:
stored_hash: str | None = None
if config_hash_file.exists():
stored_hash = json.loads(config_hash_file.read_text()).get("hash")
if stored_hash == current_hash:
logger.info(
f"Reusing cached flow for '{event.name}' at {nf_model_dir}"
)
should_train = False
else:
logger.warning(
f"Flow config changed for '{event.name}' → retraining"
)
shutil.rmtree(nf_model_dir)
nf_model_dir.mkdir(parents=True, exist_ok=True)
# For bilby mode: re-extract NPZ after clearing the cache directory
if event.from_bilby_result is not None and not npz_path.exists():
extract_gw_posterior_from_bilby(
bilby_result_file=event.from_bilby_result,
output_file=str(npz_path),
)
elif event.retrain_flow and flow_weights.exists():
logger.info(f"retrain_flow=True for '{event.name}' → retraining")
shutil.rmtree(nf_model_dir)
nf_model_dir.mkdir(parents=True, exist_ok=True)
# For bilby mode: re-extract NPZ after clearing the cache directory
if event.from_bilby_result is not None and not npz_path.exists():
extract_gw_posterior_from_bilby(
bilby_result_file=event.from_bilby_result,
output_file=str(npz_path),
)
# 5. Train flow if needed
if should_train:
logger.info(f"Training normalizing flow for '{event.name}'...")
train_flow_from_config(ft_config)
config_hash_file.write_text(json.dumps({"hash": current_hash}))
# 6. Replace event with resolved pre-trained-flow event
updated_events[ev_idx] = GWEventConfig(
name=event.name,
nf_model_dir=str(nf_model_dir),
)
updated_likelihoods[lk_idx] = lk_config.model_copy(
update={"events": updated_events}
)
any_changes = True
if not any_changes:
return config
return config.model_copy(update={"likelihoods": updated_likelihoods})
# TODO: remove transform second argument, as it is unused
# TODO: this is a bit redundant and we can just use the factory directly
[docs]
def setup_likelihood(
config: InferenceConfig, transform: JesterTransform
) -> LikelihoodBase:
"""
Setup combined likelihood from configuration
Parameters
----------
config : InferenceConfig
Configuration object
transform : JesterTransform
Transform instance
Returns
-------
LikelihoodBase
Combined likelihood instance
"""
return create_combined_likelihood(config.likelihoods)
[docs]
def run_sampling(
sampler: JesterSampler,
seed: int,
config: InferenceConfig,
outdir: str | Path,
fixed_params: dict[str, float] | None = None,
) -> InferenceResult:
"""
Run MCMC sampling and create InferenceResult
Parameters
----------
sampler : JesterSampler
JesterSampler instance (FlowMC, BlackJAX NS, or BlackJAX SMC)
seed : int
Random seed for sampling
config : InferenceConfig
Configuration object
outdir : str or Path
Output directory
fixed_params : dict[str, float] | None, optional
Parameters pinned to constant values during inference.
Returns
-------
InferenceResult
Result object containing samples, metadata, and histories
"""
logger.info(f"Starting sampling with seed {seed}...")
start = time.time()
sampler.sample(jax.random.PRNGKey(seed))
sampler.print_summary()
end = time.time()
runtime = end - start
logger.info(
f"Sampling complete! Runtime: {int(runtime / 60)} min {int(runtime % 60)} sec"
)
# Generate diagnostic plots for SMC samplers
from .samplers.blackjax.smc.base import BlackjaxSMCSampler
# TODO: plot_diagnostics should be in the base class, do not fail if not implemented but just pass
# Then, samplers can implement their own diagnostics as needed (e.g., FlowMC could have training diagnostics, acceptance rates, etc.) -- for now we only have this for SMC, but that is fine
if isinstance(sampler, BlackjaxSMCSampler):
logger.info("Generating SMC diagnostic plots...")
sampler.plot_diagnostics(outdir=outdir, filename="smc_diagnostics.png")
# Create InferenceResult from sampler output
logger.info("Creating InferenceResult from sampler output...")
result = InferenceResult.from_sampler(
sampler=sampler,
config=config,
runtime=runtime,
fixed_params=fixed_params,
)
return result
# TODO: fully deprecate this: remove this entirely (if no other dependency on it)
[docs]
def generate_eos_samples(
config: InferenceConfig,
result: InferenceResult,
transform_eos: JesterTransform,
outdir: str | Path,
n_eos_samples: int = 10_000,
) -> None:
"""
.. deprecated::
This function is deprecated and will be removed in a future version.
Use :meth:`InferenceResult.add_eos_from_transform` instead.
Generate EOS curves from sampled parameters and add to InferenceResult
Parameters
----------
config : InferenceConfig
Configuration object
result : InferenceResult
Result object with posterior samples
transform_eos : JesterTransform
Transform for generating full EOS quantities
outdir : str or Path
Output directory
n_eos_samples : int, optional
Number of EOS samples to generate
"""
warnings.warn(
"generate_eos_samples() is deprecated and will be removed in a future version. "
"Use InferenceResult.add_eos_from_transform() instead.",
DeprecationWarning,
stacklevel=2,
)
# Get log_prob from result
log_prob = result.posterior["log_prob"]
# Cap n_eos_samples at available sample size
n_available = len(log_prob)
if n_eos_samples > n_available:
logger.warning(
f"Requested {n_eos_samples} EOS samples but only {n_available} available."
)
logger.warning(f"Using all {n_available} samples instead.")
n_eos_samples = n_available
logger.info(f"Generating {n_eos_samples} EOS samples...")
# Randomly select samples
idx = np.random.choice(np.arange(len(log_prob)), size=n_eos_samples, replace=False)
# Filter out metadata fields and derived EOS quantities that aren't transform parameters
# Only keep fields that are NEP/CSE parameters for the transform
exclude_keys = {
"weights",
"ess",
"logL",
"logL_birth",
"log_prob",
"_sampler_specific",
"masses_EOS",
"radii_EOS",
"Lambdas_EOS",
"n",
"p",
"e",
"cs2",
}
param_samples = {k: v for k, v in result.posterior.items() if k not in exclude_keys}
chosen_samples = {k: jnp.array(v[idx]) for k, v in param_samples.items()}
# NOTE: This function is currently unused - see InferenceResult.add_eos_from_transform() instead
# CRITICAL: If this function is ever re-enabled, remember to filter ALL arrays:
# - log_prob and sampler fields (weights, ess, logL, logL_birth)
# - NEP/CSE parameter arrays
# Store the original full log_prob for reference, then update with filtered version
result.posterior["log_prob_full"] = result.posterior["log_prob"].copy()
result.posterior["log_prob"] = result.posterior["log_prob"][idx]
# Filter other sampler-specific fields if present
sampler_fields_to_filter = ["weights", "ess", "logL", "logL_birth"]
for field in sampler_fields_to_filter:
if field in result.posterior:
result.posterior[f"{field}_full"] = result.posterior[field].copy()
result.posterior[field] = result.posterior[field][idx]
logger.info(
f"Filtered log_prob and sampler fields from {len(log_prob)} to {len(result.posterior['log_prob'])} samples"
)
# Generate EOS curves with batched processing
logger.info("Running TOV solver with batched processing...")
my_forward = jax.jit(transform_eos.forward)
# Get batch size from config
batch_size = config.sampler.log_prob_batch_size
if batch_size > n_eos_samples:
logger.warning(
f"Requested batch size ({batch_size}) is larger than the number of samples "
f"({n_eos_samples}). Adjusting batch size to {n_eos_samples}."
)
batch_size = n_eos_samples
logger.info(f"Using batch size: {batch_size}")
# Run with batched processing (JIT compilation happens on first batch)
TOV_start = time.time()
transformed_samples = jax.lax.map(my_forward, chosen_samples, batch_size=batch_size)
TOV_end = time.time()
logger.info(
f"TOV solve time: {TOV_end - TOV_start:.2f} s ({n_eos_samples} samples)"
)
# Add derived EOS quantities to result
result.add_derived_eos(transformed_samples)
logger.info("Derived EOS quantities added to InferenceResult")
# TODO: there are some calls that check specific types of samplers/configs/...
# Ideally we should remove this and just have a small loop that prints over all available
# attributes of the config/sampler/likelihood/transform/prior objects, so that we don't have to update this function every time we add a new sampler type or likelihood type or EOS type, etc. We can still have some special handling for certain fields (e.g., if chieft enabled then print nbreak bounds, etc.) but in general we should just loop over all fields and print them in a structured way (e.g., using Pydantic's model_dump() with some formatting for logging)
# This is already done a bit with the likelihoods, so follow that approach in the future
[docs]
def main(config_path: str) -> None:
"""Main inference script
Parameters
----------
config_path : str
Path to YAML configuration file
"""
# Load configuration
logger.info(f"Loading configuration from {config_path}")
config = load_config(config_path)
# Enable NaN debugging if requested
if config.debug_nans:
logger.info("Enabling JAX NaN debugging")
jax.config.update("jax_debug_nans", True)
outdir = config.sampler.output_dir
# Print GPU info
logger.info(f"JAX devices: {jax.devices()}")
# Validation only
if config.validate_only:
logger.info("Configuration valid!")
return
# Setup components
logger.info("Setting up prior...")
prior, fixed_params = setup_prior(config)
# Log detailed prior information
logger.info(f"Prior has {prior.n_dim} dimensions")
logger.info(f"Prior parameter names: {prior.parameter_names}")
# Get individual priors - CombinePrior stores them in base_prior attribute
if hasattr(prior, "base_prior") and isinstance(prior.base_prior, list):
individual_priors = prior.base_prior
else:
# For single priors, wrap in a list
individual_priors = [prior]
# Flatten the list of priors (in case of nested CombinePriors)
def flatten_priors(priors_list):
result = []
for p in priors_list:
if hasattr(p, "base_prior") and isinstance(p.base_prior, list):
result.extend(flatten_priors(p.base_prior))
else:
result.append(p)
return result
all_priors = flatten_priors(individual_priors)
# Log each prior with its parameters
idx = 0
for param_prior in all_priors:
for name in param_prior.parameter_names:
if hasattr(param_prior, "xmin") and hasattr(param_prior, "xmax"):
logger.info(
f" [{idx}] {name}: Uniform({param_prior.xmin}, {param_prior.xmax})"
)
else:
logger.info(f" [{idx}] {name}: {type(param_prior).__name__}")
idx += 1
# Determine which parameters need to be preserved in transform output
# based on enabled likelihoods (validates required parameters exist in prior or fixed_params)
keep_names = determine_keep_names(config, prior, fixed_params)
logger.info("Setting up transform...")
transform = setup_transform(
config, prior=prior, keep_names=keep_names, fixed_params=fixed_params
)
# Log transform details
logger.info(f"EOS type: {config.eos.type}")
if isinstance(config.eos, MetamodelCSEEOSConfig):
logger.info(f" nb_CSE: {config.eos.nb_CSE}")
if config.eos.max_nbreak_nsat is not None:
logger.info(f" max_nbreak_nsat: {config.eos.max_nbreak_nsat:.4f} n_sat")
if isinstance(config.eos, BaseMetamodelEOSConfig):
logger.info(f" ndat_metamodel: {config.eos.ndat_metamodel}")
logger.info(f" nmax_nsat: {config.eos.nmax_nsat}")
logger.info(f"TOV solver: {config.tov.type}")
logger.info(f" ndat_TOV: {config.tov.ndat_TOV}")
if keep_names:
logger.info(f" Preserving parameters in output: {keep_names}")
config = prepare_gw_flows(config, Path(outdir))
logger.info("Setting up likelihood...")
likelihood = setup_likelihood(config, transform)
# Log detailed likelihood information
enabled_likelihoods = [lk for lk in config.likelihoods if lk.enabled]
logger.info(f"Number of enabled likelihoods: {len(enabled_likelihoods)}")
for lk in enabled_likelihoods:
# Use Pydantic's model_dump to serialize config for logging
lk_dict = lk.model_dump(
exclude={"enabled"}
) # Exclude enabled since we already filtered
logger.info(f" - {lk.type.upper()}:")
logger.info(f" {json.dumps(lk_dict, indent=6)}")
logger.info(f"Setting up {config.sampler.type} sampler...")
sampler = create_sampler(
config=config.sampler,
prior=prior,
likelihood=likelihood,
likelihood_transforms=[transform],
seed=config.seed,
)
# Log detailed sampler configuration
logger.info("=" * 60)
logger.info("Configuration Summary")
logger.info("=" * 60)
logger.info(f"EOS type: {config.eos.type}")
logger.info(f"TOV solver: {config.tov.type}")
logger.info(f"Random seed: {config.seed}")
logger.info(f"Sampler type: {config.sampler.type}")
logger.info("Sampler Configuration:")
# Log sampler-specific config fields
if config.sampler.type == "flowmc":
logger.info(f" Chains: {config.sampler.n_chains}")
logger.info(f" Training loops: {config.sampler.n_loop_training}")
logger.info(f" Production loops: {config.sampler.n_loop_production}")
logger.info(f" Local steps per loop: {config.sampler.n_local_steps}")
logger.info(f" Global steps per loop: {config.sampler.n_global_steps}")
logger.info(f" Training epochs: {config.sampler.n_epochs}")
logger.info(f" Learning rate: {config.sampler.learning_rate}")
logger.info(f" Training thinning: {config.sampler.train_thinning}")
logger.info(f" Output thinning: {config.sampler.output_thinning}")
elif config.sampler.type == "blackjax-ns-aw":
logger.info(f" Live points: {config.sampler.n_live}")
logger.info(f" Delete fraction: {config.sampler.n_delete_frac}")
logger.info(f" Target MCMC steps: {config.sampler.n_target}")
logger.info(f" Termination dlogZ: {config.sampler.termination_dlogz}")
elif config.sampler.type in ["smc-rw", "smc-nuts"]:
logger.info(f" Particles: {config.sampler.n_particles}")
logger.info(f" MCMC steps: {config.sampler.n_mcmc_steps}")
logger.info(f" Target ESS: {config.sampler.target_ess}")
# Log shared sampler config fields
logger.info(f" EOS samples to generate: {config.sampler.n_eos_samples}")
logger.info(f" Output directory: {outdir}")
logger.info("=" * 60 + "\n")
# Dry run option
if config.dry_run:
logger.info("Dry run complete!")
return
# Create output directory
os.makedirs(outdir, exist_ok=True)
# Test likelihood evaluation # FIXME: this should throw an error if a Nan is found
logger.info("Testing likelihood evaluation...")
test_samples = prior.sample(jax.random.PRNGKey(0), 3)
test_samples_transformed = jax.vmap(transform.forward)(test_samples)
test_log_prob = jax.vmap(likelihood.evaluate)(test_samples_transformed)
logger.info(f"Test log probabilities: {test_log_prob}")
# Run inference
result = run_sampling(
sampler, config.seed, config, outdir, fixed_params=fixed_params
)
# Generate EOS quantities from posterior samples
# Note: This requires recomputing the TOV solver for selected samples.
# Future optimization: implement transform caching outside JAX trace (see JesterSampler)
result.add_eos_from_transform(
transform=transform, # Use the same transform from sampling
n_eos_samples=config.sampler.n_eos_samples,
batch_size=config.sampler.log_prob_batch_size,
)
# Save unified HDF5 file
result_path = os.path.join(outdir, "results.h5")
result.save(result_path)
logger.info(f"Results saved to {result_path}")
# Run postprocessing if enabled
if config.postprocessing.enabled:
logger.info("\n" + "=" * 60)
logger.info("Running postprocessing...")
logger.info("=" * 60)
from jesterTOV.inference.postprocessing.postprocessing import generate_all_plots
generate_all_plots(
outdir=outdir,
prior_dir=config.postprocessing.prior_dir,
make_cornerplot_flag=config.postprocessing.make_cornerplot,
make_massradius_flag=config.postprocessing.make_massradius,
make_pressuredensity_flag=config.postprocessing.make_pressuredensity,
make_histograms_flag=config.postprocessing.make_histograms,
make_cs2_flag=config.postprocessing.make_cs2,
injection_eos_path=config.postprocessing.injection_eos_path,
)
logger.info(f"\nPostprocessing complete! Plots saved to {outdir}")
logger.info(f"\nInference complete! Results saved to {outdir}")
[docs]
def cli_entry_point() -> None:
"""
Entry point for console script.
Allows running inference with:
run_jester_inference config.yaml
Instead of:
python -m jesterTOV.inference.run_inference config.yaml
"""
import sys
# Check for exactly one argument (the config file path)
if len(sys.argv) != 2:
logger.error("Usage: run_jester_inference <config.yaml>")
logger.info("\nExamples:")
logger.info(" run_jester_inference config.yaml")
logger.info(
" run_jester_inference examples/inference/full_inference/config.yaml"
)
logger.info(
"\nOptions like dry_run and validate_only should be set in the YAML config file."
)
sys.exit(1)
config_path = sys.argv[1]
main(config_path)
if __name__ == "__main__":
import sys
if len(sys.argv) != 2:
logger.error("Usage: python -m jesterTOV.inference.run_inference <config.yaml>")
sys.exit(1)
main(sys.argv[1])