Incorporating 21cmFAST

For those that have 21cmFAST installed in their environment and would like to use it to analytically calculate the UV LFs and tau_e, we provide an example here comparing the analytic result to the emulated one.

[1]:
%matplotlib inline
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import py21cmfast as p21
from matplotlib import rcParams
from scipy.interpolate import interp1d

from py21cmemu import Emulator, EmulatorInput
from py21cmemu.properties import COSMO_PARAMS, FLAG_OPTIONS, USER_PARAMS

rcParams.update({"font.size": 30})
/home/dani/anaconda3/envs/21cmEMU/lib/python3.9/site-packages/py21cmfast/_cfg.py:57: UserWarning: Your configuration file is out of date. Updating...
  warnings.warn(
/home/dani/anaconda3/envs/21cmEMU/lib/python3.9/site-packages/py21cmfast/_cfg.py:41: UserWarning: Your configuration file is out of date. Updating...
  warnings.warn("Your configuration file is out of date. Updating...")
[2]:
test_sample = np.load("Test_data_sample.npz")
print(list(test_sample.keys()))
test_params = test_sample["X_test"]
emu = Emulator()
normed_input_params, output, output_errors = emu.predict(test_params, verbose=True)
['X_test', 'parameters', 'limits', 'Ts', 'Tb', 'PS', 'tau', 'UVLFs', 'xHI']
4/4 [==============================] - 4s 129ms/step

Following the first tutorial, we obtain emulated summaries from the emulator stored in a variable called output.

[3]:
UVLFs_emu = output["UVLFs"]  # these are in log already
tau_emu = output["tau"]
xHI_emu = output["xHI"]
zs = output["redshifts"]
MUV_emu = output.Muv
[4]:
print(
    "The shape of the output of the UV LFs is %d samples with %d redshift bins (z = 6, 7, 8, 10) and %d M_UV bins. "
    % UVLFs_emu.shape
)
The shape of the output of the UV LFs is 100 samples with 4 redshift bins (z = 6, 7, 8, 10) and 16 M_UV bins.
[5]:
list(output.keys())
[5]:
['Tb', 'xHI', 'Ts', 'PS', 'tau', 'UVLFs']

Convert input units into 21cmFAST units and calculate UVLFs and tau analytically

[6]:
emu_in = EmulatorInput()
astro_param_keys = emu_in.astro_param_keys
input_in_p21_units = emu_in.make_param_array(test_params, normed=False)
[7]:
analytic_tau = np.zeros(len(test_params))
analytic_UVLFs = np.zeros(
    (len(test_params), 2, 4, 100)
)  # Shape is: Nsamples, [M_UV, LF], [z = 6,7,8,10], M_UV bins
[8]:
for i in range(len(test_params)):
    analytic_tau[i] = p21.wrapper.compute_tau(
        redshifts=zs,
        global_xHI=xHI_emu[i, :],
        cosmo_params=COSMO_PARAMS,
        user_params=USER_PARAMS,
    )
    astro_params_dict = dict()
    for j, k in enumerate(astro_param_keys):
        astro_params_dict[k] = input_in_p21_units[i, j]
    analytic_UVLFs[i, ...] = np.array(
        p21.wrapper.compute_luminosity_function(
            redshifts=[6, 7, 8, 10],
            astro_params=astro_params_dict,
            cosmo_params=COSMO_PARAMS,
            user_params=USER_PARAMS,
            flag_options=FLAG_OPTIONS,
        )
    )[[0, 2], :, :]
    if i == 0:
        print(
            "Output from uvlfs calculation has shape: \n",
            "3 (magnitude, halo mass, LF) %d redshift bins, and %d absolute magnitude bins."
            % analytic_UVLFs[i].shape[1:],
        )
/home/dani/anaconda3/envs/21cmEMU/lib/python3.9/site-packages/py21cmfast/_utils.py:400: UserWarning: The following parameters to FlagOptions are not supported: ['USE_VELS_AUX']
  warnings.warn(
Output from uvlfs calculation has shape:
 3 (magnitude, halo mass, LF) 4 redshift bins, and 100 absolute magnitude bins.

Plot Analytic 21cmFAST vs 21cmEMU

We define a function to perform the plotting.

[14]:
def plot_true_vs_emu(
    x_emu,
    y_emu,
    x_true,
    y_true,
    x_label,
    y_label,
    title=None,
    xlims=None,
    leg_loc=(0.5, 0.5),
    N=10,
    offset=0,
    cs=None,
):
    rcParams.update({"font.size": 30})
    rcParams.update({"figure.facecolor": "white"})
    if cs is None:
        cs = [
            "k",
            "lime",
            "b",
            "orange",
            "cyan",
            "magenta",
            "grey",
            "pink",
            "darkred",
            "coral",
        ]
    y_true_interp = np.zeros(y_emu.shape)
    for i in range(y_emu.shape[0]):
        y_true_interp[i] = interp1d(x_true[i], y_true[i], fill_value="extrapolate")(
            x_emu
        )
    y_diff = abs(y_true_interp - y_emu)
    fig, axs = plt.subplots(
        nrows=2,
        ncols=1,
        sharex=True,
        figsize=(14, 12),
        gridspec_kw=dict(height_ratios=[3, 2], hspace=0),
    )
    axs = axs.flatten()

    diff_err_z = np.nanpercentile(y_diff, [2.5, 16, 50, 84, 97.5], axis=0)
    if title is not None:
        axs[0].set_title(title)
    for i, c in zip(range(N), cs, strict=False):
        if i == N - 1:
            labels = ["21cmEMU", "21cmFAST"]
        else:
            labels = [None, None]
        axs[0].plot(
            x_true[i + offset, :], y_true[i + offset, :], lw=3, color=c, label=labels[1]
        )
        axs[1].plot(x_emu, y_diff[i + offset, :], ls="-.", alpha=0.5, lw=2, color=c)
        axs[0].plot(
            x_emu, y_emu[i + offset, :], lw=2, ls="-.", color=c, label=labels[0]
        )

    axs[0].legend(loc=leg_loc, frameon=False)  # framealpha=0.3)
    axs[1].plot(x_emu, diff_err_z[2, :], ls="--", lw=3, color="k", label=r"Median")
    axs[1].fill_between(
        x_emu, diff_err_z[1, :], diff_err_z[3, :], color="k", alpha=0.2, label=r"68% CI"
    )
    axs[1].fill_between(
        x_emu, diff_err_z[0, :], diff_err_z[4, :], color="k", alpha=0.1, label=r"95% CI"
    )

    handles = [
        mpatches.Patch(color="k", label="68% CI", alpha=0.3),
        mpatches.Patch(color="k", label="95% CI", alpha=0.1),
    ]
    plt.legend(handles=handles, loc=(0.2, 0.5), frameon=False)
    axs[0].set_ylabel(y_label)
    axs[1].set_ylabel(r"Abs Diff")
    axs[1].set_xlabel(x_label)
    plt.xlim(-20, -10)
    plt.tight_layout()
    plt.show()
[15]:
for z_bin in range(4):
    plot_true_vs_emu(
        MUV_emu,
        UVLFs_emu[:, z_bin, :],
        analytic_UVLFs[:, 0, z_bin, :],
        analytic_UVLFs[:, 1, z_bin, :],
        r"$M_{UV}$",
        r"log$_{10}\:\phi$ [Mpc$^{-3}$]",
        leg_loc=(0.3, 0.01),
        title="z = %d" % output.UVLF_redshifts[z_bin],
    )
../_images/tutorials_21cmFAST_tau_UVLFs_13_0.png
../_images/tutorials_21cmFAST_tau_UVLFs_13_1.png
../_images/tutorials_21cmFAST_tau_UVLFs_13_2.png
../_images/tutorials_21cmFAST_tau_UVLFs_13_3.png
[18]:
rcParams.update({"font.size": 20})
rcParams.update({"figure.facecolor": "white"})
plt.scatter(analytic_tau, tau_emu, marker=".", color="k")
line = np.linspace(0, max(analytic_tau), 10)
plt.plot(line, line, color="r", ls="--")
plt.xlabel(r"Analytic $\tau_e$")
plt.ylabel(r"Emulated $\tau_e$")
plt.tight_layout()
plt.show()
../_images/tutorials_21cmFAST_tau_UVLFs_14_0.png
[ ]: