Source code for fiesta.inference.systematic

from ast import literal_eval
import inspect
import os
from pathlib import Path

import jax.numpy as jnp
import yaml

from fiesta.logging import logger
import fiesta.inference.prior as fiesta_prior
from fiesta.inference.prior import ConstrainedPrior

ALL_PRIORS = dict(inspect.getmembers(fiesta_prior, inspect.isclass))

########################################
# SYSTEMATIC UNCERTAINTY SETUP METHODS #
########################################

[docs] def setup_systematics_basic(likelihood, prior: ConstrainedPrior, error_budget: float = 0.3): # enable one variable sys. uncertainty parameter if "em_syserr" in prior.naming: likelihood._setup_sys_uncertainty_free() logger.info(f"Likelihood is using a collective freely sampled systematic uncertainty as specified in the prior.") # fix systematic uncertainty to set value else: likelihood._setup_sys_uncertainty_fixed(error_budget=error_budget) logger.info(f"Likelihood is using a collective fixed systematic uncertainty {error_budget}.") return likelihood, prior
[docs] def setup_systematic_from_file(likelihood, prior: ConstrainedPrior, systematics_file: str): # read systematic uncertainty setup from file if not os.path.exists(systematics_file): raise OSError(f"Provided systematics file {systematics_file} could not be found.") sys_params_per_filter, time_nodes_per_filter, additional_priors = process_file(systematics_file, likelihood.filters) # setup the likelihood likelihood._setup_sys_uncertainty_from_file(sys_params_per_filter, time_nodes_per_filter) logger.info(f"Likelihood is using systematic uncertainty sampling as specified in {systematics_file}.") # setup the prior if not isinstance(prior, ConstrainedPrior): prior = ConstrainedPrior(prior.priors) prior_list = prior.priors if "em_syserr" in prior.naming: logger.warning(f"When providing a systematics_file, 'em_syserr' should not be listed in the prior. Removing 'em_syserr' from prior list.") index = [ind for ind, p in enumerate(prior_list) if 'em_syserr' in p.naming] prior_list.pop(index[0]) prior_list.extend(prior.constraints) prior_list.extend(additional_priors) prior = ConstrainedPrior(priors=prior_list, conversion_function=prior.conversion) logger.info(f"Prior is now updated to sample systematic uncertainty parameters {[sys_prior.naming[0] for sys_prior in additional_priors]}.") return likelihood, prior
[docs] def process_file(systematic_file, filters): yaml_dict = yaml.safe_load(Path(systematic_file).read_text()) additional_priors = [] sys_params_per_filter = {} time_nodes_per_filter = {} if "collective" in yaml_dict.keys(): if len(yaml_dict.keys())>1: raise ValueError(f"'collective' sys. uncertainty can only be specified if no other sys. uncertainty setup is given in {systematic_file}.") nodes, t_range, sys_prior_type, sys_prior_params = fetch_prior_params(yaml_dict["collective"]) sys_parameters = [] for j in range(1, nodes+1): naming = f"syserr_collective_{j}" sys_parameters.append(naming) sys_prior = ALL_PRIORS[sys_prior_type](**sys_prior_params, naming=[naming]) additional_priors.append(sys_prior) for filt in filters: sys_params_per_filter[filt] = sys_parameters time_nodes_per_filter[filt] = t_range elif "individual" in yaml_dict.keys(): if len(yaml_dict.keys())>1: raise ValueError(f"'individual' sys. uncertainty for each filter can only be specified if no other sys. uncertainty setup is given in {systematic_file}.") nodes, t_range, sys_prior_type, sys_prior_params = fetch_prior_params(yaml_dict["individual"]) for filt in filters: sys_parameters = [] for j in range(1, nodes+1): naming = f"syserr_{filt}_{j}" sys_parameters.append(naming) sys_prior = ALL_PRIORS[sys_prior_type](**sys_prior_params, naming=[naming]) additional_priors.append(sys_prior) sys_params_per_filter[filt] = sys_parameters time_nodes_per_filter[filt] = t_range else: yaml_dict = check_filter_compatability(yaml_dict, filters) for key in yaml_dict.keys(): nodes, t_range, sys_prior_type, sys_prior_params = fetch_prior_params(yaml_dict[key]) sys_parameters = [] for j in range(1, nodes+1): naming = f"syserr_{key}_{j}" sys_parameters.append(naming) sys_prior = ALL_PRIORS[sys_prior_type](**sys_prior_params, naming=[naming]) additional_priors.append(sys_prior) for filt in yaml_dict[key]["filters"]: sys_params_per_filter[filt] = sys_parameters time_nodes_per_filter[filt] = t_range return sys_params_per_filter, time_nodes_per_filter, additional_priors
[docs] def check_filter_compatability(yaml_dict, filters): filters_checked = list(filters) for key in yaml_dict.keys(): if key=="remaining": continue ill_specified_filters = set(yaml_dict[key]["filters"]) - set(filters) if ill_specified_filters: logger.warning(f"Filters {ill_specified_filters} in systematics file are not part of the lightcurve data. Removing them from the sys. error group {key}.") for ill_filter in ill_specified_filters: yaml_dict[key]["filters"].remove(ill_filter) for filt in yaml_dict[key]["filters"]: filters_checked.remove(filt) if not filters_checked: try: del yaml_dict["remaining"] except KeyError: pass else: if "remaining" not in yaml_dict.keys(): raise KeyError(f"Sys error groups in systematic file do not include the following filters {filters_checked}. Set up a 'remaining' group to include those.") yaml_dict["remaining"]["filters"] = filters_checked return yaml_dict
[docs] def fetch_prior_params(yaml_entry: dict): nodes = yaml_entry["time_nodes"] t_range = yaml_entry.get("time_range", None) if t_range is not None: type, t0, t1 = t_range.split(" ") if type == "log": t_range = jnp.geomspace(float(t0), float(t1), nodes) elif type == "linear": t_range = jnp.linspace(float(t0), float(t1), nodes) else: raise ValueError(f"Range specified in systematics file must either be 'linear' or 'log', not {type}.") sys_prior_type = yaml_entry["prior"] if sys_prior_type not in ALL_PRIORS: raise ValueError(f"Prior type specified in systematic file not implemented in fiesta.inference.prior. Allowed priors are {ALL_PRIORS}.") sys_prior_params = yaml_entry["params"] return nodes, t_range, sys_prior_type, sys_prior_params