21cmEMU v3: Full Emulator with 2D Power Spectrum

This tutorial demonstrates usage of 21cmEMUv3 emulator, including:

  1. 1D Summaries: 21-cm PS (\(\Delta^2_{21} (k)\) [mK\(^2\)]), global brightness temperature (\(\overline{T}_b\)), neutral fraction (\(\overline{x}_{\mathrm{HI}}\)), spin temperature (\(T_s\)), UV luminosity functions (UVLFs), and optical depth (\(\tau_e\))

  2. 2D Power Spectrum: The 21-cm power spectrum \(\Delta^2(k_\perp, k_\parallel)\) [mK\(^2\)] emulated by a diffusion model

  3. Uncertainty Quantification: Variance estimation from multiple diffusion model samples

The v3 emulator uses:

  • An LSTM-based encoder-decoder architecture for 1D summaries

  • A score-based diffusion model architecture for the 2D power spectrum.

Requirements: This tutorial assumes a GPU is available for 2D PS emulation.

If you use this emulator in your work, please cite Breitman+26 (arXiv: 2606.00219).

[1]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D


rcParams.update({'font.size': 14})

from py21cmemu import Emulator
/home/dbreitman/.conda/envs/pytorch_env/lib/python3.10/site-packages/torch/utils/_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.
  warnings.warn(

1. Load Test Databases

We load the test databases containing simulation outputs to compare against emulator predictions.

[2]:
# Load the 1D summaries test set
test_1d_path = 'test_database.h5'

with h5py.File(test_1d_path, 'r') as f:
    print("Test set keys:", list(f.keys()))

    # Input parameters (11 astrophysical params)
    test_params = np.array(f['input_params'])

    # 1D summaries
    test_xHI = np.array(f['xHI'])
    test_Tb = np.array(f['Tb'])
    test_Ts = np.array(f['Ts'])
    test_tau = np.array(f['tau'])
    test_PS_1D = np.array(f['PS_1D_seeds']) # shape is (Nparams, Nseeds, Nz, Nk)
    PS_redshifts = np.array(f['PS_redshifts'])
    k = np.array(f["k"])
    # UVLFs: (N, 7 z-bins, 60 magnitudes)
    test_LFs_raw = np.array(f['UVLFs'])

    # Axes
    redshifts = np.array(f['redshifts'])
    M_UV_all = np.array(f['M_UV'])
    LF_zs = np.array(f['UVLF_redshifts'])
    limits = np.array(f["limits"])
    ap = np.array(f["astro_param_keys"])

# Filter valid samples (no NaN inputs)
valid_mask = ~np.isnan(test_params.mean(axis=1))
print(f"Valid samples: {valid_mask.sum()} / {len(test_params)}")

test_params = test_params[valid_mask]
test_xHI = test_xHI[valid_mask]
test_Tb = test_Tb[valid_mask]
test_Ts = test_Ts[valid_mask]
test_tau = test_tau[valid_mask]
test_LFs_raw = test_LFs_raw[valid_mask]
test_PS_1D = test_PS_1D[valid_mask]
m = np.logical_and(M_UV_all < -10, M_UV_all > -20)
# Trim UVLFs to M_UV in [-20, -10]
M_UV = M_UV_all[m]  # Crop to 30 magnitudes
test_LFs = test_LFs_raw[:, :, m]  # (N, 7, 30)

print("\nData shapes:")
print(f"  Params: {test_params.shape}")
print(f"  xHI: {test_xHI.shape}")
print(f"  Tb: {test_Tb.shape}")
print(f"  Ts: {test_Ts.shape}")
print(f"  tau: {test_tau.shape}")
print(f"  UVLFs: {test_LFs.shape}")
print(f"  Redshifts: {redshifts.shape}")
Test set keys: ['M_UV', 'PS_1D_seeds', 'PS_redshifts', 'Tb', 'Ts', 'UVLF_redshifts', 'UVLFs', 'astro_param_keys', 'astro_param_labels', 'input_params', 'k', 'limits', 'redshifts', 'tau', 'xHI']
Valid samples: 80 / 80

Data shapes:
  Params: (80, 11)
  xHI: (80, 93)
  Tb: (80, 93)
  Ts: (80, 93)
  tau: (80,)
  UVLFs: (80, 7, 30)
  Redshifts: (93,)
[3]:
logged = [0,3,5,6,7,8]
for key, lim, idx in zip(ap, limits, range(len(ap))):
    print(key, lim, np.round(np.min(test_params[...,idx]),2), np.round(np.max(test_params[...,idx]),2))
b'F_STAR10' [-2.  -0.5] -1.81 -0.57
b'ALPHA_STAR' [0. 1.] 0.2 0.81
b't_STAR' [0.01 1.  ] 0.09 0.96
b'F_ESC10' [-3.  0.] -2.31 -0.05
b'ALPHA_ESC' [-1.  1.] -0.8 0.99
b'F_STAR7_MINI' [-4. -1.] -3.89 -1.1
b'F_ESC7_MINI' [-3. -1.] -2.89 -1.02
b'L_X' [38. 43.] 38.39 42.75
b'L_X_MINI' [39. 44.] 39.14 43.14
b'NU_X_THRESH' [ 100. 1500.] 118.92 1146.18
b'SIGMA_8' [0.75 0.85] 0.79 0.84

Load 2D Power Spectrum Test Database

The 2D PS test database contains power spectra computed at multiple redshifts.

Note: If the 2D PS test database is not available, the PS comparison sections will use emulator properties for k-grids and generate predictions without ground truth comparison.

[86]:
import os

# Try to load the 2D PS test database
test_ps2d_path = 'ps_2d_test_subsample.h5'
HAS_PS2D_DB = os.path.exists(test_ps2d_path)

if HAS_PS2D_DB:
    with h5py.File(test_ps2d_path, 'r') as f:
        print("2D PS test set keys:", list(f.keys()))

        ps_params = np.array(f['input_params'])
        ps_redshifts_all = np.array(f['redshifts'])
        kperp = np.array(f['kperp'])
        kpar = np.array(f['kpar_64'])  # 32 kpar bins

        # 2D PS: seeds (individual realisations) and means (averaged over seeds)
        ps_2d_seeds = np.array(f['PS_2D_64_seeds'])
        ps_2d_means = np.array(f['PS_2D_64_means'])

    # Select a subset of redshifts for 2D PS emulation (to save time)
    n_z_subset = 10
    z_indices = np.linspace(0, len(ps_redshifts_all) - 1, n_z_subset, dtype=int)
    ps_redshifts = ps_redshifts_all[z_indices]

    # Also subset the 2D PS data to match
    ps_2d_means = ps_2d_means[:, z_indices]
    ps_2d_seeds = ps_2d_seeds[:, z_indices]

    print("\n2D PS data shapes:")
    print(f"  Params: {ps_params.shape}")
    print(f"  All redshifts: {ps_redshifts_all.shape}")
    print(f"  Selected redshifts ({n_z_subset}): {ps_redshifts.shape}")
    print(f"  Selected z values: {ps_redshifts}")
    print(f"  kperp: {kperp.shape}")
    print(f"  kpar: {kpar.shape}")
    print(f"  PS seeds (subsetted): {ps_2d_seeds.shape}")
    print(f"  PS means (subsetted): {ps_2d_means.shape}")
else:
    print("2D PS test database not found.")
    print("PS sections will use emulator properties and test params from 1D database.")
    print("\nTo download the 2D PS test database, see the 21cmEMU documentation.")
    ps_params = test_params  # Use same params as 1D test
    ps_redshifts = None  # Will use emulator defaults
    ps_2d_means = None
    ps_2d_seeds = None
    kperp = None  # Will use emulator properties
    kpar = None   # Will use emulator properties
2D PS test set keys: ['N_modes', 'PS_2D_64_means', 'PS_2D_64_seeds', 'fnames', 'input_params', 'kpar_64', 'kperp', 'limits', 'param_keys', 'param_labels', 'redshifts']

2D PS data shapes:
  Params: (100, 11)
  All redshifts: (32,)
  Selected redshifts (10): (10,)
  Selected z values: [ 5.50105421  6.46050805  7.59002262  9.40138137 11.04865309 13.6898716
 16.07786482 19.92643643 23.39700184 28.98670671]
  kperp: (32,)
  kpar: (64,)
  PS seeds (subsetted): (100, 10, 32, 64)
  PS means (subsetted): (100, 10, 32, 64)

2. Initialize the v3 Emulator

We load the MH emulator without 2D PS emulation enabled for now, so set emulate_2d_ps=False:

[4]:
emu = Emulator(emulator='mh', emulate_2d_ps=False)

print(f"Emulator initialized: {emu.which_emulator}")
print(f"PS emulation: {emu.emulate_2d_ps}")
print("\nAvailable properties:")
print(f"  Redshifts: {emu.properties.zs.shape}")
print(f"  PS redshifts: {emu.properties.PS_zs.shape}")
print(f"  kperp: {emu.properties.kperp.shape}")
print(f"  kpar: {emu.properties.kpar.shape}")
Emulator initialized: mcg
PS emulation: False

Available properties:
  Redshifts: (93,)
  PS redshifts: (32,)
  kperp: (32,)
  kpar: (64,)

3. Run Emulator Predictions

We run the emulator on a subset of test parameters.

[5]:
# Select a subset for demonstration (to save time)
n_samples = min(10, len(test_params))
idx = np.random.choice(len(test_params), n_samples, replace=False)
idx = np.sort(idx)

params_subset = test_params[idx]

print(f"Running emulator on {n_samples} samples...")

# Run prediction with EM sampling for 2D PS
normed_params, output, errors = emu.predict(
    params_subset,
    verbose=True,
    ps_sampling_method='ode',  # probability flow ODE or Euler-Maruyama 'em' sampling supported
    n_realisations=10,  # Number of diffusion realisations

)

print(f"\nOutput keys: {list(output.keys())}")
Running emulator on 10 samples...

Output keys: ['Tb', 'xHI', 'Ts', 'tau', 'UVLFs', 'PS', 'PS_2D', 'PS_2D_samples', 'PS_2D_std', 'PS_2D_redshifts']
[6]:
# Extract emulated outputs
emu_xHI = output['xHI'].value
emu_Tb = output['Tb'].value
emu_Ts = output['Ts'].value
emu_tau = output['tau'].value
emu_UVLFs = output['UVLFs'].value
emu_PS = output["PS"].value

# Extract true outputs for comparison
true_xHI = test_xHI[idx]
true_Tb = test_Tb[idx]
true_Ts = test_Ts[idx]
true_tau = test_tau[idx]
true_UVLFs = test_LFs[idx]
true_PS = test_PS_1D[idx]

print(f"Emulated xHI shape: {emu_xHI.shape}")
print(f"Emulated Tb shape: {emu_Tb.shape}")
print(f"Emulated Ts shape: {emu_Ts.shape}")
print(f"Emulated tau shape: {emu_tau.shape}")
print(f"Emulated UVLFs shape: {emu_UVLFs.shape}")
Emulated xHI shape: (10, 93)
Emulated Tb shape: (10, 93)
Emulated Ts shape: (10, 93)
Emulated tau shape: (10,)
Emulated UVLFs shape: (10, 7, 30)

4. Compare 1D Summaries: True vs Emulated

We plot several test samples comparing the true (simulation) values with emulated predictions.

[7]:
def plot_true_vs_emu(x, y_true, y_emu, x_label,
                     y_label, y_diff = None,
                     xlims = None, ylims = None,
                     N = 10, offset = 0,
                     plot_realisations=False, FE=False,
                     logFE=False, floor = 1e-2,
                     cs = None, leg_loc = (0.5,0.5), cl_leg_loc = (0.6,0.5)):
    if cs is None:
        cs = ['k','lime','b', 'orange', 'cyan', 'magenta', 'grey', 'pink', 'darkred', 'coral']
    if plot_realisations:
        y_mean = np.nanmean(y_true, axis=1)          # (N_param, len(x)) — true mean
    else:
        y_mean = y_true
    if y_diff is None:
        if not FE:
            y_diff = np.abs(y_mean - y_emu)               # (N_param, len(x)) — mean vs emu residual
        else:
            if logFE:
                y_diff = log10_fractional_error(y_mean, y_emu, floor_log=floor)
            else:
                y_diff = fractional_error(y_mean, y_emu, floor=floor)

    fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 8),
                            gridspec_kw=dict(height_ratios=[3, 2], hspace=0))
    axs = axs.flatten()

    diff_err_z = np.nanpercentile(y_diff, [2.5, 16, 50, 84, 97.5], axis=0)
    for i, c in zip(range(N), cs):
        idx = i + offset
        last = (i == N - 1)
        if plot_realisations:
            # Individual realisations — thin, semi-transparent
            for j in range(y_true.shape[1]):
                axs[0].plot(x, y_true[idx, j], lw=0.7, color=c, alpha=0.15)

        # True mean — thick solid
        axs[0].plot(x, y_mean[idx], lw=2.5, color=c)

        # Emulator prediction — dashed
        axs[0].plot(x, y_emu[idx], lw=2, ls='--', color=c)

        # Per-parameter residual (true mean vs emu) in lower panel
        if not FE:
            axs[1].plot(x, np.abs(y_mean[idx] - y_emu[idx]), ls='--', alpha=0.5, lw=2, color=c)
        else:
            if logFE:
                axs[1].plot(x, log10_fractional_error(y_mean[idx], y_emu[idx], floor_log=floor), ls='--', alpha=0.5, lw=2, color=c)
            else:
                axs[1].plot(x, fractional_error(y_mean[idx], y_emu[idx], floor=floor), ls='--', alpha=0.5, lw=2, color=c)

    if plot_realisations:
        legend_handles = [
            Line2D([0], [0], color='k', lw=0.7, alpha=0.6, label='Realisations'),
            Line2D([0], [0], color='k', lw=2.5, label='True mean'),
            Line2D([0], [0], color='k', lw=2, ls='--', label='21cmEMU'),
        ]
    else:
        legend_handles = [
            Line2D([0], [0], color='k', lw=2.5, label='True mean' if plot_realisations else 'True'),
            Line2D([0], [0], color='k', lw=2, ls='--', label='21cmEMU'),
        ]
    axs[0].legend(handles=legend_handles, loc=leg_loc, frameon=False, fontsize = 22)
    axs[1].plot(x, diff_err_z[2], ls=':', lw=3, color='k', label='Median')
    axs[1].fill_between(x, diff_err_z[1], diff_err_z[3], color='k', alpha=0.2, label='68% CI')
    axs[1].fill_between(x, diff_err_z[0], diff_err_z[4], color='k', alpha=0.1, label='95% CI')

    handles = [mpatches.Patch(color='k', label='68% CL', alpha=0.3),
               mpatches.Patch(color='k', label='95% CL', alpha=0.1),
#                mpatches.Patch(color='r', label='Mean PS', alpha=0.1),
               Line2D([0], [0], color='k', lw=2, ls=':', label='Median'),]
    if cl_leg_loc is not None:
        axs[1].legend(handles=handles, loc=cl_leg_loc, frameon=False)

    axs[0].set_ylabel(y_label)
    axs[1].set_ylabel(r'Abs diff' if not FE else 'FE (%)')
    axs[1].set_xlabel(x_label)
    if xlims is not None:
        plt.xlim(xlims[0], xlims[1])
    else:
        plt.xlim(min(x), max(x))
    if ylims is not None:
        axs[0].set_ylim(ylims[0], ylims[1])
