jesterTOV.inference.flows.flow.Flow#
- class Flow(flow, metadata, flow_kwargs)[source]#
Bases:
objectWrapper class for flowjax normalizing flows with automatic standardization handling.
This class encapsulates a trained normalizing flow and handles data standardization transparently. When sampling, it automatically converts samples back to the original scale if standardization was used during training.
- Variables:
flow – The underlying flowjax flow model
metadata – Training metadata dictionary
flow_kwargs – Flow architecture kwargs
standardize – Whether standardization was used during training
data_bounds – Min/max bounds for each feature (if standardization was used)
Example
>>> # Load a trained flow >>> flow = Flow.from_directory("./models/gw170817/") >>> >>> # Sample in original scale (standardization handled automatically) >>> samples = flow.sample(jax.random.key(0), (1000,)) >>> >>> # Access metadata >>> print(f"Flow type: {flow.metadata['flow_type']}") >>> print(f"Standardized: {flow.standardize}")
Methods
__init__(flow, metadata, flow_kwargs)Initialize Flow wrapper.
destandardize_output(data)Convert standardized data back to original scale.
from_directory(output_dir)Load a trained flow from a directory.
log_prob(x)Evaluate log probability of data under the flow.
sample(key, shape)Sample from the flow and return in original scale.
standardize_input(data)Standardize input data using the method from training.
- destandardize_output(data)[source]#
Convert standardized data back to original scale.
Applies the inverse of the standardization method: - Z-score: x * std + mean - Min-max: x * (max - min) + min - None: identity (no-op)
- Parameters:
data (
Array) – Data in standardized space (z-score or [0, 1])- Return type:
- Returns:
Data in original scale (or unchanged if standardization not used)
Example
>>> standardized_data = jnp.array([[0.5, 0.5, 0.5, 0.5]]) >>> original = flow.destandardize_output(standardized_data)
- classmethod from_directory(output_dir)[source]#
Load a trained flow from a directory.
- Parameters:
output_dir (
str) – Directory containing flow_weights.eqx, flow_kwargs.json, metadata.json- Return type:
- Returns:
Flow instance with loaded model and metadata
Example
>>> flow = Flow.from_directory("./models/gw170817/")
- log_prob(x)[source]#
Evaluate log probability of data under the flow.
If standardization was used, input data is automatically standardized before evaluation and Jacobian correction is applied. If not, operations are identity (no-op).
The Jacobian correction accounts for the change of variables: - Z-score: log p(x) = log p(x_std) - sum(log(std)) - Min-max: log p(x) = log p(x_std) - sum(log(max - min)) - None: log p(x) = log p(x_std) (no correction)
- Parameters:
x (
Array) – Data in original scale, shape (n_samples, n_features). JAX array.- Return type:
- Returns:
Log probabilities as JAX array, shape (n_samples,)
Example
>>> data = jnp.array([[1.4, 1.3, 100, 200]]) >>> log_prob = flow.log_prob(data)
- sample(key, shape)[source]#
Sample from the flow and return in original scale.
If standardization was used during training, samples are automatically converted back to the original scale using the inverse transformation (z-score or min-max). If not, the transformation is identity (no-op).
- Parameters:
- Return type:
- Returns:
Samples in original scale as JAX array of shape (
*shape, n_features)
Example
>>> samples = flow.sample(jax.random.key(0), (1000,)) >>> print(samples.shape) # (1000, 4) for 4D flow
- standardize_input(data)[source]#
Standardize input data using the method from training.
Applies the same standardization method used during training: - Z-score: (x - mean) / std → mean=0, std=1 - Min-max: (x - min) / (max - min) → [0, 1] - None: identity (no-op)
- Parameters:
data (
Array) – Input data in original scale (JAX array)- Return type:
- Returns:
Standardized data (z-score, [0,1], or unchanged)
Example
>>> original_data = jnp.array([[1.4, 1.3, 100, 200]]) >>> standardized = flow.standardize_input(original_data)