Source code for cosmos20_colors.data_loader
"""Module implements the `load_cosmos20` function"""
import os
import numpy as np
from astropy.table import Table
from jax import numpy as jnp
COSMOS20_BASENAME = "COSMOS2020_Farmer_processed_hlin.fits"
SKY_AREA = 1.21 # square degrees
NANFILL = -999.0
__all__ = ("load_cosmos20",)
[docs]
def load_cosmos20(
drn=None, bn=COSMOS20_BASENAME, apply_cuts=True, mag_lo=-100, mag_hi=5
):
"""Load the COSMOS-20 dataset from disk and calculate quality cuts
Parameters
----------
drn : string, optional
Absolute path to directory containing .fits file storing COSMOS-20 dataset
Default value is os.environ['COSMOS20_DRN'].
For bash users, add the following line to your `.bash_profile` in order to
configure the package to use your default dataset location:
export COSMOS20_DRN="/drn/storing/COSMOS20"
bn : string, optional
Absolute path to directory containing .fits file storing COSMOS-20 dataset
Default value is COSMOS20_BASENAME set at top of module
apply_cuts : bool, optional
If True, returned Table will have quality cuts imposed on the data
Default is True
mag_lo : int, optional
Smallest absolute magnitude in any band before galaxy is considered unphysical
mag_hi : int, optional
Largest absolute magnitude in any band before galaxy is considered unphysical
Returns
-------
cat : astropy.table.Table
Table of length ngals
Notes
-----
Quality cuts include lp_type=0 for the `galaxies` flag.
And for every Mag in the Le Phare absolute magnitudes,
we require mag_lo < Mag < mag_hi
"""
if drn is None:
drn = os.environ["COSMOS20_DRN"]
fn = os.path.join(drn, bn)
cat = Table.read(fn, format="fits", hdu=1)
if apply_cuts:
cat_out = Table()
cuts = []
sel_galaxies = np.array(cat["lp_type"] == 0).astype(bool)
cuts.append(sel_galaxies)
lp_keys = [key for key in cat.keys() if "lp_M" in key]
for key in lp_keys:
x = np.nan_to_num(
cat[key], copy=True, nan=NANFILL, posinf=NANFILL, neginf=NANFILL
)
key_finite_msk = np.isfinite(x == NANFILL)
cuts.append(key_finite_msk)
cuts.append(x > mag_lo)
cuts.append(x < mag_hi)
msk = np.prod(cuts, axis=0).astype(bool)
for key in cat.keys():
cat_out[key] = jnp.array(cat[key][msk])
return cat_out
else:
return cat
return cat