#         axs[1].set_ylim(-0.1,0.1)
        axs[0].set_yscale('log')
    return fig
[8]:
def fractional_error(true, pred, floor=1e-2):
    """Compute fractional error in percent
    with floor to avoid division by small numbers.


    FE = |true - pred| / max(true, floor) * 100
    """
    denom = true.copy()
    denom[np.abs(denom) < floor] = floor
    return np.abs((true - pred) / denom) * 100.0


def log10_fractional_error(true, pred, floor_log=1e-2):
    """FE on log10 PS: |log10(true) - log10(pred)| / max(|log10(true)|, floor_log) * 100.

    Uses |log10(true)| as denominator with a floor to handle PS ~ 1 (log10 ~ 0).
    Default floor_log=1e-2: pixels where |log10(PS)| < 1e-2 use 1e-2 as denominator.
    """
    log_true = np.log10(true)
    log_pred = np.log10(pred)
    denom = np.abs(log_true)
    denom[denom < floor_log] = floor_log
    return np.abs((log_true - log_pred) / denom) * 100.0
[9]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
k_idx = 11
fig = plot_true_vs_emu(PS_redshifts, true_PS[...,k_idx], emu_PS[...,k_idx], r'Redshift z', r'$\Delta^2_{21}$ [mK$^2$]',
                 xlims = [5., 30],
#                        y_diff = diff_random[...,k_idx],
                       plot_realisations=True,
                leg_loc = (0.62,0.44),
                      cl_leg_loc = (0.65, 0.17))
