# Makes print and division act like Python 3
from __future__ import print_function, division
# Import the usual libraries
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from copy import deepcopy
#from .nrc_utils import S, stellar_spectrum, jupiter_spec, cond_table, cond_filter
#from .nrc_utils import read_filter, bp_2mass, channel_select, coron_ap_locs
#from .nrc_utils import dist_image, pad_or_cut_to_size
from .nrc_utils import *
from .obs_nircam import obs_hci
#from .obs_nircam import plot_contrasts, plot_contrasts_mjup, planet_mags, plot_planet_patches
from tqdm.auto import tqdm, trange
import logging
_log = logging.getLogger('nb_funcs')
import pynrc
pynrc.setup_logging('WARN', verbose=False)
"""
Common functions for notebook simulations and plotting.
This is my attempt to standardize these routines over
all the various GTO programs.
"""
# Observation Definitions
# Functions to create and optimize a series of observation objects stored as a dictionary.
bp_k = bp_2mass('k')
[docs]def make_key(filter, pupil=None, mask=None):
"""Create identification key (string) based on filter, pupil, and mask"""
mask_key = 'none' if mask is None else mask
pupil_key = 'none' if pupil is None else pupil
key = '{}_{}_{}'.format(filter,mask_key,pupil_key)
return key
# Disk Models
[docs]def model_info(source, filt, dist, model_dir=''):
# base_dir = '/Volumes/NIRData/Andras_models_v2/'
# model_dir = base_dir + source + '/'
# Match filters with model
filt_switch = {'F182M':'F210M', 'F210M':'F210M', 'F250M':'F250M',
'F300M':'F300M', 'F335M':'F335M', 'F444W':'F444W'}
filt_model = filt_switch.get(filt, filt)
fname = source + '_' + filt_model +'sc.fits'
bp = read_filter(filt_model)
w0 = bp.avgwave() / 1e4
# Model pixels are 4x oversampled
detscale = (channel_select(bp))[0]
model_scale = detscale / 4.
# File name, arcsec/pix, dist (pc), wavelength (um), flux units, cen_star?
model_dict = {
'file' : model_dir+fname,
'pixscale' : model_scale,
'dist' : dist,
'wavelength' : w0,
'units' : 'Jy/pixel',
'cen_star' : True
}
# args_model = (model_dir+fname, model_scale, dist, w0, 'Jy/pixel', True)
return model_dict
[docs]def disk_rim_model(a_asec, b_asec, pa=0, sig_asec=0.1, flux_frac=0.5,
flux_tot=1.0, flux_units='mJy', wave_um=None, dist_pc=None,
pixsize=0.007, fov_pix=401):
"""
Simple geometric model of an inner disk rim that simply creates an
ellipsoidal ring with a brightness gradient along the major axis.
Parameters
----------
a_asec : float
Semi-major axis of ellipse
ba_asec : float
Semi-minor axis of ellipse
Keyword Args
------------
pa : float
Position angle of major axis
sig_asec : float
Sigma width of ring model
flux_frac : float
A brightness gradient can be applied along the semi-major axis.
This parameter dictates the relative brightness of the minimum flux
(at the center of the axis) compared to the flux at the out edge
of the geometric ring.
flux_tot : float
The total integrated flux of disk model.
flux_units : str
Units corresponding to `flux_tot`.
wave_um : float or None
Wavelength (in um) corresponding to `flux_tot`. Saved in output
FITS header unless the value is None.
dist_pc : float or None
Assumed distance of model (in pc). Saved in output FITS header
unless the value is None.
pixsize : float
Desired model pixel size in arcsec.
fov_pix : int
Number of pixels for x/y dimensions of output model data.
"""
from astropy.modeling.models import Ellipse2D
from astropy.convolution import Gaussian2DKernel, convolve_fft
from astropy.io import fits
# Get polar and cartesian pixel coordinate grid
sh = (fov_pix, fov_pix)
r_pix, th_ang = dist_image(np.ones(sh), return_theta=True)
x_pix, y_pix = rtheta_to_xy(r_pix, th_ang)
# In terms of arcsec
x_asec = pixsize * x_pix
y_asec = pixsize * y_pix
r_asec = pixsize * r_pix
# Semi major/minor axes (pix)
a_pix = a_asec / pixsize
b_pix = b_asec / pixsize
# Create ellipse functions
e1 = Ellipse2D(theta=0, a=a_pix+1, b=b_pix+1)
e2 = Ellipse2D(theta=0, a=a_pix-1, b=b_pix-1)
# Make the two ellipse images and subtract
e1_im = e1(x_pix,y_pix)
e2_im = e2(x_pix,y_pix)
e_im = e1_im - e2_im
# Produce a brightness gradient along major axis
grad_im = (1-flux_frac) * np.abs(x_pix) / a_pix + flux_frac
e_im = e_im * grad_im
# Convolve image with Gaussian to simulate scattering
sig_pix = sig_asec / pixsize
kernel = Gaussian2DKernel(sig_pix)
e_im = convolve_fft(e_im, kernel)
# Rotate
th_deg = pa - 90.
e_im = rotate_offset(e_im, angle=-th_deg, order=3, reshape=False)
e_im = flux_tot * e_im / np.sum(e_im)
hdu = fits.PrimaryHDU(e_im)
hdu.header['PIXELSCL'] = (pixsize, "Pixel Scale (asec/pix)")
hdu.header['UNITS'] = "{}/pixel".format(flux_units)
if wave_um is not None:
hdu.header['WAVE'] = (wave_um, "Wavelength (microns)")
if dist_pc is not None:
hdu.header['DISTANCE'] = (dist_pc, "Distance (pc)")
return fits.HDUList([hdu])
[docs]def obs_wfe(wfe_ref_drift, filt_list, sp_sci, dist, sp_ref=None, args_disk=None,
wind_mode='WINDOW', subsize=None, fov_pix=None, verbose=False, narrow=False,
model_dir=None, large_grid=False, **kwargs):
"""
For a given WFE drift and series of filters, create a list of
NIRCam observations.
"""
if sp_ref is None: sp_ref = sp_sci
obs_dict = {}
for filt, mask, pupil in filt_list:
# Create identification key
key = make_key(filt, mask=mask, pupil=pupil)
print(key)
# Disk Model
if args_disk is None:
args_disk_temp = None
elif 'auto' in args_disk:
# Convert to photons/sec in specified filter
args_disk_temp = model_info(sp_sci.name, filt, dist, model_dir=model_dir)
else:
args_disk_temp = args_disk
fov_pix_orig = fov_pix
# Define the subarray readout size
if 'FULL' in wind_mode: # Full frame
subuse = 2048
# Define PSF pixel size defaults
if mask is None:
fov_pix = 400 if fov_pix is None else fov_pix
elif ('210R' in mask) or ('SWB' in mask):
fov_pix = 640 if fov_pix is None else fov_pix
else:
fov_pix = 320 if fov_pix is None else fov_pix
elif subsize is None: # Window Mode defaults
if mask is None: # Direct Imaging
subuse = 400
elif ('210R' in mask) or ('SWB' in mask): # SW Coronagraphy
subuse = 640
else: # LW Coronagraphy
subuse = 320
else: # No effect if full frame
subuse = subsize
# Define PSF pixel size
fov_pix = subuse if fov_pix is None else fov_pix
# Make sure fov_pix is odd for direct imaging
# if (mask is None) and (np.mod(fov_pix,2)==0):
# fov_pix += 1
if np.mod(fov_pix,2)==0:
fov_pix += 1
# Other coronagraph vs direct imaging settings
module, oversample = ('B', 4) if mask is None else ('A', 2)
if narrow and ('SWB' in mask):
bar_offset=-8
elif narrow and ('LWB' in mask):
bar_offset=8
else:
bar_offset=None
# Initialize and store the observation
# A reference observation is stored inside each parent obs_hci class.
obs = obs_hci(sp_sci, dist, sp_ref=sp_ref, filter=filt, image_mask=mask, pupil_mask=pupil,
module=module, wind_mode=wind_mode, xpix=subuse, ypix=subuse,
wfe_ref_drift=wfe_ref_drift, fov_pix=fov_pix, oversample=oversample,
disk_params=args_disk_temp, verbose=verbose, bar_offset=bar_offset,
autogen_coeffs=False, **kwargs)
obs.gen_psf_coeff()
# Enable WFE drift
obs.gen_wfedrift_coeff()
# Enable mask-dependent
obs.gen_wfemask_coeff(large_grid=large_grid)
obs_dict[key] = obs
fov_pix = fov_pix_orig
# if there's a disk input, then we want to remove disk
# contributions from stellar flux and recompute to make
# sure total flux counts matches what we computed for
# sp_sci in previous section to match real photometry
if args_disk is not None:
obs = obs_dict[key]
star_flux = obs.star_flux(sp=sp_sci) # Pass original input spectrum
disk_flux = obs.disk_hdulist[0].data.sum()
obs.sp_sci = sp_sci * (1 - disk_flux / star_flux)
obs.sp_sci.name = sp_sci.name
if sp_ref is sp_sci:
obs.sp_ref = obs.sp_sci
# Generation mask position dependent PSFs
for key in tqdm(obs_dict.keys(), desc='Obs', leave=False):
obs_dict[key].gen_disk_psfs()
return obs_dict
[docs]def obs_optimize(obs_dict, sp_opt=None, well_levels=None, tacq_max=1800, **kwargs):
"""
Perform ramp optimization on each science and reference observation
in a list of filter observations. Updates the detector MULTIACCUM
settings for each observation in the dictionary.
snr_goal = 5
snr_frac = 0.02
tacq_max = 1400
tacq_frac = 0.01
nint_min = 15
ng_max = 10
"""
# A very faint bg object on which to maximize S/N
# If sp_opt is not set, then default to a 20th magnitude flat source
if sp_opt is None:
sp_opt = stellar_spectrum('flat', 20, 'vegamag', bp_k)
# Some observations may saturate, so define a list of maximum well level
# values that we will incrementally check until a ramp setting is found
# that meets the contraints.
if well_levels is None:
well_levels = [0.8, 1.5, 3.0, 5.0, 10.0, 20.0, 100.0, 150.0, 300.0, 500.0]
filt_keys = list(obs_dict.keys())
filt_keys.sort()
print(['Pattern', 'NGRP', 'NINT', 't_int', 't_exp', 't_acq', 'SNR', 'Well', 'eff'])
for j, key in enumerate(filt_keys):
print('')
print(key)
obs = obs_dict[key]
obs_ref = obs.nrc_ref
sp_sci, sp_ref = (obs.sp_sci, obs.sp_ref)
# SW filter piggy-back on two LW filters, so 2 x tacq
is_SW = obs.bandpass.avgwave()/1e4 < 2.5
# Ramp optimization for both science and reference targets
for j, sp in enumerate([sp_sci, sp_ref]):
i = nrow = 0
while nrow==0:
well_max = well_levels[i]
tbl = obs.ramp_optimize(sp_opt, sp, well_frac_max=well_max, tacq_max=tacq_max, **kwargs)
nrow = len(tbl)
i+=1
# Grab the highest ranked MULTIACCUM settings and update the detector readout
v1, v2, v3 = tbl['Pattern', 'NGRP', 'NINT'][0]
vals = list(tbl[0])#.as_void()
strout = '{:10} {:4.0f} {:4.0f}'.format(vals[0], vals[1], vals[2])
for v in vals[3:]:
strout = strout + ', {:.4f}'.format(v)
print(strout)
# SW filter piggy-back on two LW filters, so 2 x tacq
# is_SW = obs.bandpass.avgwave()/1e4 < 2.5
# if is_SW:
# v3 *= 2
# Coronagraphic observations have two roll positions, so cut NINT by 2
if obs.image_mask is not None:
v3 = int(v3/2)
obs2 = obs if j==0 else obs_ref
obs2.update_detectors(read_mode=v1, ngroup=v2, nint=v3)
###########################################
# Functions to run a series of operations
###########################################
# Optimize observations
[docs]def do_opt(obs_dict, tacq_max=1800, **kwargs):
sp_opt = stellar_spectrum('flat', 20, 'vegamag', bp_k)
obs_optimize(obs_dict, sp_opt=sp_opt, tacq_max=tacq_max, **kwargs)
# For each filter setting, generate a series of contrast curves at different WFE values
[docs]def do_contrast(obs_dict, wfe_list, filt_keys, nsig=5, roll_angle=10, verbose=False, **kwargs):
"""
kwargs to pass to calc_contrast() and their defaults:
no_ref = False
func_std = robust.medabsdev
exclude_disk = True
exclude_planets = True
exclude_noise = False
opt_diff = True
fix_sat = False
ref_scale_all = False
"""
contrast_all = {}
for i in trange(len(filt_keys), desc='Observations'):
key = filt_keys[i]
obs = obs_dict[key]
if verbose:
print(key)
wfe_roll_temp = obs.wfe_roll_drift
wfe_ref_temp = obs.wfe_ref_drift
# Stores tuple of (Radial Distances, Contrast, and Sensitivity) for each WFE drift
curves = []
for wfe_drift in tqdm(wfe_list, leave=False, desc='WFE Drift'):
if ('no_ref' in list(kwargs.keys())) and (kwargs['no_ref']==True):
obs.wfe_roll_drift = wfe_drift
else:
obs.wfe_ref_drift = wfe_drift
result = obs.calc_contrast(roll_angle=roll_angle, nsig=nsig, **kwargs)
curves.append(result)
obs.wfe_roll_drift = wfe_roll_temp
obs.wfe_ref_drift = wfe_ref_temp
contrast_all[key] = curves
return contrast_all
[docs]def do_gen_hdus(obs_dict, filt_keys, wfe_ref_drift, wfe_roll_drift,
return_oversample=True, **kwargs):
"""
kwargs to pass to gen_roll_image() and their defaults:
PA1 = 0
PA2 = 10
zfact = None
return_oversample = True
exclude_disk = False
exclude_noise = False
no_ref = False
opt_diff = True
use_cmask = False
ref_scale_all = False
xyoff_roll1 = None
xyoff_roll2 = None
xyoff_ref = None
"""
hdulist_dict = {}
for key in tqdm(filt_keys):
# if verbose: print(key)
obs = obs_dict[key]
use_cmask = kwargs.pop('use_cmask', False)
hdulist = obs.gen_roll_image(return_oversample=return_oversample, use_cmask=use_cmask,
wfe_ref_drift=wfe_ref_drift, wfe_roll_drift=wfe_roll_drift, **kwargs)
hdulist_dict[key] = hdulist
return hdulist_dict
[docs]def do_sat_levels(obs, satval=0.95, ng_min=2, ng_max=None, verbose=True,
plot=True, xylim=2.5, return_fig_axes=False):
"""Only for obs.hci classes"""
ng_max = obs.det_info['ngroup'] if ng_max is None else ng_max
kw_gen_psf = {'return_oversample': False,'return_hdul': False}
# Well level of each pixel for science source
image = obs.calc_psf_from_coeff(sp=obs.sp_sci, **kw_gen_psf)
sci_levels1 = obs.saturation_levels(ngroup=ng_min, image=image)
sci_levels2 = obs.saturation_levels(ngroup=ng_max, image=image)
# Well level of each pixel for reference source
image = obs.calc_psf_from_coeff(sp=obs.sp_ref, **kw_gen_psf)
ref_levels1 = obs.saturation_levels(ngroup=ng_min, image=image, do_ref=True)
ref_levels2 = obs.saturation_levels(ngroup=ng_max, image=image, do_ref=True)
# Which pixels are saturated?
sci_mask1 = sci_levels1 > satval
sci_mask2 = sci_levels2 > satval
# Which pixels are saturated?
ref_mask1 = ref_levels1 > satval
ref_mask2 = ref_levels2 > satval
# How many saturated pixels?
nsat1_sci = len(sci_levels1[sci_mask1])
nsat2_sci = len(sci_levels2[sci_mask2])
# How many saturated pixels?
nsat1_ref = len(ref_levels1[ref_mask1])
nsat2_ref = len(ref_levels2[ref_mask2])
# Get saturation radius
if nsat1_sci == nsat1_ref == 0:
sat_rad = 0
else:
mask_temp = sci_mask1 if nsat1_sci>nsat1_ref else ref_mask1
rho_asec = dist_image(mask_temp, pixscale=obs.pix_scale)
sat_rad = rho_asec[mask_temp].max()
if verbose:
print('Sci: {}'.format(obs.sp_sci.name))
print(' {} saturated pixel at NGROUP={}; Max Well: {:.2f}'\
.format(nsat1_sci, ng_min, sci_levels1.max()))
print(' {} saturated pixel at NGROUP={}; Max Well: {:.2f}'\
.format(nsat2_sci, ng_max, sci_levels2.max()))
print(' Sat Dist NG={}: {:.2f} arcsec'.format(ng_min, sat_rad))
print('Ref: {}'.format(obs.sp_ref.name))
print(' {} saturated pixel at NGROUP={}; Max Well: {:.2f}'.\
format(nsat1_ref, ng_min, ref_levels1.max()))
print(' {} saturated pixel at NGROUP={}; Max Well: {:.2f}'.\
format(nsat2_ref, ng_max, ref_levels2.max()))
if (nsat2_sci==nsat2_ref==0) and (plot==True):
plot=False
print('Plotting turned off; no saturation detected.')
if plot:
fig, axes_all = plt.subplots(2,2, figsize=(8,8))
xlim = ylim = np.array([-1,1])*xylim
# Plot science source
nsat1, nsat2 = (nsat1_sci, nsat2_sci)
sat_mask1, sat_mask2 = (sci_mask1, sci_mask2)
sp = obs.sp_sci
xpix, ypix = (obs.det_info['xpix'], obs.det_info['ypix'])
bar_offpix = obs.bar_offset / obs.pixelscale
if ('FULL' in obs.det_info['wind_mode']) and (obs.image_mask is not None):
cdict = coron_ap_locs(obs.module, obs.channel, obs.image_mask, full=True)
xcen, ycen = cdict['cen_V23']
xcen += bar_offpix
else:
xcen, ycen = (xpix/2 + bar_offpix, ypix/2)
# rho = dist_image(sci_mask1, center=(xcen,ycen))
delx, dely = (xcen - xpix/2, ycen - ypix/2)
extent_pix = np.array([-xpix/2-delx,xpix/2-delx,-ypix/2-dely,ypix/2-dely])
extent = extent_pix * obs.pix_scale
axes = axes_all[0]
axes[0].imshow(sat_mask1, extent=extent)
axes[1].imshow(sat_mask2, extent=extent)
axes[0].set_title('{} Saturation (NGROUP=2)'.format(sp.name))
axes[1].set_title('{} Saturation (NGROUP={})'.format(sp.name,ng_max))
for ax in axes:
ax.set_xlabel('Arcsec')
ax.set_ylabel('Arcsec')
ax.tick_params(axis='both', color='white', which='both')
for k in ax.spines.keys():
ax.spines[k].set_color('white')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
# Plot ref source sat mask
nsat1, nsat2 = (nsat1_ref, nsat2_ref)
sat_mask1, sat_mask2 = (ref_mask1, ref_mask2)
sp = obs.sp_ref
axes = axes_all[1]
axes[0].imshow(sat_mask1, extent=extent)
axes[1].imshow(sat_mask2, extent=extent)
axes[0].set_title('{} Saturation (NGROUP=2)'.format(sp.name))
axes[1].set_title('{} Saturation (NGROUP={})'.format(sp.name,ng_max))
for ax in axes:
ax.set_xlabel('Arcsec')
ax.set_ylabel('Arcsec')
ax.tick_params(axis='both', color='white', which='both')
for k in ax.spines.keys():
ax.spines[k].set_color('white')
ax.set_xlim(xlim)
ax.set_ylim(ylim)
fig.tight_layout()
if return_fig_axes and plot:
return (fig, axes), sat_rad
else:
return sat_rad
###########################################
# Simulated Data
###########################################
[docs]def average_slopes(hdulist):
"""
For a series of ramps, calculate the slope images then average together.
"""
ramps = hdulist[1].data
header = hdulist[0].header
slopes_fin = []
for i in range(len(ramps)):
data = ramps[i]
# Create time array
ng, ypix, xpix = data.shape
tvals = (np.arange(ng)+1) * header['TGROUP']
# Flatten image space to 1D
data = data.reshape([ng,-1])
# Make saturation mask
sat_val = 0.95*data.max()
sat_mask = data > sat_val
# Create slope images
# Cycle through groups using only unsaturated pixels
im_slope = np.zeros_like(data[0]) - 10
for i in np.arange(1,ng)[::-1]:
ind = (im_slope==-10) & (~sat_mask[i])
if np.any(ind): # Check if any pixels are still True
im_slope[ind] = jl_poly_fit(tvals, data[:,ind])[1]
#print(im_slope[ind].shape)
# Special case of only first frame unsaturated
ind = (im_slope==-10) & (~sat_mask[0])
im_slope[ind] = data[:,ind] / tvals[0]
#print(im_slope[ind].shape)
# If saturated on first frame, set to NaN
ind = sat_mask[0]
im_slope[ind] = np.nan
#print(im_slope[ind].shape)
data = data.reshape([ng,ypix,xpix])
im_slope = im_slope.reshape([ypix,xpix])
slopes_fin.append(im_slope)
# Average slopes together
# us nanmean() to ignore those with NaNs
slopes_fin = np.array(slopes_fin)
slope_final = np.nanmean(slopes_fin, axis=0)
return slope_final
###########################################
# Plotting images and contrast curves
###########################################
[docs]def plot_contrasts_mjup(curves, nsig, wfe_list, obs=None, sat_rad=None, age=100,
ax=None, colors=None, xr=[0,10], yr=None, file=None, linder_models=True,
twin_ax=False, return_axes=False, **kwargs):
"""Plot mass contrast curves
Plot a series of mass contrast curves for corresponding WFE drifts.
Parameters
----------
curves : list
A list with length corresponding to `wfe_list`. Each list element
has three arrays in a tuple: the radius in arcsec, n-sigma contrast,
and n-sigma sensitivity limit (vega mag).
nsig : float
N-sigma limit corresponding to sensitivities/contrasts.
wfe_list : array-like
List of WFE drift values corresponding to each set of sensitivities
in `curves` argument.
Keyword Args
------------
obs : :class:`obs_hci`
Corresponding observation class that created the contrast curves.
Uses distances and stellar magnitude to plot contrast and AU
distances on opposing axes. Also necessary for mjup=True.
sat_rad : float
Saturation radius in arcsec. If >0, then that part of the contrast
curve is excluded from the plot
age : float
Required for plotting limiting planet masses.
file : string
Location and name of COND or Linder isochrone file.
ax : matplotlib.axes
Axes on which to plot curves.
colors : None, array-like
List of colors for contrast curves. Default is gradient of blues.
twin_ax : bool
Plot opposing axes in alternate units.
return_axes : bool
Return the matplotlib axes to continue plotting. If `obs` is set,
then this returns three sets of axes.
"""
if sat_rad is None:
sat_rad = 0
if ax is None:
fig, ax = plt.subplots()
if colors is None:
lin_vals = np.linspace(0.2,0.8,len(wfe_list))
colors = plt.cm.Blues_r(lin_vals)
filt = obs.filter
mod = obs.module
dist = obs.distance
if linder_models:
# Grab Linder model data
tbl = linder_table(file=file)
mass_data, mag_data = linder_filter(tbl, filt, age, dist=dist)
else:
# Grab COND model data
tbl = cond_table(age=age, file=file)
mass_data, mag_data = cond_filter(tbl, filt, module=mod, dist=dist)
# Plot the data
isort = np.argsort(mag_data)
for j, wfe_ref_drift in enumerate(wfe_list):
rr, contrast, mag_sens = curves[j]
label='$\Delta$' + "WFE = {} nm".format(wfe_list[j])
# Interpolate in log space
xv, yv = mag_data[isort], np.log10(mass_data[isort])
xint = mag_sens
yint = np.interp(xint, xv, yv)
# Choose the lowest mass value brighter than the given mag limits
yvals = np.array([np.min(yint[xint<=xv]) for xv in xint])
yvals = 10**yvals
xvals = rr[rr>sat_rad]
yvals = yvals[rr>sat_rad]
ax.plot(xvals, yvals, label=label, color=colors[j], zorder=1, lw=2)
if xr is not None: ax.set_xlim(xr)
if yr is not None: ax.set_ylim(yr)
ax.xaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ax.yaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ylabel = 'Mass Limits ($M_{\mathrm{Jup}}$)'
ax.set_ylabel(ylabel)
ax.set_xlabel('Separation (arcsec)')
if twin_ax:
# Plot opposing axes in alternate units
yr2 = np.array(ax.get_ylim()) * 318.0 # Convert to Earth masses
ax2 = ax.twinx()
ax2.set_ylim(yr2)
ax2.set_ylabel('Earth Masses')
ax3 = ax.twiny()
xr3 = np.array(ax.get_xlim()) * obs.distance
ax3.set_xlim(xr3)
ax3.set_xlabel('Separation (AU)')
ax3.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
if return_axes:
return (ax, ax2, ax3)
else:
if return_axes:
return ax
[docs]def plot_contrasts(curves, nsig, wfe_list, obs=None, sat_rad=None, ax=None,
colors=None, xr=[0,10], yr=[25,5], return_axes=False):
"""Plot contrast curves
Plot a series of contrast curves for corresponding WFE drifts.
Parameters
----------
curves : list
A list with length corresponding to `wfe_list`. Each list element
has three arrays in a tuple: the radius in arcsec, n-sigma contrast,
and n-sigma sensitivity limit (vega mag).
nsig : float
N-sigma limit corresponding to sensitivities/contrasts.
wfe_list : array-like
List of WFE drift values corresponding to each set of sensitivities
in `curves` argument.
Keyword Args
------------
obs : :class:`obs_hci`
Corresponding observation class that created the contrast curves.
Uses distances and stellar magnitude to plot contrast and AU
distances on opposing axes.
sat_rad : float
Saturation radius in arcsec. If >0, then that part of the contrast
curve is excluded from the plot
ax : matplotlib.axes
Axes on which to plot curves.
colors : None, array-like
List of colors for contrast curves. Default is gradient of blues.
return_axes : bool
Return the matplotlib axes to continue plotting. If `obs` is set,
then this returns three sets of axes.
"""
if sat_rad is None:
sat_rad = 0
if ax is None:
fig, ax = plt.subplots()
if colors is None:
lin_vals = np.linspace(0.3,0.8,len(wfe_list))
colors = plt.cm.Blues_r(lin_vals)
for j in range(len(wfe_list)): #for j, wfe_ref_drift in enumerate(wfe_list):
rr, contrast, mag_sens = curves[j]
xvals = rr[rr>sat_rad]
yvals = mag_sens[rr>sat_rad]
label='$\Delta$' + "WFE = {} nm".format(wfe_list[j])
ax.plot(xvals, yvals, label=label, color=colors[j], zorder=1, lw=2)
if xr is not None: ax.set_xlim(xr)
if yr is not None: ax.set_ylim(yr)
ax.xaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ax.yaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ax.set_ylabel('{:.0f}-$\sigma$ Sensitivities (mag)'.format(nsig))
ax.set_xlabel('Separation (arcsec)')
# Plot opposing axes in alternate units
if obs is not None:
yr1 = np.array(ax.get_ylim())
yr2 = 10**((obs.star_flux('vegamag') - yr1) / 2.5)
ax2 = ax.twinx()
ax2.set_yscale('log')
ax2.set_ylim(yr2)
ax2.set_ylabel('{:.0f}-$\sigma$ Contrast'.format(nsig))
ax3 = ax.twiny()
xr3 = np.array(ax.get_xlim()) * obs.distance
ax3.set_xlim(xr3)
ax3.set_xlabel('Separation (AU)')
ax3.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
if return_axes:
return (ax, ax2, ax3)
else:
if return_axes:
return ax
[docs]def planet_mags(obs, age=10, entropy=13, mass_list=[10,5,2,1], av_vals=[0,25], atmo='hy3s',
cond=False, linder=False, **kwargs):
"""Exoplanet Magnitudes
Determine a series of exoplanet magnitudes for given observation.
By default, use Spiegel & Burrows 2012 models, but has the option
to use the COND models from https://phoenix.ens-lyon.fr/Grids.
These are useful because SB12 model grids only ranges from 1-1000 Myr
with masses 1-15 MJup.
cond : bool
Instead of plotting sensitivities, use COND models to plot the
limiting planet masses.
linder : bool
Instead of plotting sensitivities, use Linder models to plot the
limiting planet masses.
file : string
Location and name of COND or Linder file.
"""
if av_vals is None:
av_vals = [0,0]
pmag = {}
for i,m in enumerate(mass_list):
flux_list = []
for j,av in enumerate(av_vals):
sp = obs.planet_spec(mass=m, age=age, Av=av, entropy=entropy, atmo=atmo, **kwargs)
sp_obs = S.Observation(sp, obs.bandpass, binset=obs.bandpass.wave)
flux = sp_obs.effstim('vegamag')
flux_list.append(flux)
pmag[m] = tuple(flux_list)
# Do COND models instead
# But still want SB12 models to get A_V information
if cond or linder:
# All mass and mag data for specified filter
filt = obs.filter
mod = obs.module
dist = obs.distance
if linder:
tbl = linder_table(**kwargs)
mass_data, mag_data = linder_filter(tbl, filt, age, dist=dist, **kwargs)
else:
# Grab COND model data
tbl = cond_table(age=age, **kwargs)
mass_data, mag_data = cond_filter(tbl, filt, module=mod, dist=dist, **kwargs)
# Mag information for the requested masses
isort = np.argsort(mass_data)
xv, yv = np.log10(mass_data[isort]), mag_data[isort]
mags0 = np.interp(np.log10(mass_list), np.log10(mass_data[isort]), mag_data[isort])
# Apply extinction
for i, m in enumerate(mass_list):
if np.allclose(av_vals, 0):
dm = np.array([0,0])
else:
#SB12 at A_V=0
sp = obs.planet_spec(mass=m, age=age, Av=0, entropy=entropy, atmo=atmo, **kwargs)
sp_obs = S.Observation(sp, obs.bandpass, binset=obs.bandpass.wave)
sb12_mag = sp_obs.effstim('vegamag')
# Get magnitude offset due to extinction
dm = np.array(pmag[m]) - sb12_mag
dm2 = pmag[m][1] - sb12_mag
# Apply extinction to COND models
pmag[m] = tuple(mags0[i] + dm)
return pmag
[docs]def plot_planet_patches(ax, obs, age=10, entropy=13, mass_list=[10,5,2,1], av_vals=[0,25],
cols=None, update_title=False, linder=False, **kwargs):
"""Plot exoplanet magnitudes in region corresponding to extinction values."""
import matplotlib.patches as mpatches
# Don't plot anything if
if mass_list is None:
_log.info("mass_list=None; Not plotting planet patch locations.")
return
xlim = ax.get_xlim()
#lin_vals = np.linspace(0,0.5,4)
#cols = plt.cm.Purples_r(lin_vals)[::-1]
if cols is None:
cols = plt.cm.tab10(np.linspace(0,1,10))
dist = obs.distance
if entropy<8: entropy=8
if entropy>13: entropy=13
pmag = planet_mags(obs, age, entropy, mass_list, av_vals, linder=linder, **kwargs)
for i,m in enumerate(mass_list):
label = 'Mass = {} '.format(m) + '$M_{\mathrm{Jup}}$'
if av_vals is None:
ax.plot(xlim, pmag[m], color=cols[i], lw=1, ls='--', label=label)
else:
pm_min, pm_max = pmag[m]
rect = mpatches.Rectangle((xlim[0], pm_min), xlim[1], pm_max-pm_min,
alpha=0.2, color=cols[i], label=label, zorder=2)
ax.add_patch(rect)
ax.plot(xlim, [pm_min]*2, color=cols[i], lw=1, alpha=0.3)
ax.plot(xlim, [pm_max]*2, color=cols[i], lw=1, alpha=0.3)
entropy_switch = {13:'Hot', 8:'Cold'}
entropy_string = entropy_switch.get(entropy, "Warm")
ent_str = 'BEX Models' if linder else '{} Start'.format(entropy_string)
if av_vals is None:
av_str = ''
else:
av_str = ' ($A_V = [{:.0f},{:.0f}]$)'.format(av_vals[0],av_vals[1])
#age_str = 'Age = {:.0f} Myr; '.format(age)
#dist_str = 'Dist = {:.1f} pc; '.format(dist) if dist is not None else ''
#dist_str=""
#ax.set_title('{} -- {} ({}{}{})'.format(obs.filter,ent_str,age_str,dist_str,av_str))
if update_title:
ax.set_title('{} -- {}{}'.format(obs.filter,ent_str,av_str))
[docs]def plot_hdulist(hdulist, ext=0, xr=None, yr=None, ax=None, return_ax=False,
cmap=None, scale='linear', vmin=None, vmax=None, axes_color='white',
half_pix_shift=False, cb_label='Counts/sec', **kwargs):
from webbpsf import display_psf
if ax is None:
fig, ax = plt.subplots()
if cmap is None:
cmap = matplotlib.rcParams['image.cmap']
# This has to do with even/odd number of pixels in array.
# Usually everything is centered in the middle of a pixel
# and for odd array sizes that is where (0,0) will be plotted.
# However, even array sizes will have (0,0) at the pixel border,
# so this just shifts the entire image accordingly.
if half_pix_shift:
oversamp = hdulist[ext].header['OSAMP']
shft = 0.5*oversamp
hdul = deepcopy(hdulist)
hdul[0].data = fshift(hdul[0].data, shft, shft)
else:
hdul = hdulist
data = hdul[ext].data
if vmax is None:
vmax = 0.75 * np.nanmax(data) if scale=='linear' else np.nanmax(data)
if vmin is None:
vmin = 0 if scale=='linear' else vmax/1e6
out = display_psf(hdul, ext=ext, ax=ax, title='', cmap=cmap,
scale=scale, vmin=vmin, vmax=vmax, return_ax=True, **kwargs)
try:
ax, cb = out
cb.set_label(cb_label)
except:
ax = out
ax.set_xlim(xr)
ax.set_ylim(yr)
ax.set_xlabel('Arcsec')
ax.set_ylabel('Arcsec')
ax.tick_params(axis='both', color=axes_color, which='both')
for k in ax.spines.keys():
ax.spines[k].set_color(axes_color)
ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
ax.yaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
if return_ax:
return ax
###########################################
# Plotting images and contrast curves
###########################################
[docs]def update_yscale(ax, scale_type, ylim=None):
# Some fancy log+linear plotting
from matplotlib.ticker import FixedLocator, ScalarFormatter, LogFormatterSciNotation
if scale_type=='symlog':
ylim = [0,100] if ylim is None else ylim
ax.set_ylim(ylim)
yr = ax.get_ylim()
ax.set_yscale('symlog', linthreshy=10, linscaley=2)
ax.set_yticks(list(range(0,10)) + [10,100,1000])
#ax.get_yaxis().set_major_formatter(ScalarFormatter())
ax.yaxis.set_major_formatter(ScalarFormatter())
minor_log = list(np.arange(20,100,10)) + list(np.arange(200,1000,100))
minorLocator = FixedLocator(minor_log)
ax.yaxis.set_minor_locator(minorLocator)
ax.set_ylim([0,yr[1]])
elif scale_type=='log':
ax.set_yscale('log')
ylim = [0.1,100] if ylim is None else ylim
ax.set_ylim(ylim)
ax.yaxis.set_major_formatter(LogFormatterSciNotation())
elif 'lin' in scale_type:
ax.set_yscale('linear')
ylim = [0,100] if ylim is None else ylim
ax.set_ylim(ylim)
[docs]def do_plot_contrasts(curves_ref, curves_roll, nsig, wfe_list, obs, age, age2=None,
sat_rad=0, jup_mag=True, xr=[0,10], yr=[22,8], xr2=[0,10], yscale2='log', yr2=None,
save_fig=False, outdir='', return_fig_axes=False, **kwargs):
"""
Plot series of contrast curves.
"""
if (curves_ref is None) and (curves_roll is None):
_log.warning('Both curves set no none. Returning...')
return
lin_vals = np.linspace(0.2,0.8,len(wfe_list))
c1 = plt.cm.Blues_r(lin_vals)
c2 = plt.cm.Reds_r(lin_vals)
c3 = plt.cm.Purples_r(lin_vals)
c4 = plt.cm.Greens_r(lin_vals)
fig, axes = plt.subplots(1,2, figsize=(14,4.5))
ax = axes[0]
if curves_ref is not None:
ax1, ax2, ax3 = plot_contrasts(curves_ref, nsig, wfe_list,
obs=obs, ax=ax, colors=c1, xr=xr, yr=yr, return_axes=True)
if curves_roll is not None:
obs_kw = None if curves_ref is not None else obs
axes2 = plot_contrasts(curves_roll, nsig, wfe_list,
obs=obs_kw, ax=ax, colors=c2, xr=xr, yr=yr, return_axes=True)
if curves_ref is None:
ax1, ax2, ax3 = axes2
axes1_all = [ax1, ax2, ax3]
#plot_planet_patches(ax, obs, age=age, av_vals=None, cond=True)
#ax.set_ylim([22,8])
# Legend organization
nwfe = len(wfe_list)
if curves_ref is None:
ax.legend(loc='upper right', title='Roll Sub')
elif curves_roll is None:
ax.legend(loc='upper right', title='Ref Sub')
else:
handles, labels = ax.get_legend_handles_labels()
h1 = handles[0:nwfe][::-1]
h2 = handles[nwfe:][::-1]
h1_t = [mpatches.Patch(color='none', label='Ref Sub')]
h2_t = [mpatches.Patch(color='none', label='Roll Sub')]
handles_new = h1_t + h1 + h2_t + h2
ax.legend(ncol=2, handles=handles_new, loc='upper right')
# Magnitude of Jupiter at object's distance
if jup_mag:
jspec = jupiter_spec(dist=obs.distance)
jobs = S.Observation(jspec, obs.bandpass, binset=obs.bandpass.wave)
jmag = jobs.effstim('vegamag')
if jmag<np.max(ax.get_ylim()):
ax.plot(xr, [jmag,jmag], color='C2', ls='--')
txt = 'Jupiter at {:.1f} pc'.format(obs.distance)
ax.text(xr[0]+0.02*(xr[1]-xr[0]), jmag, txt, horizontalalignment='left', verticalalignment='bottom')
# Plot in terms of Jupiter Masses
ax = axes[1]
age1 = age
if curves_ref is not None:
ax1, ax2, ax3 = plot_contrasts_mjup(curves_ref, nsig, wfe_list, obs=obs,
age=age1, ax=ax, colors=c1, xr=xr2, twin_ax=True, yr=None, return_axes=True)
if curves_roll is not None:
twin_kw = False if curves_ref is not None else True
axes2 = plot_contrasts_mjup(curves_roll, nsig, wfe_list, obs=obs,
age=age1, ax=ax, colors=c2, xr=xr2, twin_ax=twin_kw, yr=None, return_axes=True)
if curves_ref is None:
ax1, ax2, ax3 = axes2
axes2_all = [ax1, ax2, ax3]
if age2 is not None:
if curves_ref is not None:
plot_contrasts_mjup(curves_ref, nsig, wfe_list, obs=obs, age=age2, ax=ax, colors=c3, xr=xr2, yr=None)
if curves_roll is not None:
plot_contrasts_mjup(curves_roll, nsig, wfe_list, obs=obs, age=age2, ax=ax, colors=c4, xr=xr2, yr=None)
# Legend organization
handles, labels = ax.get_legend_handles_labels()
if curves_ref is None:
handles_new = [handles[i*nwfe] for i in range(2)]
labels_new = ['Roll Sub ({:.0f} Myr)'.format(age1),
'Roll Sub ({:.0f} Myr)'.format(age2)
]
elif curves_roll is None:
handles_new = [handles[i*nwfe] for i in range(2)]
labels_new = ['Ref Sub ({:.0f} Myr)'.format(age1),
'Ref Sub ({:.0f} Myr)'.format(age2)
]
else:
handles_new = [handles[i*nwfe] for i in range(4)]
labels_new = ['Ref Sub ({:.0f} Myr)'.format(age1),
'Roll Sub ({:.0f} Myr)'.format(age1),
'Ref Sub ({:.0f} Myr)'.format(age2),
'Roll Sub ({:.0f} Myr)'.format(age2),
]
else:
handles, labels = ax.get_legend_handles_labels()
if curves_ref is None:
handles_new = [handles[0]]
labels_new = ['Roll Sub ({:.0f} Myr)'.format(age1)]
elif curves_roll is None:
handles_new = [handles[0]]
labels_new = ['Ref Sub ({:.0f} Myr)'.format(age1)]
else:
handles_new = [handles[i*nwfe] for i in range(2)]
labels_new = ['Ref Sub ({:.0f} Myr)'.format(age1),
'Roll Sub ({:.0f} Myr)'.format(age1),
]
ax.legend(handles=handles_new, labels=labels_new, loc='upper right', title='COND Models')
# Update fancing y-axis scaling on right plot
update_yscale(ax, yscale2, ylim=yr2)
yr_temp = np.array(ax.get_ylim()) * 318.0
update_yscale(axes2_all[1], yscale2, ylim=yr_temp)
# Saturation regions
if sat_rad > 0:
sat_rad_asec = sat_rad
for ax in axes:
ylim = ax.get_ylim()
rect = mpatches.Rectangle((0, ylim[0]), sat_rad, ylim[1]-ylim[0], alpha=0.2, color='k', zorder=2)
ax.add_patch(rect)
name_sci = obs.sp_sci.name
name_ref = obs.sp_ref.name
if curves_ref is None:
title_str = '{} (dist = {:.1f} pc) -- {} Contrast Curves'\
.format(name_sci, obs.distance, obs.filter)
else:
title_str = '{} (dist = {:.1f} pc; PSF Ref: {}) -- {} Contrast Curves'\
.format(name_sci, obs.distance, name_ref, obs.filter)
fig.suptitle(title_str, fontsize=16)
fig.tight_layout()
fig.subplots_adjust(top=0.85, bottom=0.1 , left=0.05, right=0.95)
fname = "{}_contrast_{}.pdf".format(name_sci.replace(" ", ""), obs.image_mask)
if save_fig:
fig.savefig(outdir+fname)
if return_fig_axes:
return fig, (axes1_all, axes2_all)
[docs]def do_plot_contrasts2(key1, key2, curves_all, nsig, obs_dict, wfe_list, age, sat_dict=None,
label1='Curves1', label2='Curves2', xr=[0,10], yr=[24,8],
yscale2='log', yr2=None, av_vals=[0,10], curves_all2=None,
c1=None, c2=None, linder_models=True, planet_patches=True, **kwargs):
fig, axes = plt.subplots(1,2, figsize=(14,4.5))
lin_vals = np.linspace(0.2,0.8,len(wfe_list))
if c1 is None: c1 = plt.cm.Blues_r(lin_vals)
if c2 is None: c2 = plt.cm.Reds_r(lin_vals)
c3 = plt.cm.Purples_r(lin_vals)
c4 = plt.cm.Greens_r(lin_vals)
# Left plot (5-sigma sensitivities)
ax = axes[0]
k = key1
curves = curves_all[k]
obs = obs_dict[k]
sat_rad = None if sat_dict is None else sat_dict[k]
ax, ax2, ax3 = plot_contrasts(curves, nsig, wfe_list, obs=obs, sat_rad=sat_rad,
ax=ax, colors=c1, xr=xr, yr=yr, return_axes=True)
axes1_all = [ax, ax2, ax3]
if key2 is not None:
k = key2
curves = curves_all[k] if curves_all2 is None else curves_all2[k]
obs = None
sat_rad = None if sat_dict is None else sat_dict[k]
plot_contrasts(curves, nsig, wfe_list, obs=obs, sat_rad=sat_rad,
ax=ax, xr=xr, yr=yr, colors=c2)
# Planet mass locations
if planet_patches:
plot_planet_patches(ax, obs_dict[key1], age=age, update_title=True, av_vals=av_vals,
linder=linder_models, **kwargs)
ax.set_title('Flux Sensitivities')
# Right plot (Converted to MJup/MEarth)
ax = axes[1]
k = key1
curves = curves_all[k]
obs = obs_dict[k]
sat_rad = None if sat_dict is None else sat_dict[k]
ax, ax2, ax3 = plot_contrasts_mjup(curves, nsig, wfe_list, obs=obs, age=age, sat_rad=sat_rad,
ax=ax, colors=c1, xr=xr, twin_ax=True, return_axes=True,
linder_models=linder_models)
axes2_all = [ax, ax2, ax3]
if key2 is not None:
k = key2
curves = curves_all[k] if curves_all2 is None else curves_all2[k]
obs = obs_dict[k]
sat_rad = None if sat_dict is None else sat_dict[k]
plot_contrasts_mjup(curves, nsig, wfe_list, obs=obs, age=age, sat_rad=sat_rad,
ax=ax, colors=c2, xr=xr, linder_models=linder_models)
mod_str = 'BEX' if linder_models else 'COND'
ax.set_title('Mass Sensitivities -- {} Models'.format(mod_str))
# Update fancy y-axis scaling on right plot
ax = axes2_all[0]
update_yscale(ax, yscale2, ylim=yr2)
yr_temp = np.array(ax.get_ylim()) * 318.0
update_yscale(axes2_all[1], yscale2, ylim=yr_temp)
# Left legend
nwfe = len(wfe_list)
ax=axes[0]
handles, labels = ax.get_legend_handles_labels()
h1 = handles[0:nwfe][::-1]
h2 = handles[nwfe:2*nwfe][::-1]
h3 = handles[2*nwfe:]
h1_t = [mpatches.Patch(color='none', label=label1)]
h2_t = [mpatches.Patch(color='none', label=label2)]
h3_t = [mpatches.Patch(color='none', label='{} ({})'.format(mod_str, obs_dict[key1].filter))]
if planet_patches:
if key2 is not None:
handles_new = h1_t + h1 + h2_t + h2 + h3_t + h3
ncol = 3
else:
h3 = handles[nwfe:]
handles_new = h1_t + h1 + h3_t + h3
ncol = 2
else:
if key2 is not None:
handles_new = h1_t + h1 + h2_t + h2
ncol = 2
else:
handles_new = h1_t + h1
ncol = 1
ax.legend(ncol=ncol, handles=handles_new, loc=1, fontsize=9)
# Right legend
ax=axes[1]
handles, labels = ax.get_legend_handles_labels()
h1 = handles[0:nwfe][::-1]
h2 = handles[nwfe:2*nwfe][::-1]
h1_t = [mpatches.Patch(color='none', label=label1)]
h2_t = [mpatches.Patch(color='none', label=label2)]
if key2 is not None:
handles_new = h1_t + h1 + h2_t + h2
ncol = 2
else:
handles_new = h1_t + h1
ncol = 1
ax.legend(ncol=ncol, handles=handles_new, loc=1, fontsize=9)
# Title
name_sci = obs.sp_sci.name
dist = obs.distance
age_str = 'Age = {:.0f} Myr'.format(age)
dist_str = 'Distance = {:.1f} pc'.format(dist) if dist is not None else ''
title_str = '{} ({}, {})'.format(name_sci,age_str,dist_str)
fig.suptitle(title_str, fontsize=16);
fig.tight_layout()
fig.subplots_adjust(top=0.8, bottom=0.1 , left=0.05, right=0.95)
return (fig, (axes1_all, axes2_all))
[docs]def plot_images(obs_dict, hdu_dict, filt_keys, wfe_drift, fov=10,
save_fig=False, outdir='', return_fig_axes=False):
nfilt = len(filt_keys)
ext_name = ['Model', 'Sim Image (linear scale)', 'Sim Image ($r^2$ scale)']
nim = len(ext_name)
fig, axes = plt.subplots(nfilt, nim, figsize=(8.5,6.5))
#axes = axes.transpose()
for j, k in enumerate(filt_keys):
obs = obs_dict[k]
hdu_mod = obs.disk_hdulist
if hdu_mod is None:
raise ValueError('Disk model image is None. Did you forget to add the disk image?')
hdu_sim = hdu_dict[k]
data = hdu_sim[0].data
data -= np.nanmedian(data)
# Make r^2 scaled version of data
hdu_sim_r2 = deepcopy(hdu_sim)
data = hdu_sim_r2[0].data
data -= np.nanmedian(data)
header = hdu_sim_r2[0].header
rho = dist_image(data, pixscale=header['PIXELSCL'])
data *= rho**2
# Max value for model
data_mod = hdu_mod[0].data
header_mod = hdu_mod[0].header
# Scale to data pixelscale
data_mod = frebin(data_mod, scale=header_mod['PIXELSCL']/header['PIXELSCL'])
rho_mod = dist_image(data_mod, pixscale=header['PIXELSCL'])
data_mod_r2 = data_mod*rho_mod**2
vmax = np.max(data_mod)
vmax2 = np.max(data_mod_r2)
# Scale value for data
im_temp = pad_or_cut_to_size(data_mod, hdu_sim[0].data.shape)
mask_good = im_temp>(0.1*vmax)
scl1 = np.nanmedian(hdu_sim[0].data[mask_good] / im_temp[mask_good])
scl1 = np.abs(scl1)
# Scale value for r^2 version
im_temp = pad_or_cut_to_size(data_mod_r2, hdu_sim_r2[0].data.shape)
mask_good = im_temp>(0.1*vmax2)
scl2 = np.nanmedian(hdu_sim_r2[0].data[mask_good] / im_temp[mask_good])
scl2 = np.abs(scl2)
vmax_vals = [vmax, vmax*scl1, vmax2*scl2]
hdus = [hdu_mod, hdu_sim, hdu_sim_r2]
for i, ax in enumerate(axes[j]):
hdulist = hdus[i]
data = hdulist[0].data
header = hdulist[0].header
pixscale = header['PIXELSCL']
rho = dist_image(data, pixscale=pixscale)
rad = data.shape[0] * pixscale / 2
extent = [-rad, rad, -rad, rad]
ax.imshow(data, vmin=0, vmax=0.9*vmax_vals[i], extent=extent)
ax.set_aspect('equal')
if i > 0: ax.set_yticklabels([])
if j < nfilt-1: ax.set_xticklabels([])
if j==nfilt-1: ax.set_xlabel('Arcsec')
if j==0: ax.set_title(ext_name[i])
if i==0:
texp = obs.multiaccum_times['t_exp']
texp = round(2*texp/100)*100
exp_text = "{:.0f} sec".format(texp)
ax.set_ylabel('{} ({})'.format(obs.filter, exp_text))
xlim = [-fov/2,fov/2]
ylim = [-fov/2,fov/2]
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.xaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ax.yaxis.get_major_locator().set_params(nbins=10, steps=[1, 2, 5, 10])
ax.tick_params(axis='both', color='white', which='both')
for k in ax.spines.keys():
ax.spines[k].set_color('white')
name_sci = obs.sp_sci.name
wfe_text = "WFE Drift = {} nm".format(wfe_drift)
fig.suptitle('{} ({})'.format(name_sci, wfe_text), fontsize=16);
fig.tight_layout()
fig.subplots_adjust(wspace=0.05, hspace=0.05, top=0.9, bottom=0.1)
#fig.subplots_adjust(wspace=0.1, hspace=0.1, top=0.9, bottom=0.07 , left=0.05, right=0.97)
fname = "{}_images_{}.pdf".format(name_sci.replace(" ", ""), obs.image_mask)
if save_fig:
fig.savefig(outdir+fname)
if return_fig_axes:
return fig, axes
[docs]def plot_images_swlw(obs_dict, hdu_dict, filt_keys, wfe_drift, fov=10,
save_fig=False, outdir='', return_fig_axes=False):
nfilt = len(filt_keys)
ext_name = ['Model', 'Sim Image (linear scale)', 'Sim Image ($r^2$ scale)']
nim = len(ext_name)
fig, axes = plt.subplots(nim, nfilt, figsize=(14,7.5))
axes = axes.transpose()
for j, k in enumerate(filt_keys):
obs = obs_dict[k]
hdu_mod = obs.disk_hdulist
if hdu_mod is None:
raise ValueError('Disk model image is None. Did you forget to add the disk image?')
hdu_sim = hdu_dict[k]
data = hdu_sim[0].data
data -= np.nanmedian(data)
# Make r^2 scaled version of data
hdu_sim_r2 = deepcopy(hdu_sim)
data = hdu_sim_r2[0].data
data -= np.nanmedian(data)
header = hdu_sim_r2[0].header
rho = dist_image(data, pixscale=header['PIXELSCL'])
data *= rho**2
# Max value for model
data_mod = hdu_mod[0].data
header_mod = hdu_mod[0].header
# Scale to data pixelscale
data_mod = frebin(data_mod, scale=header_mod['PIXELSCL']/header['PIXELSCL'])
rho_mod = dist_image(data_mod, pixscale=header['PIXELSCL'])
data_mod_r2 = data_mod*rho_mod**2
vmax = np.max(data_mod)
vmax2 = np.max(data_mod_r2)
# Scale value for data
im_temp = pad_or_cut_to_size(data_mod, hdu_sim[0].data.shape)
mask_good = im_temp>(0.1*vmax)
scl1 = np.nanmedian(hdu_sim[0].data[mask_good] / im_temp[mask_good])
scl1 = np.abs(scl1)
# Scale value for r^2 version
im_temp = pad_or_cut_to_size(data_mod_r2, hdu_sim_r2[0].data.shape)
mask_good = im_temp>(0.1*vmax2)
scl2 = np.nanmedian(hdu_sim_r2[0].data[mask_good] / im_temp[mask_good])
scl2 = np.abs(scl2)
vmax_vals = [vmax,vmax*scl1,vmax2*scl2]
hdus = [hdu_mod, hdu_sim, hdu_sim_r2]
for i, ax in enumerate(axes[j]):
hdulist = hdus[i]
data = hdulist[0].data
header = hdulist[0].header
pixscale = header['PIXELSCL']
rho = dist_image(data, pixscale=pixscale)
rad = data.shape[0] * pixscale / 2
extent = [-rad, rad, -rad, rad]
ax.imshow(data, vmin=0, vmax=0.9*vmax_vals[i], extent=extent)
ax.set_aspect('equal')
if j > 0: ax.set_yticklabels([])
if i < nim-1: ax.set_xticklabels([])
if i==nim-1: ax.set_xlabel('Arcsec')
if j==0: ax.set_ylabel(ext_name[i])
if i==0:
texp = obs.multiaccum_times['t_exp']
texp = round(2*texp/100)*100
exp_text = "{:.0f} sec".format(texp)
ax.set_title('{} ({})'.format(obs.filter, exp_text))
xlim = [-fov/2,fov/2]
ylim = [-fov/2,fov/2]
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
ax.yaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
if fov<=2*rad:
ax.tick_params(axis='both', color='white', which='both')
for k in ax.spines.keys():
ax.spines[k].set_color('white')
name_sci = obs.sp_sci.name
wfe_text = "WFE Drift = {} nm".format(wfe_drift)
fig.suptitle('{} ({})'.format(name_sci, wfe_text), fontsize=16);
fig.tight_layout()
fig.subplots_adjust(wspace=0.1, hspace=0.1, top=0.9, bottom=0.07 , left=0.05, right=0.97)
fname = "{}_images_{}.pdf".format(name_sci.replace(" ", ""), obs.image_mask)
if save_fig:
fig.savefig(outdir+fname)
if return_fig_axes:
return fig, axes