jesterTOV.inference.flows.flow.Flow#

class Flow(flow, metadata, flow_kwargs)[source]#

Bases: object

Wrapper class for flowjax normalizing flows with automatic standardization handling.

This class encapsulates a trained normalizing flow and handles data standardization transparently. When sampling, it automatically converts samples back to the original scale if standardization was used during training.

Variables:
  • flow – The underlying flowjax flow model

  • metadata – Training metadata dictionary

  • flow_kwargs – Flow architecture kwargs

  • standardize – Whether standardization was used during training

  • data_bounds – Min/max bounds for each feature (if standardization was used)

Example

>>> # Load a trained flow
>>> flow = Flow.from_directory("./models/gw170817/")
>>>
>>> # Sample in original scale (standardization handled automatically)
>>> samples = flow.sample(jax.random.key(0), (1000,))
>>>
>>> # Access metadata
>>> print(f"Flow type: {flow.metadata['flow_type']}")
>>> print(f"Standardized: {flow.standardize}")
__init__(flow, metadata, flow_kwargs)[source]#

Initialize Flow wrapper.

Parameters:
  • flow (AbstractDistribution) – Trained flowjax flow model

  • metadata (Dict[str, Any]) – Training metadata

  • flow_kwargs (Dict[str, Any]) – Flow architecture kwargs

Methods

__init__(flow, metadata, flow_kwargs)

Initialize Flow wrapper.

destandardize_output(data)

Convert standardized data back to original scale.

from_directory(output_dir)

Load a trained flow from a directory.

log_prob(x)

Evaluate log probability of data under the flow.

sample(key, shape)

Sample from the flow and return in original scale.

standardize_input(data)

Standardize input data using the method from training.

destandardize_output(data)[source]#

Convert standardized data back to original scale.

Applies the inverse of the standardization method: - Z-score: x * std + mean - Min-max: x * (max - min) + min - None: identity (no-op)

Parameters:

data (Array) – Data in standardized space (z-score or [0, 1])

Return type:

Array

Returns:

Data in original scale (or unchanged if standardization not used)

Example

>>> standardized_data = jnp.array([[0.5, 0.5, 0.5, 0.5]])
>>> original = flow.destandardize_output(standardized_data)
classmethod from_directory(output_dir)[source]#

Load a trained flow from a directory.

Parameters:

output_dir (str) – Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json

Return type:

Flow

Returns:

Flow instance with loaded model and metadata

Example

>>> flow = Flow.from_directory("./models/gw170817/")
log_prob(x)[source]#

Evaluate log probability of data under the flow.

If standardization was used, input data is automatically standardized before evaluation and Jacobian correction is applied. If not, operations are identity (no-op).

The Jacobian correction accounts for the change of variables: - Z-score: log p(x) = log p(x_std) - sum(log(std)) - Min-max: log p(x) = log p(x_std) - sum(log(max - min)) - None: log p(x) = log p(x_std) (no correction)

Parameters:

x (Array) – Data in original scale, shape (n_samples, n_features). JAX array.

Return type:

Array

Returns:

Log probabilities as JAX array, shape (n_samples,)

Example

>>> data = jnp.array([[1.4, 1.3, 100, 200]])
>>> log_prob = flow.log_prob(data)
sample(key, shape)[source]#

Sample from the flow and return in original scale.

If standardization was used during training, samples are automatically converted back to the original scale using the inverse transformation (z-score or min-max). If not, the transformation is identity (no-op).

Parameters:
  • key (Array) – JAX random key (jax.Array)

  • shape (Tuple[int, ...]) – Shape of samples to generate (e.g., (1000,) for 1000 samples)

Return type:

Array

Returns:

Samples in original scale as JAX array of shape (*shape, n_features)

Example

>>> samples = flow.sample(jax.random.key(0), (1000,))
>>> print(samples.shape)  # (1000, 4) for 4D flow
standardize_input(data)[source]#

Standardize input data using the method from training.

Applies the same standardization method used during training: - Z-score: (x - mean) / std → mean=0, std=1 - Min-max: (x - min) / (max - min) → [0, 1] - None: identity (no-op)

Parameters:

data (Array) – Input data in original scale (JAX array)

Return type:

Array

Returns:

Standardized data (z-score, [0,1], or unchanged)

Example

>>> original_data = jnp.array([[1.4, 1.3, 100, 200]])
>>> standardized = flow.standardize_input(original_data)