axs = fig.get_axes()
axs[0].text(0.8, 0.9,
            "k $\sim$ " + str(np.round(k[k_idx],1)) +" Mpc$^{-1}$",
            horizontalalignment='center',
            verticalalignment='center',
            transform=axs[0].transAxes)

# diff_err_z = np.nanpercentile(diff_mean[...,k_idx], [2.5, 16, 50, 84, 97.5], axis=0)
# axs[1].fill_between(PS_redshifts, diff_err_z[1], diff_err_z[3], color='r', alpha=0.2, label='Mean')
# axs[1].fill_between(PS_redshifts, diff_err_z[0], diff_err_z[4], color='r', alpha=0.1,)
# axs[1].plot(PS_redshifts, diff_err_z[2], ls=':', lw=3, color='r')
plt.tight_layout()
# plt.savefig("PS_true_vs_emu.pdf")
plt.show()
../_images/tutorials_v3_full_emulator_15_0.png
[10]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
k_idx = 11
fig = plot_true_vs_emu(PS_redshifts, true_PS[...,k_idx], emu_PS[...,k_idx], r'Redshift z', r'$\Delta^2_{21}$ [mK$^2$]',
                 xlims = [5., 30],
#                        y_diff = fe_random[...,k_idx],
                       FE=True, floor = 0.1, logFE=False,
                       plot_realisations=True,
                leg_loc = (0.62,0.44),
                      cl_leg_loc = (0.65, 0.17))
axs = fig.get_axes()
axs[0].text(0.8, 0.9,
            "k $\sim$ " + str(np.round(k[k_idx],1)) +" Mpc$^{-1}$",
            horizontalalignment='center',
            verticalalignment='center',
            transform=axs[0].transAxes)
# diff_err_z = np.nanpercentile(PS_mean_fe[...,k_idx], [2.5, 16, 50, 84, 97.5], axis=0)
# axs[1].fill_between(PS_redshifts, diff_err_z[1], diff_err_z[3], color='r', alpha=0.2, label='Mean')
# axs[1].fill_between(PS_redshifts, diff_err_z[0], diff_err_z[4], color='r', alpha=0.1,)
# axs[1].plot(PS_redshifts, diff_err_z[2], ls=':', lw=3, color='r')
# axs[1].set_ylim(0,50)
plt.tight_layout()
# plt.savefig("PS_true_vs_emu_FE.png")
plt.show()
../_images/tutorials_v3_full_emulator_16_0.png
[11]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
fig = plot_true_vs_emu(redshifts, true_xHI, emu_xHI, r'Redshift z', r'$\overline{x}_{\rm HI}$',
                       plot_realisations = False,
                 xlims = [5., 30],
                leg_loc = None,
                cl_leg_loc = (0.65, 0.3))

plt.tight_layout()
# plt.savefig("xHI_true_vs_emu.pdf", format = "pdf")
plt.show()
../_images/tutorials_v3_full_emulator_17_0.png
[12]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
fig = plot_true_vs_emu(redshifts, true_xHI, emu_xHI, r'Redshift z', r'$\overline{x}_{\rm HI}$',
                       plot_realisations = False, FE=True,
                 xlims = [5., 30],
                leg_loc = None,
                cl_leg_loc = (0.65, 0.3))
axs = fig.get_axes()
axs[1].set_ylim(-1,30)
plt.tight_layout()
# plt.savefig("xHI_true_vs_emu_FE.pdf")
plt.show()
../_images/tutorials_v3_full_emulator_18_0.png
[13]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
fig = plot_true_vs_emu(redshifts, true_Tb, emu_Tb,
                       r'Redshift z', r'$ \delta\overline{\rm{T}}_{\rm b}$ [mK]',
                       plot_realisations = False,
                 xlims = [5., 30],
                leg_loc = None,
                cl_leg_loc = (0.65, 0.3))

plt.tight_layout()
# plt.savefig("Tb_true_vs_emu.pdf", format = "pdf")
plt.show()
../_images/tutorials_v3_full_emulator_19_0.png
[14]:
from matplotlib import rcParams
rcParams.update({"font.size":25, "font.family": 'serif'})
fig = plot_true_vs_emu(redshifts, np.log10(true_Ts), np.log10(emu_Ts), r'Redshift z', r'$\log_{10}\overline{\rm{T}}_{\rm S}$ [K]',
                       plot_realisations = False,
                 xlims = [5., 30],
                leg_loc = None,
                cl_leg_loc = (0.65, 0.3))

