jesterTOV.inference.flows.config.FlowTrainingConfig#

class FlowTrainingConfig(**data)[source]#

Bases: BaseModel

Configuration for training normalizing flows on posterior samples.

Variables:
  • posterior_file (str) – Path to .npz file with posterior samples

  • output_dir (str) – Directory to save model weights, kwargs, and plots

  • parameter_names (list[str]) – List of parameter names to extract from posterior file. Examples: GW parameters [“mass_1_source”, “mass_2_source”, “lambda_1”, “lambda_2”], NICER [“mass”, “radius”]

  • num_epochs (int) – Number of training epochs (default: 600)

  • learning_rate (float) – Learning rate for training (default: 1e-3)

  • max_patience (int) – Early stopping patience (default: 50)

  • nn_depth (int) – Depth of neural network blocks (default: 5)

  • nn_block_dim (int) – Dimension of neural network blocks (default: 8)

  • flow_layers (int) – Number of flow layers (default: 1)

  • invert (bool) – Whether to invert the flow (default: True)

  • cond_dim (int | None) – Conditional dimension for conditional flows (default: None)

  • max_samples (int) – Maximum number of samples to use for training (default: 50,000)

  • seed (int) – Random seed for reproducibility (default: 0)

  • plot_corner (bool) – Generate corner plot comparison (default: True)

  • plot_losses (bool) – Plot training and validation losses (default: True)

  • flow_type (Literal["block_neural_autoregressive_flow", "masked_autoregressive_flow", "coupling_flow"]) – Type of normalizing flow to use (default: masked_autoregressive_flow)

  • nn_width (int) – Width of neural network hidden layers (default: 50)

  • standardize (bool) – Whether to standardize input data (default: True, changed from False)

  • standardization_method (Literal["zscore", "minmax"]) – Method for standardizing input data (default: zscore). - “zscore”: Standardize to mean=0, std=1 (recommended for most cases) - “minmax”: Standardize to [0, 1] range (legacy, for backward compatibility) Only used if standardize=True.

  • transformer (Literal["affine", "rational_quadratic_spline"]) – Transformer type for masked_autoregressive_flow and coupling_flow (default: rational_quadratic_spline, changed from affine)

  • transformer_knots (int) – Number of knots for RationalQuadraticSpline transformer (default: 10, changed from 8)

  • transformer_interval (float) – Interval for RationalQuadraticSpline transformer (default: 5.0, changed from 4.0)

  • val_prop (float) – Proportion of data to use for validation (default: 0.2)

  • batch_size (int) – Batch size for training (default: 128)

__init__(**data)#

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

Methods

__init__(**data)

Create a new model by parsing and validating input data from keyword arguments.

construct([_fields_set])

copy(*[, include, exclude, update, deep])

Returns a copy of the model.

dict(*[, include, exclude, by_alias, ...])

from_orm(obj)

from_yaml(filepath)

Load configuration from a YAML file.

json(*[, include, exclude, by_alias, ...])

model_construct([_fields_set])

Creates a new instance of the Model class with validated data.

model_copy(*[, update, deep])

!!! abstract "Usage Documentation"

model_dump(*[, mode, include, exclude, ...])

!!! abstract "Usage Documentation"

model_dump_json(*[, indent, ensure_ascii, ...])

!!! abstract "Usage Documentation"

model_json_schema([by_alias, ref_template, ...])

Generates a JSON schema for a model class.

model_parametrized_name(params)

Compute the class name for parametrizations of generic classes.

model_post_init(context, /)

Override this method to perform additional initialization after __init__ and model_construct.

model_rebuild(*[, force, raise_errors, ...])

Try to rebuild the pydantic-core schema for the model.

model_validate(obj, *[, strict, extra, ...])

Validate a pydantic model instance.

model_validate_json(json_data, *[, strict, ...])

!!! abstract "Usage Documentation"

model_validate_strings(obj, *[, strict, ...])

Validate the given object with string data against the Pydantic model.

parse_file(path, *[, content_type, ...])

parse_obj(obj)

parse_raw(b, *[, content_type, encoding, ...])

schema([by_alias, ref_template])

schema_json(*[, by_alias, ref_template])

update_forward_refs(**localns)

validate(value)

validate_parameter_names(v)

Validate that parameter_names is a non-empty list.

validate_positive_float(v)

Validate that float value is positive.

validate_positive_int(v)

Validate that integer value is positive.

validate_val_prop_range(v)

Validate that validation proportion is in (0, 1).

Attributes

model_computed_fields

model_config

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_extra

Get extra fields set during validation.

model_fields

model_fields_set

Returns the set of fields that have been explicitly set on this model instance.

posterior_file

output_dir

parameter_names

num_epochs

learning_rate

max_patience

nn_depth

nn_block_dim

flow_layers

invert

cond_dim

max_samples

seed

plot_corner

plot_losses

flow_type

nn_width

standardize

standardization_method

transformer

transformer_knots

transformer_interval

val_prop

batch_size

batch_size: int#
cond_dim: int | None#
flow_layers: int#
flow_type: Literal['block_neural_autoregressive_flow', 'masked_autoregressive_flow', 'coupling_flow']#
classmethod from_yaml(filepath)[source]#

Load configuration from a YAML file.

Parameters:

filepath (str | Path) – Path to YAML configuration file

Return type:

FlowTrainingConfig

Returns:

FlowTrainingConfig instance with loaded configuration

Example

>>> config = FlowTrainingConfig.from_yaml("config.yaml")
invert: bool#
learning_rate: float#
max_patience: int#
max_samples: int#
nn_block_dim: int#
nn_depth: int#
nn_width: int#
num_epochs: int#
output_dir: str#
parameter_names: list[str]#
plot_corner: bool#
plot_losses: bool#
posterior_file: str#
seed: int#
standardization_method: Literal['zscore', 'minmax']#
standardize: bool#
transformer: Literal['affine', 'rational_quadratic_spline']#
transformer_interval: float#
transformer_knots: int#
val_prop: float#
classmethod validate_parameter_names(v)[source]#

Validate that parameter_names is a non-empty list.

Return type:

list[str]

classmethod validate_positive_float(v)[source]#

Validate that float value is positive.

Return type:

float

classmethod validate_positive_int(v)[source]#

Validate that integer value is positive.

Return type:

int

classmethod validate_val_prop_range(v)[source]#

Validate that validation proportion is in (0, 1).

Return type:

float