plt.tight_layout()
# plt.savefig("Ts_true_vs_emu.pdf", format = "pdf")
plt.show()
/home/dbreitman/.conda/envs/pytorch_env/lib/python3.10/site-packages/numpy/lib/nanfunctions.py:1563: RuntimeWarning: All-NaN slice encountered
  return function_base._ureduce(a,
../_images/tutorials_v3_full_emulator_20_1.png
[15]:
true_UVLFs.shape, emu_UVLFs.shape, M_UV.shape
[15]:
((10, 7, 30), (10, 7, 30), (30,))
[16]:
for i in range(len(LF_zs)):
    fig = plot_true_vs_emu(M_UV, true_UVLFs[:,i,:], emu_UVLFs[:,i,:],
                     r'M$_{\rm UV}$', r'log$_{10}\Phi$ [Mpc$^{-3}$ Mag$^{-1}$]',
                     xlims = [-10, -20],
                           plot_realisations=False,
                      leg_loc = "lower left", cl_leg_loc = (0.65, 0.3))
    axs = fig.get_axes()
    axs[0].text(0.87, 0.85,
                r"z $\sim$"+f" {LF_zs[i]:n} ",
                horizontalalignment='center',
                verticalalignment='center',
                transform=axs[0].transAxes)
    plt.tight_layout()
#     plt.savefig(f"UVLFs_true_vs_emu{LF_zs[i]:n}.pdf", format = "pdf")
    plt.show()
../_images/tutorials_v3_full_emulator_22_0.png
../_images/tutorials_v3_full_emulator_22_1.png
../_images/tutorials_v3_full_emulator_22_2.png
../_images/tutorials_v3_full_emulator_22_3.png
../_images/tutorials_v3_full_emulator_22_4.png
../_images/tutorials_v3_full_emulator_22_5.png
../_images/tutorials_v3_full_emulator_22_6.png
[17]:
cs = ['k','lime','b', 'orange', 'cyan', 'magenta', 'grey', 'pink', 'darkred', 'coral']
[18]:
tau_bins = np.linspace(min(true_tau), np.nanpercentile(true_tau,95), 3)
tau_binned_fe = np.zeros(len(tau_bins))
tau_binned_fe_68 = np.zeros((len(tau_bins),2))
tau_binned_fe_95 = np.zeros((len(tau_bins),2))

tau_binned = np.zeros(len(tau_bins))
tau_binned_68 = np.zeros((len(tau_bins),2))
tau_binned_95 = np.zeros((len(tau_bins),2))
[19]:
tau_diff = abs((emu_tau - true_tau))
tau_frac_err = abs((emu_tau - true_tau) / true_tau) * 100.
[20]:
for i in range(len(tau_bins)-1):
    mask = np.logical_and(true_tau >= tau_bins[i], true_tau < tau_bins[i+1])
    low1, low, med, high, high1 = np.nanpercentile(tau_frac_err[mask], [2.5, 16,50,84, 97.5])
    tau_binned_fe[i] = med
    tau_binned_fe_68[i,:] = [low, high]
    tau_binned_fe_95[i,:] = [low1, high1]

    low1, low, med, high, high1 = np.nanpercentile(tau_diff[mask], [2.5, 16,50,84, 97.5])
    tau_binned[i] = med
    tau_binned_68[i,:] = [low, high]
    tau_binned_95[i,:] = [low1, high1]
[21]:
tau_frac_err2 = abs((emu_tau - true_tau) / true_tau) * 100.
tau_diff2 = abs(emu_tau - true_tau)
[22]:
plt.figure(figsize=(10,7))
rcParams.update({'font.size': 30})
plt.plot(tau_bins, tau_binned, color = 'k', lw = 2, ls = '--')
plt.fill_between(tau_bins, tau_binned_68[:,0], tau_binned_68[:,1], color = 'k', alpha = 0.2)
plt.fill_between(tau_bins, tau_binned_95[:,0], tau_binned_95[:,1], color = 'k', alpha = 0.1)
for i in range(10):
    plt.scatter(emu_tau[i], tau_diff2[i], color = cs[i], marker = 'o', zorder = 2)
#plt.scatter(10**tau_true, tau_frac_err, alpha = 0.1)
plt.xlabel(r'$\tau_e$', fontsize=35)
handles = [mpatches.Patch(color='k', label='68% CI', alpha = 0.1),
          mpatches.Patch(color='k', label='95% CI', alpha = 0.3),
          ]
#plt.legend(handles=handles, loc = (0.05,0.7), frameon = False, fontsize = 20)
plt.ylabel('Abs diff')
# plt.savefig('tau_binned.png', dpi = 300, bbox_inches = "tight")
plt.show()
../_images/tutorials_v3_full_emulator_28_0.png
[23]:
plt.figure(figsize=(10,7))
rcParams.update({'font.size': 30})
plt.plot(tau_bins, tau_binned_fe, color = 'k', lw = 2, ls = '--')
plt.fill_between(tau_bins, tau_binned_fe_68[:,0], tau_binned_fe_68[:,1], color = 'k', alpha = 0.2)
plt.fill_between(tau_bins, tau_binned_fe_95[:,0], tau_binned_fe_95[:,1], color = 'k', alpha = 0.1)
for i in range(10):
    plt.scatter(emu_tau[i], tau_frac_err2[i], color = cs[i], marker = 'o', zorder = 2)
#plt.scatter(10**tau_true, tau_frac_err, alpha = 0.1)
plt.xlabel(r'$\tau_e$', fontsize=35)
handles = [mpatches.Patch(color='k', label='68% CI', alpha = 0.1),
          mpatches.Patch(color='k', label='95% CI', alpha = 0.3),
          ]
#plt.legend(handles=handles, loc = (0.05,0.7), frameon = False, fontsize = 20)
plt.ylabel('FE (%)')
# plt.savefig('tau_binned_FE.pdf', dpi = 300, bbox_inches = "tight")
plt.show()
../_images/tutorials_v3_full_emulator_29_0.png

Compute Fractional Errors

[24]:
# Compute FEs
fe_xHI = fractional_error(true_xHI, emu_xHI, floor=0.01)
fe_Tb = fractional_error(true_Tb, emu_Tb, floor=5.0)  # 5 mK floor for Tb
fe_Ts = fractional_error(np.log10(true_Ts + 1e-3), np.log10(emu_Ts + 1e-3), floor=0.01)
fe_tau = fractional_error(true_tau, emu_tau, floor=0.01)

print("Fractional Error Statistics (%)")
print("=" * 50)
for name, fe in [('xHI', fe_xHI), ('Tb', fe_Tb), ('Ts', fe_Ts), ('tau', fe_tau)]:
    finite = fe[np.isfinite(fe)]
    print(f"{name:8s}: median={np.median(finite):6.2f}%  "
          f"68th={np.percentile(finite, 68):6.2f}%  "
          f"95th={np.percentile(finite, 95):6.2f}%")
Fractional Error Statistics (%)
==================================================
xHI     : median=  0.03%  68th=  0.09%  95th=  2.08%
Tb      : median=  0.60%  68th=  0.91%  95th=  4.53%
Ts      : median=  0.18%  68th=  0.37%  95th=  1.22%
tau     : median=  0.26%  68th=  0.38%  95th=  0.65%

5. 2D Power Spectrum Emulation

The v3 emulator uses a score-based diffusion model to emulate the 2D power spectrum. This allows:

  • Probabilistic predictions with multiple samples

  • Variance and covariance estimation from the sample distribution

  • Two sampling methods: Euler-Maruyama (EM) and Probability-flow ODE

[26]:
# Select a few parameters for 2D PS demonstration
n_ps_samples = 2
ps_test_idx = np.random.choice(len(ps_params), n_ps_samples, replace=False)

# Prepare parameters - we need to combine astro params with redshift
# The emulator expects 11 astro params; redshift is added internally
ps_test_params = ps_params[ps_test_idx]

print(f"Testing 2D PS on {n_ps_samples} parameter sets")
print(f"Parameter shape: {ps_test_params.shape}")
if ps_redshifts is not None:
    print(f"Using {len(ps_redshifts)} test set redshifts: {ps_redshifts}")

# Run with more samples for variance estimation
# Pass ps_redshifts to match the test set redshifts
_, ps_output, ps_errors = emu.predict(
    ps_test_params,
    verbose=True,
    ps_2d_redshifts=ps_redshifts,  # Use test set redshifts for comparison
    ps_sampling_method='ode',
    n_ps_batch=5,
    n_realisations=100  # More samples for better variance estimate
)
Testing 2D PS on 2 parameter sets
Parameter shape: (2, 11)
Using 10 test set redshifts: [ 5.50105421  6.46050805  7.59002262  9.40138137 11.04865309 13.6898716
 16.07786482 19.92643643 23.39700184 28.98670671]
Computing 2D PS: 100%|██████████| 4/4 [01:55<00:00, 28.93s/batch]
[35]:
# Access PS outputs
# PS_2D shape: (n_params, n_redshifts, 32, 64) - median of diffusion samples
ps_emu = 10**ps_output['PS_2D'].value  # Delat_{21}^2 [mK^2] 2D PS median
print(f"Emulated 2D PS shape (median): {ps_emu.shape}")

# Get the actual redshifts used for 2D PS emulation
ps_emu_zs = output.PS_2D_redshifts if hasattr(output, 'PS_2D_redshifts') else ps_redshifts
print(f"2D PS redshifts: {ps_emu_zs}")

# Also access 1D PS from LSTM
ps_1d = ps_output['PS'].value
print(f"1D PS shape: {ps_1d.shape}")
Emulated 2D PS shape (median): (2, 10, 32, 64)
2D PS redshifts: [ 5.5         6.97446005  7.54906604  7.9582024   9.82883407 10.36152691
 10.63860385 16.66170964 19.52022545 24.10859229]
1D PS shape: (2, 32, 32)
[47]:
# Plot 2D PS comparison for one parameter set at one redshift
param_idx = 0
z_idx = 3  # Middle redshift in our subset (0 to n_z_subset-1)

# Get k-grids from emulator properties
kperp_emu = emu.properties.kperp
kpar_emu = emu.properties.kpar

if HAS_PS2D_DB and ps_2d_means is not None:
    # Get true PS from database (already subsetted to match ps_redshifts)
    true_ps_2d_sel = ps_2d_means[ps_test_idx]  # Shape: (n_ps_samples, n_z_subset, 32, 64)
    emu_ps_2d_sel = ps_emu  # Shape: (n_ps_samples, n_z_subset, 32, 64)

    print(f"True PS shape: {true_ps_2d_sel.shape}")
    print(f"Emulated PS shape: {emu_ps_2d_sel.shape}")
    print(f"Comparing at z = {ps_redshifts[z_idx]:.2f}")

    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # True PS at selected redshift
    true_ps = true_ps_2d_sel[param_idx, z_idx]
    emu_ps = emu_ps_2d_sel[param_idx, z_idx]

    vmin = np.percentile(true_ps, 5)
    vmax = np.percentile(true_ps, 95)

    # Row 1: PS values
    im0 = axes[0, 0].pcolormesh(kperp_emu, kpar_emu, true_ps.T,
                                 vmin=vmin, vmax=vmax, cmap='inferno')
    axes[0, 0].set_title('True (simulation)', fontsize=14)
    axes[0, 0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[0, 0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[0, 0].set_xscale('log')
    plt.colorbar(im0, ax=axes[0, 0], label=r'$\Delta^2_{21}$ [mK$^2$]')

    im1 = axes[0, 1].pcolormesh(kperp_emu, kpar_emu, emu_ps.T,
                                 vmin=vmin, vmax=vmax, cmap='inferno')
    axes[0, 1].set_title('Emulated (diffusion)', fontsize=14)
    axes[0, 1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[0, 1].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[0, 1].set_xscale('log')
    plt.colorbar(im1, ax=axes[0, 1], label=r'$\Delta^2_{21}$ [mK$^2$]')

    # Difference
    diff = emu_ps - true_ps
    vlim = np.max(np.abs(diff))
    im2 = axes[0, 2].pcolormesh(kperp_emu, kpar_emu, diff.T,
                                 vmin=-vlim, vmax=vlim, cmap='RdBu_r')
    axes[0, 2].set_title('Difference (log)', fontsize=14)
    axes[0, 2].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[0, 2].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[0, 2].set_xscale('log')
    plt.colorbar(im2, ax=axes[0, 2], label=r'$\Delta \log_{10}$ PS')

    # Row 2: FE and 1D slices
    fe_ps = fractional_error(true_ps, emu_ps, floor=0.01)

    im3 = axes[1, 0].pcolormesh(kperp_emu, kpar_emu, fe_ps.T,
                                 vmin=0, vmax=min(100, np.percentile(fe_ps, 95)),
                                 cmap='hot_r')
    axes[1, 0].set_title(f'Fractional Error (mean: {np.mean(fe_ps):.1f}%)', fontsize=14)
    axes[1, 0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[1, 0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[1, 0].set_xscale('log')
    plt.colorbar(im3, ax=axes[1, 0], label='FE [%]')

    # 1D slice along kperp (fixed kpar)
    kpar_slice_idx = 16  # Middle kpar
    axes[1, 1].plot(kperp_emu, true_ps[:, kpar_slice_idx], 'k-', lw=2, label='True')
    axes[1, 1].plot(kperp_emu, emu_ps[:, kpar_slice_idx], 'r--', lw=2, label='Emulated')
    axes[1, 1].set_xscale('log')
    axes[1, 1].set_yscale('log')
    axes[1, 1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[1, 1].set_ylabel(r'PS [mK$^2$]')
    axes[1, 1].set_title(f'1D slice at $k_\\parallel$ = {kpar_emu[kpar_slice_idx]:.3f} Mpc$^{{-1}}$', fontsize=14)
    axes[1, 1].legend()

    # FE histogram
    axes[1, 2].hist(fe_ps.ravel(), bins=50, alpha=0.7, density=True)
    axes[1, 2].axvline(np.mean(fe_ps), color='r', ls='--', lw=2,
                       label=f'Mean: {np.mean(fe_ps):.1f}%')
    axes[1, 2].axvline(np.median(fe_ps), color='b', ls='--', lw=2,
                       label=f'Median: {np.median(fe_ps):.1f}%')
    axes[1, 2].set_xlabel('Fractional Error [%]')
    axes[1, 2].set_ylabel('Density')
    axes[1, 2].set_title('FE Distribution', fontsize=14)
    axes[1, 2].legend()

    plt.suptitle(f'2D Power Spectrum at z = {ps_redshifts[z_idx]:.1f}', fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show()

else:
    # Without database, just show the emulated PS
    kperp_emu = emu.properties.kperp
    kpar_emu = emu.properties.kpar
    emu_ps = ps_emu[param_idx, z_idx]

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    vmin = np.percentile(np.log10(emu_ps + 1e-10), 5)
    vmax = np.percentile(np.log10(emu_ps + 1e-10), 95)

    im0 = axes[0].pcolormesh(kperp_emu, kpar_emu, np.log10(emu_ps).T,
                              vmin=vmin, vmax=vmax, cmap='inferno')
    axes[0].set_title('Emulated 2D PS', fontsize=14)
    axes[0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[0].set_xscale('log')
    plt.colorbar(im0, ax=axes[0], label=r'$\Delta^2_{21}$ [mK$^2$]')

    # 1D slice along kperp (fixed kpar)
    kpar_slice_idx = 16  # Middle kpar
    axes[1].plot(kperp_emu, emu_ps[:, kpar_slice_idx], 'r-', lw=2, label='Emulated')
    axes[1].set_xscale('log')
    axes[1].set_yscale('log')
    axes[1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[1].set_ylabel(r'PS [mK$^2$]')
    axes[1].set_title(f'1D slice at $k_\\parallel$ = {kpar_emu[kpar_slice_idx]:.3f} Mpc$^{{-1}}$', fontsize=14)
    axes[1].legend()

    # 1D slice along kpar (fixed kperp)
    kperp_slice_idx = 16  # Middle kperp
    axes[2].plot(kpar_emu, emu_ps[kperp_slice_idx, :], 'r-', lw=2, label='Emulated')
    axes[2].set_yscale('log')
    axes[2].set_xlabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[2].set_ylabel(r'$\Delta^2_{21}$ [mK$^2$]')
    axes[2].set_title(f'1D slice at $k_\\perp$ = {kperp_emu[kperp_slice_idx]:.3f} Mpc$^{{-1}}$', fontsize=14)
    axes[2].legend()

    z_val = ps_redshifts[z_idx] if ps_redshifts is not None else emu.properties.PS_zs[z_idx]
    plt.suptitle(f'Emulated 2D Power Spectrum at z = {z_val:.1f}', fontsize=16)
    plt.tight_layout()
    plt.show()

    print("\nNote: 2D PS test database not available - showing emulated PS only.")
True PS shape: (2, 10, 32, 64)
Emulated PS shape: (2, 10, 32, 64)
Comparing at z = 9.40
../_images/tutorials_v3_full_emulator_35_1.png

6. Uncertainty Quantification from Diffusion Samples

The diffusion model generates multiple samples for each input, allowing variance estimation.

[44]:
# Get the raw output which contains PS_2D_samples
ps_samples = ps_output.PS_2D_samples[:1]  # Shape: (1, n_z, n_samples, 32, 64)
print(f"PS samples shape: {ps_samples.shape}")

# Compute variance per pixel
ps_mean = np.mean(ps_samples, axis=2).value
ps_std = np.std(ps_samples, axis=2).value
ps_var = np.var(ps_samples, axis=2).value

print(f"PS mean shape: {ps_mean.shape}")
print(f"PS std shape: {ps_std.shape}")
PS samples shape: (1, 10, 100, 32, 64)
PS mean shape: (1, 10, 32, 64)
PS std shape: (1, 10, 32, 64)
[48]:
# Plot variance map for one redshift
z_idx = 5  # Use the same z_idx as the comparison plots

# Use emulator k-grids
kperp_emu = emu.properties.kperp
kpar_emu = emu.properties.kpar

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Mean PS
im0 = axes[0].pcolormesh(kperp_emu, kpar_emu, ps_mean[0, z_idx].T, cmap='inferno')
axes[0].set_title('Mean PS', fontsize=14)
axes[0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[0].set_xscale('log')
plt.colorbar(im0, ax=axes[0], label=r'$\Delta^2_{21}$ [mK$^2$]')

# Standard deviation
im1 = axes[1].pcolormesh(kperp_emu, kpar_emu, np.log10(ps_std[0, z_idx] + 1e-10).T, cmap='viridis')
axes[1].set_title('Standard Deviation', fontsize=14)
axes[1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[1].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[1].set_xscale('log')
plt.colorbar(im1, ax=axes[1], label=r'$\log_{10}$ $\sigma$')

# Coefficient of variation (std/mean)
cv = ps_std[0, z_idx] / (ps_mean[0, z_idx] + 1e-10)
im2 = axes[2].pcolormesh(kperp_emu, kpar_emu, cv.T, cmap='hot_r', vmin=0, vmax=1)
axes[2].set_title('Coefficient of Variation', fontsize=14)
axes[2].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[2].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[2].set_xscale('log')
plt.colorbar(im2, ax=axes[2], label='CV = σ/μ')

z_label = ps_redshifts[z_idx] if ps_redshifts is not None else z_idx
plt.suptitle(f'Variance from Diffusion Samples (z = {z_label:.1f})', fontsize=16)
plt.tight_layout()
plt.show()
../_images/tutorials_v3_full_emulator_38_0.png

Covariance and Correlation Analysis

The diffusion model samples allow us to compute the full covariance matrix between all \((k_\perp, k_\parallel)\) pixels. This reveals the correlations induced by the diffusion process.

[51]:
# Compute full covariance matrix from samples
# Flatten samples to (n_samples, n_pixels)
H, W = ps_samples.shape[-2], ps_samples.shape[-1]
n_samples_cov = ps_samples.shape[2]
n_pixels = H * W

samples_flat = ps_samples[0, z_idx].reshape(n_samples_cov, -1).value  # (n_samples, H*W)

# Center the samples
samples_centered = samples_flat - samples_flat.mean(axis=0, keepdims=True)

# Compute covariance matrix: Cov[i,j] = E[(X_i - μ_i)(X_j - μ_j)]
cov_matrix = (samples_centered.T @ samples_centered) / (n_samples_cov - 1)
print(f"Covariance matrix shape: {cov_matrix.shape}")

# Compute correlation matrix from covariance
std_vec = np.sqrt(np.diag(cov_matrix))
std_outer = np.outer(std_vec, std_vec)
corr_matrix = cov_matrix / (std_outer + 1e-10)

# Clip correlation to [-1, 1] due to numerical precision
corr_matrix = np.clip(corr_matrix, -1, 1)
print(f"Correlation matrix shape: {corr_matrix.shape}")

# Compute diagnostic statistics
diag_var = np.diag(cov_matrix)
total_var = np.sum(diag_var)
off_diag_var = np.sum(np.abs(cov_matrix)) - total_var
diag_frac = total_var / (total_var + off_diag_var)

off_diag_mask = ~np.eye(n_pixels, dtype=bool)
mean_abs_corr = np.mean(np.abs(corr_matrix[off_diag_mask]))
max_abs_corr = np.max(np.abs(corr_matrix[off_diag_mask]))

z_label = ps_redshifts[z_idx] if ps_redshifts is not None else z_idx
print(f"\nCovariance diagnostics at z = {z_label:.1f}:")
print(f"  Diagonal fraction: {diag_frac:.3f}")
print(f"  Mean |off-diag correlation|: {mean_abs_corr:.3f}")
print(f"  Max |off-diag correlation|: {max_abs_corr:.3f}")
Covariance matrix shape: (2048, 2048)
Correlation matrix shape: (2048, 2048)

Covariance diagnostics at z = 13.7:
  Diagonal fraction: 0.007
  Mean |off-diag correlation|: 0.146
  Max |off-diag correlation|: 0.967
[52]:
# Plot correlation matrix
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Full correlation matrix
im0 = axes[0].imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
axes[0].set_title(f'Pixel Correlation Matrix ({n_pixels}x{n_pixels})', fontsize=14)
axes[0].set_xlabel('Pixel index (flattened)')
axes[0].set_ylabel('Pixel index (flattened)')
plt.colorbar(im0, ax=axes[0], label='Correlation')

# Add grid lines to show 32x64 block structure
for i in range(0, n_pixels+1, W):
    axes[0].axhline(i-0.5, color='gray', alpha=0.3, lw=0.5)
    axes[0].axvline(i-0.5, color='gray', alpha=0.3, lw=0.5)

# Histogram of off-diagonal correlations
off_diag_corrs = corr_matrix[off_diag_mask]
axes[1].hist(off_diag_corrs, bins=100, alpha=0.7, density=True, color='steelblue')
axes[1].axvline(0, color='k', ls='--', lw=1)
axes[1].axvline(mean_abs_corr, color='r', ls='--', lw=2,
                label=f'Mean |corr| = {mean_abs_corr:.3f}')
axes[1].axvline(-mean_abs_corr, color='r', ls='--', lw=2)
axes[1].set_xlabel('Correlation coefficient', fontsize=12)
axes[1].set_ylabel('Density', fontsize=12)
axes[1].set_title('Distribution of Off-Diagonal Correlations', fontsize=14)
axes[1].legend(loc='upper right')
axes[1].set_xlim(-1, 1)

z_label = ps_redshifts[z_idx] if ps_redshifts is not None else z_idx
plt.suptitle(f'Correlation Structure (z = {z_label:.1f})', fontsize=16)
plt.tight_layout()
plt.show()
../_images/tutorials_v3_full_emulator_41_0.png
[53]:
# Spatial correlation maps: correlation of each pixel with reference pixels
kperp_emu = emu.properties.kperp
kpar_emu = emu.properties.kpar

# Define reference pixels: center and corner
center_idx = (H // 2) * W + (W // 2)  # Center pixel
corner_idx = 0  # Top-left corner
mid_k_idx = (H // 4) * W + (W // 2)  # Mid-k region

# Extract correlation with reference pixels and reshape to 2D
corr_with_center = corr_matrix[center_idx, :].reshape(H, W)
corr_with_corner = corr_matrix[corner_idx, :].reshape(H, W)
corr_with_mid = corr_matrix[mid_k_idx, :].reshape(H, W)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Correlation with center pixel
im0 = axes[0].pcolormesh(kperp_emu, kpar_emu, corr_with_center.T,
                          cmap='RdBu_r', vmin=-1, vmax=1)
axes[0].plot(kperp_emu[H//2], kpar_emu[W//2], 'k*', ms=15, label='Reference')
axes[0].set_title('Correlation with Center Pixel', fontsize=14)
axes[0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[0].set_xscale('log')
axes[0].legend()
plt.colorbar(im0, ax=axes[0], label='Correlation')

# Correlation with corner pixel (low-k)
im1 = axes[1].pcolormesh(kperp_emu, kpar_emu, corr_with_corner.T,
                          cmap='RdBu_r', vmin=-1, vmax=1)
axes[1].plot(kperp_emu[0], kpar_emu[0], 'k*', ms=15, label='Reference')
axes[1].set_title('Correlation with Low-k Corner', fontsize=14)
axes[1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[1].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[1].set_xscale('log')
axes[1].legend()
plt.colorbar(im1, ax=axes[1], label='Correlation')

# Correlation with mid-k pixel
im2 = axes[2].pcolormesh(kperp_emu, kpar_emu, corr_with_mid.T,
                          cmap='RdBu_r', vmin=-1, vmax=1)
axes[2].plot(kperp_emu[H//4], kpar_emu[W//2], 'k*', ms=15, label='Reference')
axes[2].set_title('Correlation with Mid-k Pixel', fontsize=14)
axes[2].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[2].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[2].set_xscale('log')
axes[2].legend()
plt.colorbar(im2, ax=axes[2], label='Correlation')

z_label = ps_redshifts[z_idx] if ps_redshifts is not None else z_idx
plt.suptitle(f'Spatial Correlation Maps (z = {z_label:.1f})', fontsize=16)
plt.tight_layout()
plt.show()

print("\nCorrelation length analysis:")
print(f"  Center pixel self-correlation: {corr_with_center[H//2, W//2]:.3f}")
print(f"  Center-corner correlation: {corr_with_center[0, 0]:.3f}")
print(f"  This indicates {'strong' if abs(corr_with_center[0, 0]) > 0.3 else 'weak'} "
      f"long-range correlations in the diffusion samples.")
../_images/tutorials_v3_full_emulator_42_0.png

Correlation length analysis:
  Center pixel self-correlation: 1.000
  Center-corner correlation: 0.024
  This indicates weak long-range correlations in the diffusion samples.

7. Comparison: EM vs ODE Sampling

The diffusion model supports two sampling methods:

  • Euler-Maruyama (EM): Stochastic sampling, faster, allows variance estimation

  • Probability-flow ODE: Deterministic sampling, slower but exact likelihood

[ ]:
import time

# Single parameter for comparison
test_single = ps_test_params[:1]

# EM sampling
t0 = time.time()
_, out_em, _ = emu.predict(test_single, ps_2d_redshifts=ps_redshifts,
                           ps_sampling_method='em', n_realisations=50, verbose=False)
t_em = time.time() - t0

# ODE sampling
t0 = time.time()
_, out_ode, _ = emu.predict(test_single, ps_2d_redshifts=ps_redshifts,
                            ps_sampling_method='ode', n_realisations=50, verbose=False)
t_ode = time.time() - t0

print(f"EM sampling:  {t_em:.2f}s")
print(f"ODE sampling: {t_ode:.2f}s")
[ ]:
# Compare EM vs ODE - use same z_idx as before
z_idx_compare = 5  # Should match the subset of redshifts we selected

ps_em = out_em['PS_2D'][0, z_idx_compare]
ps_ode = out_ode['PS_2D'][0, z_idx_compare]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

vmin = min(np.log10(ps_em).min(), np.log10(ps_ode).min())
vmax = max(np.log10(ps_em).max(), np.log10(ps_ode).max())

kperp_emu = emu.properties.kperp
kpar_emu = emu.properties.kpar

im0 = axes[0].pcolormesh(kperp_emu, kpar_emu, np.log10(ps_em).T,
                          vmin=vmin, vmax=vmax, cmap='inferno')
axes[0].set_title(f'EM Sampling ({t_em:.1f}s)', fontsize=14)
axes[0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[0].set_xscale('log')
plt.colorbar(im0, ax=axes[0], label=r'$\log_{10}$ PS')

im1 = axes[1].pcolormesh(kperp_emu, kpar_emu, np.log10(ps_ode).T,
                          vmin=vmin, vmax=vmax, cmap='inferno')
axes[1].set_title(f'ODE Sampling ({t_ode:.1f}s)', fontsize=14)
axes[1].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[1].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[1].set_xscale('log')
plt.colorbar(im1, ax=axes[1], label=r'$\log_{10}$ PS')

# Difference
diff = np.log10(ps_em) - np.log10(ps_ode)
vlim = np.percentile(np.abs(diff), 95)
im2 = axes[2].pcolormesh(kperp_emu, kpar_emu, diff.T,
                          vmin=-vlim, vmax=vlim, cmap='RdBu_r')
axes[2].set_title(f'EM - ODE (RMS: {np.sqrt(np.mean(diff**2)):.3f})', fontsize=14)
axes[2].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
axes[2].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
axes[2].set_xscale('log')
plt.colorbar(im2, ax=axes[2], label=r'$\Delta \log_{10}$ PS')

z_label = ps_redshifts[z_idx_compare] if ps_redshifts is not None else z_idx_compare
plt.suptitle(f'EM vs ODE Sampling Comparison (z = {z_label:.1f})', fontsize=16)
plt.tight_layout()
plt.show()

Using the MHEmulatorErrors Class

The MH (v3) emulator provides error statistics via the MHEmulatorErrors class, which differs significantly from ACG/Radio emulators:

Key Differences:

  • Absolute errors in physical units (not fractional %)

  • Output-dependent errors: computed from FE% × |output_value|

  • 2D PS error statistics: variance, covariance, and correlation matrices

  • Sampling method support: different errors for ‘em’ vs ‘ode’ sampling

[ ]:
# Get the error object from the prediction
_, output_single, errors = emu.predict(test_params[0:1], verbose=False)

# Inspect the error object
print(f"Error type: {type(errors).__name__}")
print(f"\nAvailable error fields: {errors.keys()}")
print(f"\nError summary:\n{errors.summary()}")

Absolute Errors with Physical Units

Unlike ACG/Radio which store FE%, MHEmulatorErrors provides absolute errors with proper astropy units. For log-space quantities (PS, UVLFs), errors are in dex (log10 units).

[ ]:
# Check units of each error field
print("Error field units:")
print(f"  PS_err: {errors.PS_err.unit} (dex - error on log10 PS)")
print(f"  Tb_err: {errors.Tb_err.unit} (absolute mK error)")
print(f"  xHI_err: {errors.xHI_err.unit} (absolute neutral fraction error)")
print(f"  Ts_err: {errors.Ts_err.unit} (absolute K error)")
print(f"  tau_err: {errors.tau_err.unit} (absolute optical depth error)")
print(f"  UVLFs_logerr: {errors.UVLFs_logerr.unit} (dex - error on log10 LF)")

# Median PS error interpretation
median_ps_err = np.nanmedian(errors.PS_err.value)
print(f"\nMedian PS error: {median_ps_err:.3f} dex")
print(f"  This means log10(PS) predictions are typically off by ~{median_ps_err:.3f}")
print(f"  In linear PS, this is a factor of 10^{median_ps_err:.3f}{10**median_ps_err:.2f} ({(10**median_ps_err - 1)*100:.0f}% error)")

2D Power Spectrum Error Statistics (Unique to MH Emulator)

The MH emulator’s diffusion model provides advanced error statistics for the 2D power spectrum:

  • Variance: Per-bin error variance from test set residuals

  • Covariance: Full covariance matrix between (kperp, kpar) bins

  • Correlation diagnostics: ps_diagonal_fraction, ps_mean_abs_correlation

[ ]:
# Access 2D PS error statistics via the errors object
ps_var = errors.get_ps_variance()
ps_cov = errors.get_ps_covariance()

if ps_var is not None:
    print(f"2D PS variance shape: {ps_var.shape}")
    print(f"2D PS covariance shape: {ps_cov.shape if ps_cov is not None else 'N/A'}")
    print("\nCorrelation diagnostics:")
    print(f"  Diagonal fraction: {errors.ps_diagonal_fraction:.2%}")
    print("  (1.0 = uncorrelated errors, <1.0 = some correlation)")
    print(f"  Mean |off-diag correlation|: {errors.ps_mean_abs_correlation:.3f}")
    print("  (0 = uncorrelated, 1 = perfectly correlated)")
else:
    print("2D PS statistics not available (emulator loaded without 2D PS)")
[ ]:
# Visualize 2D PS variance and correlation matrix
if ps_var is not None and ps_cov is not None:
    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    # Get k-grids from emulator properties
    kperp = emu.properties.kperp
    kpar = emu.properties.kpar

    # Plot variance map
    im0 = axes[0].pcolormesh(kperp, kpar, ps_var.T, cmap='viridis')
    axes[0].set_xlabel(r'$k_\perp$ [Mpc$^{-1}$]')
    axes[0].set_ylabel(r'$k_\parallel$ [Mpc$^{-1}$]')
    axes[0].set_xscale('log')
    axes[0].set_title('2D PS Error Variance (FE%²)')
    plt.colorbar(im0, ax=axes[0], label='Variance')

    # Plot covariance matrix (subsampled for visibility)
    step = 8  # Subsample for clearer visualization
    cov_sub = ps_cov[::step, ::step]
    im1 = axes[1].imshow(cov_sub, cmap='RdBu_r', aspect='equal',
                         vmin=-np.percentile(np.abs(cov_sub), 95),
                         vmax=np.percentile(np.abs(cov_sub), 95))
    axes[1].set_title(f'Covariance Matrix (subsampled {step}x)')
    axes[1].set_xlabel('Pixel index')
    axes[1].set_ylabel('Pixel index')
    plt.colorbar(im1, ax=axes[1], label='Covariance')

    # Compute and plot correlation matrix
    std = np.sqrt(np.diag(ps_cov))
    std_outer = np.outer(std, std)
    corr = ps_cov / std_outer
    corr_sub = corr[::step, ::step]

    im2 = axes[2].imshow(corr_sub, cmap='RdBu_r', aspect='equal', vmin=-1, vmax=1)
    axes[2].set_title('Correlation Matrix (subsampled)')
    axes[2].set_xlabel('Pixel index')
    axes[2].set_ylabel('Pixel index')
    plt.colorbar(im2, ax=axes[2], label='Correlation')

    plt.suptitle('MH Emulator: 2D PS Error Statistics', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()
else:
    print("Load emulator with emulate_2d_ps=True to access these statistics")

Summary

This tutorial demonstrated the v3 (MH) emulator capabilities:

  1. 1D Summaries: Global \(T_b\), \(x_{\mathrm{HI}}\), \(T_s\), UVLFs, and \(\tau_e\) emulated by LSTM networks

  2. 2D Power Spectrum: \(P(k_\perp, k_\parallel)\) emulated by a score-based diffusion model

  3. Uncertainty Quantification:

    • Variance/standard deviation maps from diffusion samples

    • Full covariance matrix computation between PS pixels

    • Correlation structure analysis and spatial correlation maps

  4. Sampling Methods: EM (fast, stochastic) vs ODE (slower, deterministic)

The emulator achieves:

  • Sub-percent median fractional errors on most 1D summaries

  • ~10% median FE on 2D PS for most k-modes

  • Full probabilistic uncertainty through diffusion sampling

  • Pixel-by-pixel covariance for downstream statistical analyses