Source code for webbpsf_ext.imreg_tools

import numpy as np
import os

import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

from .image_manip import fourier_imshift, fshift, frebin
from .image_manip import get_im_cen, pad_or_cut_to_size, bp_fix
from .image_manip import apply_pixel_diffusion, add_ipc, add_ppc
from .image_manip import crop_observation, crop_image
from .coords import dist_image, get_sgd_offsets
from .maths import round_int

from astropy.io import fits
from skimage.registration import phase_cross_correlation

# Create NRC SIAF class
from .utils import get_one_siaf
nrc_siaf = get_one_siaf(instrument='NIRCam')

import logging
# Define logging
_log = logging.getLogger(__name__)
_log.setLevel(logging.INFO)

###########################################################################
#    File Information
###########################################################################

def get_detname(det_id, use_long=True):
    """Return NRC[A-B][1-4,LONG] for valid detector/SCA IDs"""

    from .utils import get_detname as _get_detname
    return _get_detname(det_id, use_long=use_long)

def get_mask_from_pps(apname_pps):
    """Get mask name from PPS aperture name
    
    The PPS aperture name is of the form:
        NRC[A/B][1-5]_[FULL]_[MASK]
    where MASK is the name of the coronagraphic mask used.

    For target acquisition apertures the mask name can be
    prependend with "TA" (eg., TAMASK335R).

    Return '' if MASK not in input aperture name.
    """

    if 'MASK' not in apname_pps:
        return ''

    pps_str_arr = apname_pps.split('_')
    for s in pps_str_arr:
        if 'MASK' in s:
            image_mask = s
            break

    # Special case for TA apertures
    if 'TA' in image_mask:
        # Remove TA from mask name
        image_mask = image_mask.replace('TA', '')

        # Remove FS from mask name
        if 'FS' in image_mask:
            image_mask = image_mask.replace('FS', '')

        # Remove trailing S or L from LWB and SWB TA apertures
        if ('WB' in image_mask) and (image_mask[-1]=='S' or image_mask[-1]=='L'):
            image_mask = image_mask[:-1]

    return image_mask

def get_coron_apname(input):
    """Get aperture name from header or data model
    
    Parameters
    ==========
    input : fits.header.Header or datamodels.DataModel
        Input header or data model
    """

    if isinstance(input, (fits.header.Header)):
        # Aperture names
        apname = input['APERNAME']
        apname_pps = input['PPS_APER']
        subarray = input['SUBARRAY']
    else:
        # Data model meta info
        meta = input.meta

        # Aperture names
        apname = meta.aperture.name
        apname_pps = meta.aperture.pps_name
        subarray = meta.subarray.name

    # print(apname, apname_pps, subarray)

    # No need to do anything if the aperture names are the same
    # Also skip if MASK not in apname_pps
    if ((apname==apname_pps) or ('MASK' not in apname_pps)) and ('400X256' not in subarray):
        apname_new = apname
    else:
        # Should only get here if coron mask and apname doesn't match PPS
        apname_str_split = apname.split('_')
        sca = apname_str_split[0]
        image_mask = get_mask_from_pps(apname_pps)

        # Get subarray info
        # Sometimes apname erroneously has 'FULL' in it
        # So, first for subarray info in apname_pps 
        if ('400X256' in apname_pps) or ('400X256' in subarray):
            apn0 = f'{sca}_400X256'
        elif ('FULL' in apname_pps):
            apn0 = f'{sca}_FULL'
        else:
            apn0 = sca

        apname_new = f'{apn0}_{image_mask}'

        # Append filter or NARROW if needed
        pps_str_arr = apname_pps.split('_')
        last_str = pps_str_arr[-1]
        # Look for filter specified in PPS aperture name
        if ('_F1' in apname_pps) or ('_F2' in apname_pps) or ('_F3' in apname_pps) or ('_F4' in apname_pps):
            # Find all instances of "_"
            inds = [pos for pos, char in enumerate(apname_pps) if char == '_']
            # Filter is always appended to end, but can have different string sizes (F322W2)
            filter = apname_pps[inds[-1]+1:]
            apname_new += f'_{filter}'
        elif last_str=='NARROW':
            apname_new += '_NARROW'
        elif ('TAMASK' in apname_pps) and ('WB' in apname_pps[-1]):
            apname_new += '_WEDGE_BAR'
        elif ('TAMASK' in apname_pps) and (apname_pps[-1]=='R'):
            apname_new += '_WEDGE_RND'

    # print(apname_new)

    # If apname_new doesn't exist, we need to fall back to apname
    # even if it may not completely make sense.
    if apname_new in nrc_siaf.apernames:
        return apname_new
    else:
        return apname

def apname_full_frame_coron(apname):
    """Retrieve full frame version of coronagraphic aperture name"""

    if 'FULL' in apname:
        if 'FULL_WEDGE' in apname:
            _log.warning(f'Aperture name {apname} does not specify occulting mask.')
        return apname
    else:
        # Remove 400X256 string
        apname = apname.replace('_400X256', '')
        # Add in FULL string
        apname_full = apname.replace('_', '_FULL_', 1)
        return apname_full

def get_files(indir, pid=None, obsid=None, sca=None, filt=None, file_type='uncal.fits', 
              exp_type=None, vst_grp_act=None, apername=None, apername_pps=None):
    """Get files of interest
    
    Parameters
    ==========
    indir : str
        Location of FITS files.
    pid: int
        Program ID number.
    obsid : int
        Observation number.
    sca : str
        Name of detector (e.g., 'along' or 'a3')
    filt : str
        Return files observed in given filter.
    file_type : str
        uncal.fits or rate.fits, etc
    exp_type : str
        Exposure type such as NRC_TACQ, NRC_TACONFIRM, NRC_CORON, etc.
    vst_grp_act : str
        The _<gg><s><aa>_ portion of the file name.
        hdr0['VISITGRP'] + hdr0['SEQ_ID'] + hdr0['ACT_ID']
    apername : str
        Name of aperture (e.g., NRCA5_FULL)
    apername_pps : str
        Name of aperture from PPS (e.g., NRCA5_FULL)
    """

    sca = '' if sca is None else get_detname(sca).lower()
    
    # file name start and end
    file_start = 'jw' if pid is None else f'jw{pid:05d}'

    # Clear any underscores from file type input
    if file_type[0]=='_':
        file_type = file_type[1:]
    # Add SCA (if specified) and prepend underscore
    file_end = f'{sca.lower()}_{file_type}'

    # Get all files
    allfiles = np.sort([f for f in os.listdir(indir) if ((file_end in f) and f.startswith(file_start))])
    
    # Filter by obsid
    if obsid is not None:
        # files2 = []
        # for f in allfiles:
        #     hdr = fits.getheader(os.path.join(indir,f))
        #     if int(hdr.get('OBSERVTN', -1))==obsid:
        #         files2.append(f)
        # allfiles = np.array(files2)
        fstart = f'jw{pid:05d}{obsid:03d}'
        allfiles = np.array([f for f in allfiles if f.startswith(fstart)])

    # Check filter info
    if filt is not None:
        files2 = []
        for f in allfiles:
            hdr = fits.getheader(os.path.join(indir,f))
            obs_filt = hdr.get('FILTER', 'none')
            obs_pup = hdr.get('PUPIL', 'none')
            # Check if filter string exists in the pupil wheel
            if obs_pup[0]=='F' and (obs_pup[-1]=='N' or obs_pup[-1]=='M'):
                filt_match = obs_pup
            else:
                filt_match = obs_filt
            if filt==filt_match:
                files2.append(f)
        allfiles = np.array(files2)

    # Filter by exposure type
    if exp_type is not None:
        files2 = []
        for f in allfiles:
            hdr = fits.getheader(os.path.join(indir,f))
            if hdr.get('EXP_TYPE', 'none')==exp_type:
                files2.append(f)
        allfiles = np.array(files2)

    # Filter by visit group
    if vst_grp_act is not None:
        files2 = []
        for f in allfiles:
            hdr = fits.getheader(os.path.join(indir,f))
            if hdr.get('VISITGRP', 'none')==vst_grp_act[0:2].upper() and \
               hdr.get('SEQ_ID', 'none')==vst_grp_act[2].upper() and \
               hdr.get('ACT_ID', 'none')==vst_grp_act[3:].upper():
                # print(f)
                files2.append(f)
        allfiles = np.array(files2)

    if apername is not None:
        files2 = []
        for f in allfiles:
            hdr = fits.getheader(os.path.join(indir,f))
            apname_obs = hdr.get('APERNAME', 'none')
            if apname_obs==apername or apername==get_coron_apname(hdr):
                files2.append(f)
        allfiles = np.array(files2)

    if apername_pps is not None:
        files2 = []
        for f in allfiles:
            hdr = fits.getheader(os.path.join(indir,f))
            apname_pps = hdr.get('PPS_APER', 'none')
            if apname_pps==apername_pps:
                files2.append(f)
        allfiles = np.array(files2)
    
    return allfiles

def filter_files(files, save_dir):
    """Remove files where source is offset off of observed aperture"""

    # Check if we've dithered outside of FoV
    exp_ind = get_loc_all(files, save_dir, find_func=get_expected_loc)

    ind_keep = []
    for i, f in enumerate(files):
        xi, yi = exp_ind[i]
        
        # Open FITS file
        fpath = os.path.join(save_dir, f)
        hdul = fits.open(fpath)
        hdr = hdul[0].header
        ap = nrc_siaf[hdr['APERNAME']]

        if (0<xi<ap.XSciSize) and (0<yi<ap.YSciSize):
            ind_keep.append(i)

        # Close FITS file
        hdul.close()

    return files[ind_keep]

def get_save_dir(pid, mast_dir=None):
    """Return save directory for processed files
    
    Takes the MAST directory for a given PID and adds
    '_proc' to the end. If it doesn't exist, it will be created.
    """
    if mast_dir is None:
        mast_dir = os.getenv('JWSTDOWNLOAD_OUTDIR')

    # Output directory
    if mast_dir[-1]=='/':
        mast_proc_dir = mast_dir[:-1] + '_proc/'
    else:
        mast_proc_dir = mast_dir + '_proc/'
    save_dir = os.path.join(mast_proc_dir, f'{pid:05d}/')

    # Create directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    return save_dir

###########################################################################
#    Target Acquisition
###########################################################################

def get_ictm_event_log(startdate, enddate, hdr=None, mast_api_token=None, verbose=False):
    """Get ICTM event log from MAST
    
    Parameters
    ==========
    startdate : str
        Start date of observation of format YYYY-MM-DD HH:MM:SS.sss
    enddate : str
        End date of observation of format YYYY-MM-DD HH:MM:SS.sss
    """

    from datetime import datetime, timedelta, timezone
    from requests import Session
    import time

    # parameters
    mnemonic = 'ICTM_EVENT_MSG'

    # constants
    base = 'https://mast.stsci.edu/jwst/api/v0.1/Download/file?uri=mast:jwstedb'
    mastfmt = '%Y%m%dT%H%M%S'
    tz_utc = timezone(timedelta(hours=0))

    # establish MAST session
    session = Session()

    # Attempt to find MAST token if set to None
    if mast_api_token is None:
        mast_api_token = os.environ.get('MAST_API_TOKEN')
        # NOTE: MAST token is no longer strictly necessary (I think?)
        # if mast_api_token is None:
        #     raise ValueError("Must define MAST_API_TOKEN env variable or specify mast_api_token parameter")

    # Update token
    if mast_api_token is not None:
        session.headers.update({'Authorization': f'token {mast_api_token}'})

    # Determine date range to grab data
    if hdr is not None:
        startdate = hdr['VSTSTART']
        enddate = hdr['VISITEND']
    
    startdate = startdate.replace(' ', '+')
    try:
        idot = startdate.index('.')
        startdate = startdate[0:idot]
    except ValueError:
        pass

    enddate = enddate.replace(' ', '+')
    try:
        idot = enddate.index('.')
        enddate = enddate[0:idot]
    except ValueError:
        pass

    # fetch event messages from MAST engineering database (lags FOS EDB)
    start = datetime.fromisoformat(startdate)
    end = datetime.now(tz=tz_utc) if enddate is None else datetime.fromisoformat(enddate)
    startstr = start.strftime(mastfmt)
    endstr = end.strftime(mastfmt)
    filename = f'{mnemonic}-{startstr}-{endstr}.csv'
    url = f'{base}/{filename}'

    if verbose:
        _log.info(f"Retrieving {url}")
    response = session.get(url)
    if response.status_code == 401:
        exit('HTTPError 401 - Check your MAST token and EDB authorization.')

    retries = 0
    retry_limit = 5
    while retries < retry_limit:
        try:
            response.raise_for_status()
            break
        except Exception as e:
            # Wait 5 seconds before retrying
            time.sleep(5)
            # log the error
            retries += 1
            if retries == retry_limit:
                _log.error(f'Failed to retreieve url after {retry_limit} tries')
                raise e

    lines = response.content.decode('utf-8').splitlines()

    return lines

def tasub_to_apname(tasub):

    # Get aperture name from TA subarray name

    # Dictionary of aperture names
    apname_dict={
        'SUBFSA210R' : 'NRCA2_FSTAMASK210R' ,
        'SUBFSA335R' : 'NRCA5_FSTAMASK335R',
        'SUBFSA430R' : 'NRCA5_FSTAMASK430R',
        'SUBFSALWB'  : 'NRCA5_FSTAMASKLWB'  ,
        'SUBFSASWB'  : 'NRCA4_FSTAMASKSWB'  ,
        'SUBNDA210R' : 'NRCA2_TAMASK210R'   ,
        'SUBNDA335R' : 'NRCA5_TAMASK335R'   ,
        'SUBNDA430R' : 'NRCA5_TAMASK430R'   ,
        'SUBNDALWBL' : 'NRCA5_TAMASKLWBL'   ,
        'SUBNDALWBS' : 'NRCA5_TAMASKLWB'    ,
        'SUBNDASWBL' : 'NRCA4_TAMASKSWB'    ,
        'SUBNDASWBS' : 'NRCA4_TAMASKSWBS'   ,
        'SUBNDB210R' : 'NRCB1_TAMASK210R'   ,
        'SUBNDB335R' : 'NRCB5_TAMASK335R'   ,
        'SUBNDB430R' : 'NRCB5_TAMASK430R'   ,
        'SUBNDBLWBL' : 'NRCB5_TAMASKLWBL'   ,
        'SUBNDBLWBS' : 'NRCB5_TAMASKLWB'    ,
        'SUBNDBSWBL' : 'NRCB3_TAMASKSWB'    ,
        'SUBNDBSWBS' : 'NRCB3_TAMASKSWBS'   ,
    }

    return apname_dict[tasub]


def print_ta_visit_times(eventlog, verbose=True):
    """Get centroid position of TA as reported in JWST event logs"""

    from csv import reader
    from datetime import datetime

    # parse response (ignoring header line) and print new event messages
    vid = ''
    ta_only = True
    in_ta = False

    # Search through event log for TA visit and get visit ids
    vid_list = []
    vstart_list = []
    vend_list = []
    for value in reader(eventlog, delimiter=',', quotechar='"'):
        val_str = value[2]

        if val_str[:6] == 'VISIT ':
            if val_str[-7:] == 'STARTED':
                vstart = 'T'.join(value[0].split())[:-3]
                vid = val_str.split()[1]
                # Add to lists
                vid_list.append(vid)
                vstart_list.append(vstart)

            elif val_str[-5:] == 'ENDED':
                vend = 'T'.join(value[0].split())[:-3]
                vend_list.append(vend)

    # Grab unique visit ids
    vid_list, ivid = np.unique(vid_list, return_index=True)
    vstart_list    = np.array(vstart_list)[ivid]
    vend_list      = np.array(vend_list)[ivid]

    for i, vid in enumerate(vid_list):
        if verbose:
            print(f"VISIT {vid} STARTED at {vstart_list[i]}")
        find_centroid_det(eventlog, vid)
        if verbose:
            print(f"VISIT {vid} ENDED at {vend_list[i]}")
        if i+1 < len(vid_list):
            print('')


def find_centroid_det(eventlog, selected_visit_id):
    """Get centroid position of TA as reported in JWST event logs"""

    from csv import reader
    from datetime import datetime

    # parse response (ignoring header line) and print new event messages
    vid = ''
    in_selected_visit = False
    ta_only = True
    in_ta = False
    tasub = None

    for value in reader(eventlog, delimiter=',', quotechar='"'):
        val_str = value[2]

        # Get subarray name for visit
        if in_selected_visit and  ('Configured NIRCam subarray' in val_str):
            val_str_list = val_str.split(' ')
            if tasub is None:
                tasub = val_str_list[-1].split(',')[0]
            _log.info(val_str)
            
        if in_selected_visit and ((not ta_only) or in_ta) :
            # print(value[0][0:22], "\t", value[2])
            
            # Print coordinate location info
            if ('postage-stamp coord' in val_str) or ('detector coord' in val_str): 
                _log.info(val_str)
        
            # Backup coords in case of TA centroid failure
            if 'postage-stamp coord (colPeak, rowPeak)' in val_str:
                val_str_list = val_str.split('=')
                xcen, ycen = val_str_list[1].split(',')
                ind1 = xcen.find('(')
                xcen = xcen[ind1+1:]
                ind2 = ycen.find(')')
                ycen = ycen[0:ind2]
                # These are NOT 'sci' coords, but instead a 
                # subarray cut-out in detector coords
                peak_coords = (float(xcen), float(ycen))

            # Parse centroid position reported in detector coordinates
            if ('detector coord (colCentroid, rowCentroid)') in val_str or \
               ('detector coord (colCen, rowCen)' in val_str):
                val_str_list = val_str.split('=')
                xcen, ycen = val_str_list[1].split(',')
                ind1 = xcen.find('(')
                xcen = xcen[ind1+1:]
                ind2 = ycen.find(')')
                ycen = ycen[0:ind2]

                return float(xcen), float(ycen)
            
            elif 'detector coord (colCen, rowCen)' in val_str:
                val_str_list = val_str.split('=')
                xcen, ycen = val_str_list[1].split(',')
                ind1 = xcen.find('(')
                xcen = xcen[ind1+1:]
                ind2 = ycen.find(')')
                ycen = ycen[0:ind2]

                return float(xcen), float(ycen)

        # Flag if current line is between when visit starts and ends
        if val_str[:6] == 'VISIT ':
            if val_str[-7:] == 'STARTED':
                vstart = 'T'.join(value[0].split())[:-3]
                vid = val_str.split()[1]

                if vid==selected_visit_id:
                    _log.debug(f"VISIT {selected_visit_id} START FOUND at {vstart}")
                    in_selected_visit = True
                    tasub = None
                    # if ta_only:
                    #     print("Only displaying TARGET ACQUISITION RESULTS:")

            elif val_str[-5:] == 'ENDED' and in_selected_visit:
                assert vid == val_str.split()[1]
                assert selected_visit_id  == val_str.split()[1]

                vend = 'T'.join(value[0].split())[:-3]
                _log.debug(f"VISIT {selected_visit_id} END FOUND at {vend}")

                in_selected_visit = False
        elif val_str[:31] == f'Script terminated: {vid}':
            if val_str[-5:] == 'ERROR':
                script = val_str.split(':')[2]
                vend = 'T'.join(value[0].split())[:-3]
                dur = datetime.fromisoformat(vend) - datetime.fromisoformat(vstart)
                note = f'Halt in {script}'
                in_selected_visit = False
        elif in_selected_visit and val_str.startswith('*'): 
            # this string is used to mark the start and end of TA sections
            in_ta = not in_ta

    # If we've gotten here, then no centroid was found
    # Return peak coords if available
    if 'peak_coords' in locals():
        _log.warning(f'No centroid found for {selected_visit_id}. Using peak coords instead.')
        apname = tasub_to_apname(tasub)
        ap = nrc_siaf[apname]
        x0, y0 = np.min(ap.corners('det'), axis=1)

        # Figure out location of peak in full frame
        xp_full = peak_coords[0] + x0 - 0.5
        yp_full = peak_coords[1] + y0 - 0.5

        return np.array([xp_full, yp_full])
    else:
        _log.warning(f'No centroid found for {selected_visit_id}.')
        return None

def diff_ta_data(uncal_data):
    """Onboard algorithm to difference TA data"""
    
    data = uncal_data.astype('float')
    
    nint, ng, ny, nx = data.shape
    im1 = data[0,-1]    - data[0,ng//2]
    im2 = data[0,ng//2] - data[0,0]
    
    return np.minimum(im1,im2)

def read_ta_files(indir, pid, obsid, sca, file_type='rate.fits', 
                  uncal_dir=None, bpfix=False):
    """Store all TA and Conf data into a dictionary
    
    indir should include rate.fits. For the initial TACQ, can use
    uncal files (via `uncal_dir` input flag) to simulate onboard
    subtraction. 

    bpfix is only for the TACONF data and mainly for display purposes.

    Parameters
    ==========
    indir : str
        Input directory
    pid : int
        Program ID number
    obsid : int
        Observation number
    sca : str
        SCA name, such as a1, a2, a3, a4, along, etc
    file_type : str
        File extension, such as uncal.fits, rate.fits, cal.fits, etc.
    uncal_dir : str or None
        If not None, use uncal files in this directory for TACQ data.
    bpfix : bool
        If True, perform bad pixel fixing on the data.
        Mainly for display purposes.
    """

    from jwst import datamodels

    # Option to use uncal files for subarray TA observation
    ta_dir = indir if uncal_dir is None else uncal_dir
    taconf_dir = indir

    fta_type = file_type if uncal_dir is None else 'uncal.fits'

    # Get TACQ
    try:
        fta = get_files(ta_dir, pid, obsid=obsid, sca=sca, 
                        file_type=fta_type, exp_type='NRC_TACQ')[-1]
    except:
        raise RuntimeError(f'Unable to determine NRC_TACQ file for PID {pid} Obs, {obsid}, {sca}')

    # Full path
    fta_path = os.path.join(ta_dir, fta)
    ta_dict = {'dta': {'file': fta_path, 'type': 'Target Acq'}}

    # Get TACONFIRM 
    fconf = get_files(taconf_dir, pid, obsid=obsid, sca=sca, 
                      file_type=file_type, exp_type='NRC_TACONFIRM')
    if len(fconf)>0:
        fconf1, fconf2 = fconf
        # Full paths of files
        fconf1_path = os.path.join(taconf_dir, fconf1)
        fconf2_path = os.path.join(taconf_dir, fconf2)

        ta_dict['dconf1'] = {'file': fconf1_path, 'type': 'TA Conf1'}
        ta_dict['dconf2'] = {'file': fconf2_path, 'type': 'TA Conf2'}
    else:
        _log.warning(f'NRC_TACQ exists, but no NRC_TACONFIRM observed for PID {pid}, Obs {obsid}, {sca}')

    # Build dictionary of data and header info
    for k in ta_dict.keys():
        d = ta_dict[k]

        f = d['file']
        # print(f)
        hdul = fits.open(f)
        # Get data and take diff if uncal
        data = hdul['SCI'].data.astype('float')
        if 'uncal.fits' in f:
            # For TACQ, do difference and get DQ mask from rate file
            data = diff_ta_data(data)
            frate = get_files(indir, pid, obsid=obsid, sca=sca, 
                              file_type=file_type, exp_type='NRC_TACQ')[0]
            frate_path = os.path.join(indir, frate)
            dq = fits.getdata(frate_path, extname='DQ')
        else:
            dq = hdul['DQ'].data

        # Get date from datamodel
        data_model = datamodels.open(f)
        date = data_model.meta.observation.date_beg
        # Close data model
        data_model.close()

        d['data'] = data
        d['dq']   = dq
        d['hdr0'] = hdul[0].header
        d['hdr1'] = hdul[1].header
        d['date'] = date
        hdul.close()

        d['apname'] = get_coron_apname(d['hdr0'])
        # Apername supplied by PPS for pointing control
        d['apname_pps'] = d['hdr0']['PPS_APER']

        d['ap'] = nrc_siaf[d['apname']]
        d['ap_pps'] = nrc_siaf[d['apname_pps']]

        # Exposure type
        d['exp_type'] = d['hdr0']['EXP_TYPE']

        # bad pixel fixing for TA confirmation
        if bpfix and ('conf' in k):
            im = crop_observation(d['data'], d['ap'], 100)
            # Perform pixel fixing in place
            _ = bp_fix(im, sigclip=10, niter=1, in_place=True)
        
    return ta_dict


def read_sgd_files(indir, pid, obsid, filter, sca, bpfix=False, 
                   file_type='rate.fits', exp_type=None, vst_grp_act=None,
                   apername=None, apername_pps=None, nodata=False,
                   combine_same_dithers=False):
    """Store SGD or science data into a dictionary

    By default, excludes any TAMASK or TACONFIRM data, but can be overridden
    by setting exp_type.
    
    Parameters
    ==========
    indir : str
        Input directory
    pid : int
        Program ID number
    obsid : int
        Observation number
    filter : str
        Name of filter element
    sca : str
        SCA name, such as a1, a2, a3, a4, along, etc
    file_type : str
        File extension, such as uncal.fits, rate.fits, cal.fits, etc.
    exp_type : str
        Exposure type such as NRC_TACQ, NRC_TACONFIRM
    vst_grp_act : str
        The _<gg><s><aa>_ portion of the file name.
        hdr0['VISITGRP'] + hdr0['SEQ_ID'] + hdr0['ACT_ID']
    apername : str
        Name of aperture (e.g., NRCA5_FULL)
    apername_pps : str
        Name of aperture from PPS (e.g., NRCA5_FULL)
    bpfix : bool
        If True, perform bad pixel fixing on the data.
        Mainly for display purposes.
    nodata : bool
        If True, only return header info and not data.
    combine_same_dithers : bool
        Combine same dither positions? Looks at the 'PATT_NUM' keyword.
    """

    from jwst import datamodels

    files = get_files(indir, pid, obsid=obsid, sca=sca, filt=filter,
                      file_type=file_type, exp_type=exp_type, vst_grp_act=vst_grp_act,
                      apername=apername, apername_pps=apername_pps)

    if len(files)==0:
        _log.warning(f'No files found for PID {pid}, Obs {obsid}, {sca} with filter {filter}')
        _log.warning(f'file_type={file_type}, exp_type={exp_type}, vst_grp_act={vst_grp_act}, apername={apername}, apername_pps={apername_pps}')
        _log.warning(f'Input directory: {indir}')
        return {}

    # Exclude any TAMASK or TACONFIRM data by default
    if exp_type is None:
        ikeep = []
        for i, f in enumerate(files):
            fpath = os.path.join(indir, f)
            hdr = fits.getheader(fpath, ext=0)
            isTA = ('_TACQ' in hdr['EXP_TYPE']) or ('_TACONFIRM' in hdr['EXP_TYPE'])
            if not isTA:
                ikeep.append(i)
        
        files = files[ikeep]

    if len(files)==0:
        _log.warning(f'No science files found for PID {pid}, Obs {obsid}, {sca} with filter {filter}')
        _log.warning(f'file_type={file_type}, exp_type={exp_type}, vst_grp_act={vst_grp_act}, apername={apername}, apername_pps={apername_pps}')
        _log.warning(f'Input directory: {indir}')
        return {}

    sgd_dict = {}
    for i, f in enumerate(files):
        fpath = os.path.join(indir, f)
        d = {'file': fpath}
        
        hdul = fits.open(fpath)
        if not nodata:
            d['data'] = hdul['SCI'].data.astype('float')
            d['dq']   = hdul['DQ'].data
            try:
                d['err'] = hdul['ERR'].data
            except:
                d['err'] = None
                
        d['hdr0'] = hdul[0].header
        d['hdr1'] = hdul[1].header
        hdul.close()

        # Get date from datamodel
        data_model = datamodels.open(fpath)
        d['date'] = data_model.meta.observation.date_beg
        # Close data model
        data_model.close()
 
        d['apname'] = get_coron_apname(d['hdr0'])
        # Apername supplied by PPS for pointing control
        d['apname_pps'] = d['hdr0']['PPS_APER']

        # Add SIAF apertures
        d['ap'] = nrc_siaf[d['apname']]
        d['ap_pps'] = nrc_siaf[d['apname_pps']]

        # Exposure type
        d['exp_type'] = d['hdr0']['EXP_TYPE']

        sgd_dict[i] = d

        # bad pixel fixing 
        if bpfix and not nodata:
            im = crop_observation(d['data'], d['ap'], 100)
            # Perform pixel fixing in place
            _ = bp_fix(im, sigclip=10, niter=1, in_place=True)

    # Loop through dictionaries and combine observations at same dither position
    patt_num = sgd_dict[0]['hdr0'].get('PATT_NUM', None)
    if combine_same_dithers and (patt_num is not None):
        patt_num_arr = np.array([d['hdr0'].get('PATT_NUM') for i, d in sgd_dict.items()])
        patt_num_uniq = np.unique(patt_num_arr)

        if len(patt_num_uniq)!=len(patt_num_arr):
            _log.warning('Combining observation data of same dither positions. Only header info for the first instance will be retained.')

        # Combine data at same dither positions
        for patt_num in patt_num_uniq:
            # Find all instances of this pattern number
            ind_patt = np.where(patt_num_arr==patt_num)[0]
            if len(ind_patt)==1:
                continue

            # Combine data
            d = sgd_dict[ind_patt[0]]
            for i in ind_patt[1:]:
                if d.get('files', None) is None:
                    d['files'] = [d['file']]
                d['files'] = d['files'] + [sgd_dict[i]['file']]
                if not nodata:
                    d['data'] = np.concatenate((d['data'], sgd_dict[i]['data']), axis=0)
                    d['dq']   = np.concatenate((d['dq'], sgd_dict[i]['dq']), axis=0)
                    if d['err'] is not None and sgd_dict[i]['err'] is not None:
                        d['err'] = np.concatenate((d['err'], sgd_dict[i]['err']), axis=0)

            # Remove second entries
            for i in ind_patt[1:]:
                del sgd_dict[i]

    return sgd_dict


###########################################################################
#    Image Cropping
###########################################################################

def get_expected_loc(input, return_indices=True, add_sroffset=None):
    """Input header or data model to get expected pixel position of target
    
    Integer values correspond to center of a pixel, whereas 0.5
    correspond to pixel edges.

    `return_indices=True` will return the [xi,yi] index within the 
    observed aperture subarray, otherwise returns the 'sci' coordinate 
    position. These should only be off 1 (e.g. index=sci-1, because
    'sci' coordinates are 1-index, while numpy arrays are 0-indexed).

    SR offsets excluded for dates prior to 2022-07-01, otherwise included.
    Specify `add_sroffset=True` or `add_sroffset=False` to override the
    default settings. If False, any SGD offsets will be added back in.
    TODO: What about normal dithers?

    Parameters
    ==========
    input : fits.header.Header or datamodels.DataModel
        Input header or data model
    return_indices : bool
        Return indices of expected location within the subarray
        otherwise return the 'sci' coordinate position.
    add_sroffset : None or bool
        Include Special Requirements (SR) offset in the calculation.
        Will default to False if date<2022-07-01, otherwise True.
        Specify True or False to override the default.
        If False, any SGD offsets will be added back in.
    """
    
    from astropy.time import Time

    apname = get_coron_apname(input)

    if isinstance(input, (fits.header.Header)):
        # Aperture names
        apname_pps = input['PPS_APER']
        # Dither offsets
        xoff_asec = input['XOFFSET']
        yoff_asec = input['YOFFSET']
        # date
        date_obs = input['DATE-OBS']

        # SGD info (only needed if SR offsets is False)
        is_sgd = input.get('SUBPXPAT', False)
        sgd_pattern = input.get('SMGRDPAT', None)
        sgd_pos = input.get('PATT_NUM', 1) - 1
    else:
        # Data model meta info
        meta = input.meta

        # Aperture names
        apname_pps = meta.aperture.pps_name
        # Dither offsets
        xoff_asec = meta.dither.x_offset
        yoff_asec = meta.dither.y_offset
        # date
        date_obs = meta.observation.date

        # SGD info (only needed if SR offsets is False)
        if hasattr(meta.dither, 'subpixel_pattern'):
            subpixel_pattern = meta.dither.subpixel_pattern
            if subpixel_pattern is None:
                is_sgd = False
            elif 'small-grid' in subpixel_pattern.lower():
                is_sgd = True
            else:
                is_sgd = False
        else:
            is_sgd = False
        # SGD type
        if is_sgd and hasattr(meta.dither, 'small_grid_pattern'):
            sgd_pattern = meta.dither.small_grid_pattern
        else:
            sgd_pattern = None
        # SGD position index
        if is_sgd and hasattr(meta.dither, 'position_number'):
            sgd_pos = meta.dither.position_number - 1
        else:
            sgd_pos = 0

    # Include SIAF subarray offset?
    # Set defaults
    if add_sroffset is None:
        # If observed before 2022-07-01, then don't include SR offset.
        # SR offsets prior to 2022-07-01 were included to match expected
        # changes to the SIAF that were made after July 1 (or around there).
        add_sroffset = False if Time(date_obs) < Time('2022-07-01') else True

    # If offsets excluded, then reset xoff and yoff to 0
    # but add in SGD offsets if they exist
    if not add_sroffset:
        xoff_asec = yoff_asec = 0.0
        # Add in a SGD offsets if they exist
        if is_sgd and (sgd_pattern is not None):
            xoff_arr, yoff_arr = get_sgd_offsets(sgd_pattern)
            xoff_asec += xoff_arr[sgd_pos]
            yoff_asec += yoff_arr[sgd_pos]

    # Observed aperture
    ap = nrc_siaf[apname]
    # Aperture reference for pointing / dithering
    ap_pps = nrc_siaf[apname_pps]
    
    # Expected pixel location based on ideal offset
    if apname == apname_pps:
        xsci_exp, ysci_exp = (ap.XSciRef, ap.YSciRef)
        # Add offset
        xsci_exp = xsci_exp + xoff_asec / ap.XSciScale
        ysci_exp = ysci_exp + yoff_asec / ap.YSciScale
    else:
        if np.allclose([xoff_asec, yoff_asec], 0.0):
            xtel, ytel = (ap_pps.V2Ref, ap_pps.V3Ref)
        else:
            xtel, ytel = ap_pps.idl_to_tel(xoff_asec, yoff_asec)
        xsci_exp, ysci_exp = ap.tel_to_sci(xtel, ytel)
    
    if return_indices:
        return xsci_exp-1, ysci_exp-1
    else:
        return xsci_exp, ysci_exp

def get_gfit_cen(im, xysub=11, return_sci=False, find_max=True, **kwargs):
    """Gaussion fit to get centroid position"""
    
    from astropy.modeling import models, fitting

    # Set NaNs to 0
    ind_nan = np.isnan(im)
    im[ind_nan] = 0

    # Crop around max value?
    if find_max:
        yind, xind = np.unravel_index(np.argmax(im), im.shape)
        xyloc = (xind, yind)
    else:
        xyloc = None
    im_sub, (x1, x2, y1, y2) = crop_image(im, xysub, return_xy=True, xyloc=xyloc)

    # Add crop indices create grid in terms of full image indices
    xv = np.arange(x1, x2)
    yv = np.arange(y1, y2)
    xgrid, ygrid = np.meshgrid(xv, yv)
    xc, yc = (xv.mean(), yv.mean())

    # Fit the data using astropy.modeling
    p_init = models.Gaussian2D(amplitude=im_sub.max(), x_mean=xc, y_mean=yc, x_stddev=1, y_stddev=2)
    fit_p = fitting.LevMarLSQFitter()

    pfit = fit_p(p_init, xgrid, ygrid, im_sub)
    xind_cen = pfit.x_mean.value
    yind_cen = pfit.y_mean.value

    # Return to NaNs
    im[ind_nan] = np.nan

    if return_sci:
        return xind_cen+1, yind_cen+1
    else:
        return xind_cen, yind_cen

def get_com(im, halfwidth=7, return_sci=False, **kwargs):
    """Center of mass centroiding"""
    
    from poppy.fwcentroid import fwcentroid

    # Set NaNs to 0
    ind_nan = np.isnan(im)
    im[ind_nan] = 0

    # Find center of mass centroid
    try:
        com = fwcentroid(im, halfwidth=halfwidth, **kwargs)
    except IndexError:
        hw = int(halfwidth / 2)
        com = fwcentroid(im, halfwidth=hw, **kwargs)
    yind_com, xind_com = com

    # Return to NaNs
    im[ind_nan] = np.nan

    if return_sci:
        return xind_com+1, yind_com+1
    else:
        return xind_com, yind_com

def get_peak(im, nsig_threshold=50, box_size=15, return_sci=False, **kwargs):
    
    from photutils.detection import find_peaks
    from . import robust

    #  Find peak position
    std = robust.medabsdev(im)
    threshold = nsig_threshold * std
    tbl = find_peaks(im, threshold, box_size=box_size, npeaks=1)
    xind_peak, yind_peak = (tbl[0]['x_peak'], tbl[0]['y_peak'])
    
    if return_sci:
        return xind_peak+1, yind_peak+1
    else:
        return xind_peak, yind_peak

def get_loc_all(files, indir, find_func=get_com, 
                fix_bad_pixels=True, **kwargs):

    from jwst.datamodels import dqflags
    from .image_manip import bp_fix
    
    star_locs = []
    for f in files:
        fpath = os.path.join(indir, f)
        
        # Open FITS file
        hdul = fits.open(fpath)

        # Crop and roughly center image
        data = hdul['SCI'].data
        try:
            dqmask = hdul['DQ'].data
        except KeyError:
            dqmask = np.zeros_like(data).astype(np.uint64)

        # If data is 3D, then get median image
        if len(data.shape) > 2:
            bpmask = (dqmask & dqflags.pixel['DO_NOT_USE']) > 0
            data[bpmask] = np.nan
            data = np.nanmedian(data, axis=0)
            # Bitwise AND of DQ mask
            dqmask = np.bitwise_and.reduce(dqmask, axis=0)

        # Get rough stellar position
        if find_func is get_expected_loc:
            xy = get_expected_loc(hdul[0].header, **kwargs)
        elif find_func is get_com:
            # Fix bad pixels
            bpmask = (dqmask & dqflags.pixel['DO_NOT_USE']) > 0

            if fix_bad_pixels:
                data = bp_fix(data, sigclip=20, in_place=False)
                data = bp_fix(data, bpmask=bpmask)
            else:
                data[bpmask] = np.nan

            xy = get_com(data, **kwargs)
        else:
            xy = find_func(data, **kwargs)
        
        star_locs.append(xy)
        
        # Close FITS file
        hdul.close()
        
    return np.array(star_locs)


def load_cropped_files(save_dir, files, xysub=65, bgsub=False, 
                       fix_bad_pixels=True, find_func=get_com, **kwargs):
    """Load a cropper version of the files
    
    Opens the files, crops them, and returns the cropped data, DQ arrays,
    indices of the cropped images, and bad pixel masks. The indices are an
    array of (x1, x2, y1, y2) in shape of (nfiles,4).

    Parameters
    ==========
    save_dir : str
        Directory where the files are saved
    files : list
        List of file names
    xysub : int
        Size of the subarray to use for cropping
    bgsub : bool
        If True, then subtract the background from the cropped image.
        The background region is defined as r>0.7*xysub/2.
    fix_bad_pixels : bool
        If True, then fix bad pixels in the cropped image.
    find_func : function
        Function to use to find the location of the star.
    """

    from jwst.datamodels import dqflags

    # Get index location and 'sci' position
    if find_func is get_com:
        kwargs['halfwidth'] = kwargs.get('halfwidth', 15)
    com_ind = get_loc_all(files, save_dir, find_func=find_func,
                          fix_bad_pixels=fix_bad_pixels, **kwargs)

    imsub_arr = []
    dqsub_arr = []
    xyind_arr = []
    for i, f in enumerate(files):
        fpath = os.path.join(save_dir, f)
        hdul = fits.open(fpath)

        ndim = len(hdul['SCI'].data.shape)
        data = hdul['SCI'].data[0] if ndim==3 else hdul['SCI'].data
        try:
            dqmask = hdul['DQ'].data[0] if ndim==3 else hdul['DQ'].data
        except KeyError:
            dqmask = np.zeros_like(data).astype(np.uint64)

        ny, nx = data.shape[-2:]

        # Crop and roughly center image
        data, xy = crop_image(data, xysub, xyloc=com_ind[i], return_xy=True)
        x1, x2, y1, y2 = xy
        if ndim==3:
            data = hdul['SCI'].data[:,y1:y2,x1:x2]
            try:
                dqmask = hdul['DQ'].data[:,y1:y2,x1:x2]
            except KeyError:
                dqmask = np.zeros_like(data).astype(np.uint64)
        else:
            try:
                dqmask = hdul['DQ'].data[y1:y2,x1:x2]
            except KeyError:
                dqmask = np.zeros_like(data).astype(np.uint64)
        
        # For arrays padded with 0s, flag those pixels as DO_NOT_USE
        indz = (data==0)
        dqmask[indz] = dqmask[indz] | dqflags.pixel['DO_NOT_USE']
        
        imsub_arr.append(data)
        dqsub_arr.append(dqmask)
        xyind_arr.append(xy)
        
        hdul.close()

    # Ensure data are of the same shape
    sh1 = imsub_arr[0].shape[-2:]
    xymin_size = np.min([sh1[0], sh1[1]])
    same_shape = True
    for i in range(1, len(imsub_arr)):
        sh2 = imsub_arr[i].shape[-2:]
        if sh1 != sh2:
            same_shape = False
        xymin_size = np.min([xymin_size, np.min([sh2[0], sh2[1]])])
        # Make sure xymin_size is odd
        if xymin_size % 2 == 0:
            xymin_size -= 1

    if not same_shape:
        raise ValueError(f'xysub={xysub} is too large shifted data of shape {(ny,nx)}. Trying shinking to {xymin_size}.')

    try:
        imsub_arr = np.asarray(imsub_arr)
        dqsub_arr = np.asarray(dqsub_arr)
        xyind_arr = np.asarray(xyind_arr)
    except:
        _log.warning('Unequal number of integrations. Concatenating arrays into [nim_tot,ny,nx].')
        imsub_arr = np.concatenate(imsub_arr, axis=0)
        dqsub_arr = np.concatenate(dqsub_arr, axis=0)
        xyind_arr = np.concatenate(xyind_arr, axis=0)
    bp_masks1 = (dqsub_arr & dqflags.pixel['OTHER_BAD_PIXEL']) > 0
    bp_masks = bp_masks1 | np.isnan(imsub_arr)

    # Do bg subtraction from r>bg_rad and only include good pixels
    if bgsub:
        # Radial position to set background
        bg_rad = int(0.7 * xysub / 2)
        ind_bg = dist_image(np.zeros([xysub,xysub])) > bg_rad
        for i in range(len(files)):
            imsub_arr_i = imsub_arr[i]
            bp_masks_i = bp_masks[i]
            ndim = len(imsub_arr_i.shape)
            if ndim==3:
                for j in range(imsub_arr_i.shape[0]):
                    indgood = (~bp_masks_i[j]) & ind_bg
                    imsub_arr_i[j] -= np.nanmedian(data[j][indgood])
            else:
                indgood = (~bp_masks_i) & ind_bg
                imsub_arr_i -= np.nanmedian(data[indgood])

    return imsub_arr, dqsub_arr, xyind_arr, bp_masks



def recenter_psf(psfs_over, niter=3, halfwidth=7, 
                 gfit=True, in_place=False, **kwargs):
    """Use Gaussian fit or center of mass algorithm to relocate PSF to center of image.
    
    Returns recentered PSFs and shift values used.

    Parameters
    ----------
    psfs_over : array_like
        Oversampled PSF(s) to recenter. If 2D, will be converted to 3D.
    niter : int
        Number of iterations to use for center of mass algorithm.
    halfwidth : int or None
        Halfwidth of box to use for center of mass algorithm.
        Default is 7, which is a 15x15 box.
    gfit : bool
        If True, use Gaussian fitting instead of center of mass.
    in_place : bool
        If True, then perform the shift in place, overwriting the input
        PSF array.
    """

    from .image_manip import fourier_imshift

    ndim = len(psfs_over.shape)
    if ndim==2:
        psfs_over = [psfs_over]

    if not in_place:
        psfs_over = psfs_over.copy()

    # Reposition oversampled PSF to center of array using center of mass algorithm
    xyoff_psfs_over = []
    for i, psf in enumerate(psfs_over):
        xc_psf, yc_psf = get_im_cen(psf)
        xsh_sum, ysh_sum = (0, 0)
        for j in range(niter):
            if gfit:
                xc, yc = get_gfit_cen(psf, xysub=2*halfwidth+1, 
                                      return_sci=False, **kwargs)
            else:
                xc, yc = get_com(psf, halfwidth=halfwidth, return_sci=False)
            xsh, ysh = (xc_psf - xc, yc_psf - yc)
            psf = fourier_imshift(psf, xsh, ysh)
            xsh_sum += xsh
            ysh_sum += ysh
        psfs_over[i] = psf
        xyoff_psfs_over.append(np.array([xsh_sum, ysh_sum]))
        
        gc_str = 'Gaussian Fit' if gfit else 'CoM'
        _log.info(f"Recentered oversampled PSF ({xsh_sum:.3f}, {ysh_sum:.3f}) pixels using {gc_str} algorithm.")

    # Oversampled offsets
    xyoff_psfs_over = np.array(xyoff_psfs_over)


    # If input was a single image, return same dimensions
    if ndim==2:
        psfs_over = psfs_over[0]
        xyoff_psfs_over = xyoff_psfs_over[0]

    return psfs_over, xyoff_psfs_over


[docs] def subtract_psf(image, psf, osamp=1, bpmask=None, rin=None, rout=None, xyshift=(0,0), psf_scale=None, psf_offset=0, method='fourier', interp='lanczos', pad=True, cval=0, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_over=None, weights=None, return_sum2=False, return_scale=False, **kwargs): """ Subtract PSF from image Provide scale, offset, and shift values to PSF before subtraction. Uses `fractional_image_shift` function to shift PSF. Parameters ---------- image: ndarray Observed science image. psf: ndarray Oversampled PSF (shifted and scaled to match). osamp: int Oversampling factor of PSF. bpmask: bool array Bad pixel mask indicating pixels in input image to ignore. rin: float Inner radius of annulus for subtraction. Default is None. rout: float Outer radius of annulus for subtraction. Default is None. xyshift: tuple Shift values in (x,y) directions. Units of pixels. psf_scale: float Scale factor to apply to PSF. If set to None, then will find the best scaling factor. psf_offset: float Offset to apply to PSF. psf_corr_over: ndarray Oversampled PSF correction image. If provided, then this image is multiplied with the PSF after diffusion. These are empirical corrections to the STPSF model to better match the observed PSF. kipc: ndarray 3x3 array of IPC kernel values. If None, then no IPC is applied. kppc: ndarray 3x3 array of PPC kernel values. If None, then no PPC is applied. Should already be oriented along readout direction of PSF. diffusion_sigma: float Sigma value for Gaussian diffusion kernel. If None, then no diffusion is applied. In units of detector pixels. weights: ndarray Array of weights to use during the fitting process. Useful if you have bad pixels to mask out (ie., set them to zero). Default is None (no weights). Should be same size as image. Recommended is inverse variance map. method : str Method to use for shifting. Options are: - 'fourier' : Shift in Fourier space - 'fshift' : Shift using interpolation - 'opencv' : Shift using OpenCV warpAffine interp : str Interpolation method to use for shifting using 'fshift' or 'opencv. Default is 'cubic'. For 'opencv', valid options are 'linear', 'cubic', and 'lanczos'. for 'fshift', valid options are 'linear', 'cubic', and 'quintic'. pad : bool Should we pad the array before shifting, then truncate? Otherwise, the image is wrapped. cval : sequence or float, optional The values to set the padded values for each axis. Default is 0. ((before_1, after_1), ... (before_N, after_N)) unique pad constants for each axis. ((before, after),) yields same before and after constants for each axis. (constant,) or int is a shortcut for before = after = constant for all axes. return_sum2 : bool Return the sum of the squared difference between the image and PSF. Default is False. Keyword Args ------------ gstd_pix : float Standard deviation of Gaussian kernel to blur PSF during shift. oversample : int Oversampling factor for fractional shift. Default is 1. order : int Interpolation order for oversampling during shifting. Default is 1. rescale_pix : bool Explicitly rescale the pixel values during resampling to ensure that the flux within a superpixel is preserved. Default is False (zoom default behavior). """ from webbpsf_ext.image_manip import image_shift_with_nans # from webbpsf_ext.image_manip import apply_pixel_diffusion, add_ipc, add_ppc # from webbpsf_ext.coords import dist_image # Shift oversampled PSF and xsh_over, ysh_over = np.array(xyshift) * osamp if method is not None: kwargs_shift = {} kwargs_shift['pad'] = pad kwargs_shift['cval'] = cval if method in ['fshift', 'opencv']: kwargs_shift['interp'] = interp # Scale Gaussian std dev by oversampling factor gstd_pix = kwargs.pop('gstd_pix', None) if gstd_pix is not None: kwargs_shift['gstd_pix'] = gstd_pix * osamp # psf_over = fractional_image_shift(psf, xsh_over, ysh_over, method=method, **kwargs_shift) # Perform oversampling during shifting process? kwargs_shift['oversample'] = kwargs.pop('oversample', 1) kwargs_shift['order'] = kwargs.pop('order', 1) kwargs_shift['rescale_pix'] = kwargs.pop('rescale_pix', False) psf_over = image_shift_with_nans(psf, xsh_over, ysh_over, shift_method=method, **kwargs_shift) # Charge diffusion if diffusion_sigma is not None: sigma_osamp = diffusion_sigma * osamp psf_over = apply_pixel_diffusion(psf_over, sigma_osamp) # Apply PSF correction if psf_corr_over is not None: psf_over *= crop_image(psf_corr_over, psf_over.shape, fill_val=1) # Rebin to detector sampling psf_det = frebin(psf_over, scale=1/osamp) if osamp!=1 else psf_over # Add IPC to detector-sampled PSF if kipc is not None: psf_det = add_ipc(psf_det, kernel=kipc) if kppc is not None: psf_det = add_ppc(psf_det, kernel=kppc, nchans=1) # Crop image if psf_det.shape != image.shape: psf_det = crop_image(psf_det, image.shape) if psf_scale is None: # Get optimal scale factor between images # Ignore NaNs and zeros good_mask = ~np.isnan(image) & ~np.isnan(psf_det) good_mask = good_mask & (~np.isclose(image,0)) & (~np.isclose(psf_det,0)) if bpmask is not None: good_mask &= ~bpmask if (rin is not None) or (rout is not None): rho = dist_image(image) rin = 0 if rin is None else rin rout = np.inf if rout is None else rout good_mask &= (rho >= rin) & (rho <= rout) im_good = image[good_mask].flatten() - psf_offset psf_good = psf_det[good_mask].flatten() cf = np.linalg.lstsq(psf_good.reshape([1,-1]).T, im_good, rcond=None)[0] psf_scale = cf[0] psf_det = psf_det * psf_scale + psf_offset # Subtract PSF from image diff = image - psf_det if weights is not None: diff = diff * weights if return_sum2: # Set anything that are 0 in either image as zero in difference zmask = np.isclose(image,0) | np.isclose(psf_det,0) nmask = np.isnan(image) | np.isnan(psf_det) mask = zmask | nmask if bpmask is not None: mask |= bpmask diff[mask] = 0 return (np.sum(diff**2), psf_scale) if return_scale else np.sum(diff**2) else: return (diff, psf_scale) if return_scale else diff
def correl_images(im1, im2, mask=None): """ Image correlation coefficient Calculate the 2D cross-correlation coefficient between two images or array of images. Images must have the same x and y dimensions and should alredy be aligned. Parameters ---------- im1 : ndarray Single image or image cube (nz1, ny, nx). im2 : ndarray Single image or image cube (nz2, ny, nx). If both im1 and im2 are cubes, then returns a matrix of coefficients. mask : ndarry or None If set, then a binary mask of 1=True and 0=False. Excludes pixels marked with 0s/False. Must be same size/shape as images (ny, nx). Any NaNs in the images will automatically be masked. """ sh1 = im1.shape sh2 = im2.shape if len(sh1)==2: ny1, nx1 = sh1 nz1 = 1 im1.reshape([nz1,ny1,nx1]) else: nz1, ny1, nx1 = sh1 if len(sh2)==2: ny2, nx2 = sh2 nz2 = 1 im2.reshape([nz2,ny2,nx2]) else: nz2, ny2, nx2 = sh2 assert (nx1==nx2) and (ny1==ny2), "Input images must have same sizes" im1 = im1.reshape([nz1,-1]) im2 = im2.reshape([nz2,-1]) # Mask out NaNs nanvals = np.sum(np.isnan(im1), axis=0) + np.sum(np.isnan(im2), axis=0) nan_mask = nanvals > 0 nan_mask = nan_mask.reshape([ny1,nx1]) if (np.sum(nan_mask) > 0) and (mask is None): mask = ~nan_mask elif (np.sum(nan_mask) > 0) and (mask is not None): mask = mask & ~nan_mask # Apply masking if mask is not None: im1 = im1[:, mask.ravel()] im2 = im2[:, mask.ravel()] # Subtract mean from each axes im1 = im1 - np.mean(im1, axis=1).reshape([-1,1]) im2 = im2 - np.mean(im2, axis=1).reshape([-1,1]) # Calculate numerators for each image pair correl_top = np.dot(im1, im2.T) # Calculate denominators for each image pair im1_tot = np.sum(im1**2, axis=1) im2_tot = np.sum(im2**2, axis=1) correl_bot = np.sqrt(np.multiply.outer(im1_tot, im2_tot)) correl_fin = correl_top / correl_bot if correl_fin.size==1: return correl_fin.flatten()[0] else: return correl_fin.squeeze() def sample_crosscorr(corr, xcoarse, ycoarse, xfine, yfine, method='cubic'): """Perform a cubic interpolation over the coarse grid""" from scipy.interpolate import griddata xycoarse = np.asarray(np.meshgrid(xcoarse, ycoarse)).reshape([2,-1]).transpose() # Sub-sampling shifts to interpolate over xv, yv = np.meshgrid(xfine, yfine) # Perform cubic interpolation corr_fine = griddata(xycoarse, corr.flatten(), (xv, yv), method=method) return corr_fine def find_max_crosscorr(corr, xsh_arr, ysh_arr, sub_sample): """Interpolate finer grid onto cross corr map and location max position""" # Sub-sampling shifts to interpolate over # sub_sample = 0.01 xsh_fine_vals = np.arange(xsh_arr[0],xsh_arr[-1],sub_sample) ysh_fine_vals = np.arange(ysh_arr[0],ysh_arr[-1],sub_sample) corr_all_fine = sample_crosscorr(corr, xsh_arr, ysh_arr, xsh_fine_vals, ysh_fine_vals) # Find position iymax, ixmax = np.argwhere(corr_all_fine==np.nanmax(corr_all_fine))[0] xsh_fine, ysh_fine = xsh_fine_vals[ixmax], ysh_fine_vals[iymax] return xsh_fine, ysh_fine def gen_psf_offsets(psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), dxy=0.05, psf_osamp=1, shift_func=fourier_imshift, ipc_vals=None, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=None, monitor=False, prog_leave=False, **kwargs): """ Generate a series of downsampled cropped and shifted PSF images If fov_pix is odd, then crop should be odd. If fov_pix is even, then crop should be even. Add IPC: Either ipc_vals = 0.006 or ipc_vals=[0.006,0.0004]. The former add 0.6% to each side pixel, while the latter includes 0.04% to the corners. Can also supply kernel directly with kipc. Add PPC: Specify kppc kernel directly. This must be correctly oriented for the PSF image readout direction. Assumes single output amplifier. """ psf_is_even = np.mod(psf.shape[0] / psf_osamp, 2) == 0 psf_is_odd = not psf_is_even crop_is_even = np.mod(crop, 2) == 0 crop_is_odd = not crop_is_even if (psf_is_even and crop_is_odd) or (psf_is_odd and crop_is_even): crop = crop + 1 crop_is_even = np.mod(crop, 2) == 0 crop_is_odd = not crop_is_even _log.warning('PSF and crop must both be even or odd. Incrementing crop by 1.') # Range of offsets to probe in fractional pixel steps xmin_pix, xmax_pix = xlim_pix ymin_pix, ymax_pix = ylim_pix # Pixel offsets xoff_pix = np.arange(xmin_pix, xmax_pix+dxy, dxy) yoff_pix = np.arange(ymin_pix, ymax_pix+dxy, dxy) # Create a grid and flatten xoff_all, yoff_all = np.meshgrid(xoff_pix, yoff_pix) xoff_all = xoff_all.flatten() yoff_all = yoff_all.flatten() # Make initial crop so we don't shift entire image crop_init = crop + int(2*(np.max(np.abs(np.concatenate([xoff_pix, yoff_pix]))) + 1)) crop_init_over = crop_init * psf_osamp psf0 = crop_image(psf, crop_init_over) # psf0 = pad_or_cut_to_size(psf, crop_init_over) # Create a series of shifted PSFs to compare to images psf_sh_all = [] if monitor: iter_vals = tqdm(zip(xoff_all, yoff_all), total=len(xoff_all), leave=prog_leave) else: iter_vals = zip(xoff_all, yoff_all) for xoff, yoff in iter_vals: xoff_over = xoff*psf_osamp yoff_over = yoff*psf_osamp crop_over = crop*psf_osamp psf_sh = crop_image(psf0, crop_over, xyloc=None, delx=xoff_over, dely=yoff_over, shift_func=shift_func, **kwargs) # psf_sh = pad_or_cut_to_size(psf0, crop_over, offset_vals=(-yoff_over,-xoff_over), # shift_func=shift_func, pad=True) # Apply pixel diffusion as Gaussian kernel if (diffusion_sigma is not None) and (diffusion_sigma > 0): dsig = diffusion_sigma * psf_osamp psf_sh = apply_pixel_diffusion(psf_sh, dsig) # Apply PSF correction image if psf_corr_image is not None: psf_corr_im_sh = crop_image(psf_corr_image, crop_over, xyloc=None, delx=xoff_over, dely=yoff_over, shift_func=shift_func, fill_val=1, **kwargs) psf_sh *= psf_corr_im_sh # Rebin to detector pixels psf_sh = frebin(psf_sh, scale=1/psf_osamp) psf_sh_all.append(psf_sh) psf_sh_all = np.asarray(psf_sh_all) psf_sh_all[np.isnan(psf_sh_all)] = 0 # Add IPC if (kipc is not None) or (ipc_vals is not None): # Build kernel if it wasn't already specified if kipc is None: if isinstance(ipc_vals, (tuple, list, np.ndarray)): a1, a2 = ipc_vals else: a1, a2 = ipc_vals, 0 kipc = np.array([[a2,a1,a2], [a1,1-4*(a1+a2),a1], [a2,a1,a2]]) psf_sh_all = add_ipc(psf_sh_all, kernel=kipc) # Add PPC if (kppc is not None): # Build kernel if it wasn't already specified psf_sh_all = add_ppc(psf_sh_all, kernel=kppc, nchans=1) # Reshape to grid # sh_grid = (len(yoff_pix), len(xoff_pix)) # xoff_all = xoff_all.reshape(sh_grid) # yoff_all = yoff_all.reshape(sh_grid) return xoff_pix, yoff_pix, psf_sh_all def find_offsets(input, psf, crop=65, xlim_pix=(-3,3), ylim_pix=(-3,3), shift_func=fshift, rin=0, rout=None, dxy_coarse=0.05, dxy_fine=0.01, **kwargs): """Find offsets necessary to align observations with input psf""" # Check if input is a dictionary is_dict = True if isinstance(input, dict) else False res = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=ylim_pix, dxy=dxy_coarse, shift_func=shift_func) xoff_pix, yoff_pix, psf_sh_all = res # Grid shape sh_grid = (len(yoff_pix), len(xoff_pix)) # Cycle through each SGD position keys = list(input.keys()) if is_dict else None xsh0_pix = [] ysh0_pix = [] iter_vals = tqdm(keys) if is_dict else tqdm(input) for val in iter_vals: if is_dict: d = input[val] im = crop_observation(d['data'], d['ap'], crop) else: im = pad_or_cut_to_size(val, crop) # Create masks rdist = dist_image(im) rin = 0 if rin is None else rin rmask = (rdist>=rin) if rout is None else (rdist>=rin) & (rdist<=rout) # Exclude 0s and NaNs zmask = (im!=0) & (~np.isnan(im)) ind_mask = rmask & zmask # Cross-correlate to find best x,y shift to align image with PSF cc = correl_images(psf_sh_all, im, mask=ind_mask) cc = cc.reshape(sh_grid) # Cubic interplotion of cross correlation image onto a finer grid xsh, ysh = find_max_crosscorr(cc, xoff_pix, yoff_pix, dxy_fine) xsh0_pix.append(xsh) ysh0_pix.append(ysh) xsh0_pix = np.array(xsh0_pix) ysh0_pix = np.array(ysh0_pix) return xsh0_pix, ysh0_pix def find_offsets2(input, xoff_pix, yoff_pix, psf_sh_all, bpmasks=None, crop=65, rin=0, rout=None, dxy_fine=0.01, prog_leave=True, return_more=False, lsq_diff=False, **kwargs): """Find offsets necessary to align observations with input psf""" # Check if input is a dictionary is_dict = True if isinstance(input, dict) else False # Make sure input image is 3D if not is_dict and len(input.shape)==2: input2d = True input = [input] else: input2d = False if (bpmasks is not None) and (len(bpmasks.shape)==2): bpmasks = [bpmasks] # Grid shape sh_grid = (len(yoff_pix), len(xoff_pix)) # Cycle through each SGD position keys = list(input.keys()) if is_dict else None xsh0_pix = [] ysh0_pix = [] if is_dict and len(keys)==1: iter_vals = keys elif is_dict and len(keys)>1: tqdm(keys,leave=prog_leave) elif len(input)==1: iter_vals = input else: iter_vals = tqdm(input, leave=prog_leave) # iter_vals = tqdm(keys,leave=prog_leave) if is_dict else tqdm(input,leave=prog_leave) i = 0 if return_more: res_dict = {} for val in iter_vals: if crop is None: im0 = input[val]['data'] if is_dict else val ny1, nx1 = im0.shape _, ny2, nx2 = psf_sh_all ny_crop = np.min([ny1, ny2]) nx_crop = np.min([nx1, nx2]) crop = (ny_crop, nx_crop) # Crop the input image if is_dict: d = input[val] im = crop_observation(d['data'], d['ap'], crop) else: im = crop_image(val, crop) # Crop PSFs to match size psf_sh_crop = crop_image(psf_sh_all, crop) # Crop bp mask to match if bpmasks is None: bpmask = np.zeros_like(im).astype('bool') else: bpmask = crop_image(bpmasks[i], crop) i += 1 # print(im.shape, psf_sh_crop.shape, psf_sh_all.shape) # Create masks rdist = dist_image(im) rin = 0 if rin is None else rin rmask = (rdist>=rin) if rout is None else (rdist>=rin) & (rdist<=rout) # Exclude 0s and NaNs zmask = (im!=0) & (~np.isnan(im)) nanmask_psf = (psf_sh_crop==0) | np.isnan(psf_sh_crop) zmask2 = np.sum(nanmask_psf, axis=0) == 0 ind_mask = rmask & zmask & zmask2 & (~bpmask) if lsq_diff: # Least squares difference bpmask = ~ind_mask sum_sqrs = np.array([subtract_psf(im, psf, bpmask=bpmask, return_sum2=True) for psf in psf_sh_crop]) correlation_metric = 1 / sum_sqrs.reshape(sh_grid) else: # Cross-correlate to find best (x,y) shift to align image with PSF cc = correl_images(psf_sh_crop, im, mask=ind_mask) correlation_metric = cc.reshape(sh_grid) # Cubic interplotion of cross correlation image onto a finer grid xsh, ysh = find_max_crosscorr(correlation_metric, xoff_pix, yoff_pix, dxy_fine) if return_more: res_dict[i] = {'corr_map':correlation_metric, 'xoff_pix':xoff_pix, 'yoff_pix':yoff_pix} xsh0_pix.append(xsh) ysh0_pix.append(ysh) xsh0_pix = np.array(xsh0_pix) ysh0_pix = np.array(ysh0_pix) # If we had a single image input, return first elements if input2d: xsh0_pix = xsh0_pix[0] ysh0_pix = ysh0_pix[0] if return_more: return xsh0_pix, ysh0_pix, res_dict else: return xsh0_pix, ysh0_pix def find_offsets_phase(input, psf, crop=65, rin=0, rout=None, dxy_fine=0.01, prog_leave=False): """Use phase_cross_correlation to determine offset Returns offset (delx,dely) required to register input image[s] onto psf image. """ # Check if input is a dictionary is_dict = True if isinstance(input, dict) else False # Make sure input image is 3D if not is_dict and len(input.shape)==2: input = [input] # Cycle through each SGD position keys = list(input.keys()) if is_dict else None # Ensure PSF is correct size psf_sub = crop_image(psf, crop, fill_val=0) xsh0_pix = [] ysh0_pix = [] if prog_leave: iter_vals = tqdm(keys) if is_dict else tqdm(input) else: iter_vals = keys if is_dict else input for val in iter_vals: if is_dict: d = input[val] imfull = d['data'] im = crop_observation(imfull, d['ap'], crop).copy() else: imfull = val im = crop_image(imfull, crop, fill_val=0) # Create masks rdist = dist_image(im) rin = 0 if rin is None else rin rmask = (rdist>=rin) if rout is None else (rdist>=rin) & (rdist<=rout) # Exclude 0s and NaNs zmask = (im!=0) & (~np.isnan(im)) ind_mask = rmask & zmask # Zero-out bad pixels im[~ind_mask] = 0 # Initial offset required to move im onto psf_sub ysh, xsh = phase_cross_correlation(psf_sub, im, upsample_factor=1/dxy_fine, return_error=False) # Shift PSF in opposite direction to register onto im. # We do this under the assumption that PSF is more ideal (no bad pixels) compared to im, # so there will less fourier artifacts after the shift. # Then find any residual necessary moves. psf_sh = pad_or_cut_to_size(fourier_imshift(psf, -1*xsh, -1*ysh), crop) del_ysh, del_xsh = phase_cross_correlation(psf_sh, im, upsample_factor=1/dxy_fine, return_error=False) xsh += del_xsh ysh += del_ysh xsh0_pix.append(xsh) ysh0_pix.append(ysh) xsh0_pix = np.array(xsh0_pix) ysh0_pix = np.array(ysh0_pix) res = np.array([xsh0_pix, ysh0_pix]).T return res.squeeze() def find_pix_offsets(imsub_arr, psfs, psf_osamp=1, bpmask_arr=None, crop=None, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=None, phase=False, xcorr=True, lsq_diff=False, **kwargs): """Find number of pixels to offset PSFs to corrsponding images If multple methods are selected, then will return values for each in a dictionary. If only one method is selected, then will return a single array of offsets. Parameters ---------- imsub_arr : ndarray Array of cropped images psfs : ndarray Array of PSFs to align to images. Either same number of images or a single PSF to align to all images. psf_osamp : int Oversampling factor of PSFs bpmask_arr : ndarray Bad pixel mask array. Should be same shape as imsub_arr. diffusion_sigma : float Diffusion kernel sigma value to apply to psfs. kipc : ndarray IPC kernel to apply to PSFs. kppc : ndarray PPC kernel. Should already align to readout direction of detector along rows. phase : bool Use phase cross-correlation to find offsets psf_corr_image : ndarray Correction factor to multiply PSF after diffussion align_method : str Method to use to align images. Options are 'xcorr', 'phase', or 'lsqdiff'. Default is 'xcorr'. For 'xcorr', peform traditional corr correlation to find offsets. For 'phase', use phase cross correlation to find offsets. For 'lsqdiff', use least squares difference to find offsets. Keyword Args ============ rin : float Exclude pixels interior to this radius. rout : float or None Exclude pixel exterior to this radius. xylim_pix : tuple or list Initial coarse step range in detector pixels. """ def find_pix_phase(im, psf, psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=None, crop=15, **kwargs): # Rebin to detector sampling if psf_osamp!=1: psf = frebin(psf, scale=1/psf_osamp) # Add diffusion if (diffusion_sigma is not None) and (diffusion_sigma>0): psf = apply_pixel_diffusion(psf, diffusion_sigma) # Apply PSF correction image if psf_corr_image is not None: psf *= crop_image(psf_corr_image, psf.shape[-2:], fill_val=1) # Add IPC if kipc is not None: psf = add_ipc(psf, kernel=kipc) # Add PPC if kppc is not None: psf = add_ppc(psf, kernel=kppc, nchans=1) rin = kwargs.get('rin', 0) rout = kwargs.get('rout', None) res = find_offsets_phase(im, psf, crop=crop, rin=rin, rout=rout, dxy_fine=0.001) return res def find_pix_cc(im, psf, psf_osamp, bpmask=None, crop=33, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=None, lsq_diff=False, return_grids=False, **kwargs): """Cross correlate by shifting PSF in fine steps""" # Create a series of coarse offset PSFs to find initial estimate xylim_pix = kwargs.get('xylim_pix') if xylim_pix is not None: xlim_pix = ylim_pix = xylim_pix else: xlim_pix = ylim_pix = (-5,5) dxy_coarse = kwargs.pop('dxy_coarse', 0.250) dxy_fine = kwargs.pop('dxy_fine', 0.005) res_coarse = kwargs.get('res_coarse', None) if res_coarse is None: res_coarse = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=xlim_pix, dxy=dxy_coarse, psf_osamp=psf_osamp, kipc=None, kppc=None, diffusion_sigma=None, psf_corr_image=psf_corr_image, prog_leave=False, shift_func=fshift, **kwargs) xoff_pix, yoff_pix, psf_sh_all = res_coarse # psf_sh_all are cropped to `crop` value, whereas im is still input size xsh_coarse, ysh_coarse = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, bpmasks=bpmask, crop=crop, dxy_fine=dxy_coarse, prog_leave=False, **kwargs) # Create finer grid of offset PSFs xlim_pix = (xsh_coarse-dxy_coarse/2, xsh_coarse+dxy_coarse/2) ylim_pix = (ysh_coarse-dxy_coarse/2, ysh_coarse+dxy_coarse/2) res2 = gen_psf_offsets(psf, crop=crop, xlim_pix=xlim_pix, ylim_pix=ylim_pix, dxy=dxy_fine, psf_osamp=psf_osamp, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, psf_corr_image=psf_corr_image, prog_leave=False, **kwargs) xoff_pix, yoff_pix, psf_sh_all = res2 # Perform cross correlations and interpolate at 0.001 pixel xsh_fine, ysh_fine = find_offsets2(im, xoff_pix, yoff_pix, psf_sh_all, bpmasks=bpmask, crop=crop, dxy_fine=0.001, lsq_diff=lsq_diff, prog_leave=False, **kwargs) res = (xsh_fine, ysh_fine) if return_grids: return res, res_coarse else: return res sh_orig = imsub_arr.shape sh_orig_psfs = psfs.shape if len(sh_orig)==2: imsub_arr = [imsub_arr] bpmask_arr = [bpmask_arr] psfs = [psfs] elif len(sh_orig_psfs)==2: psfs = [psfs] xysh_pix_phase = [] xysh_pix_cc = [] xysh_pix_lsq = [] iter_vals = trange(len(imsub_arr), desc='Image Alignment', leave=False) if len(imsub_arr)>=10 else range(len(imsub_arr)) for i in iter_vals: im = imsub_arr[i] # If only a single PSF was passed, then use it for all images psf = psfs[i] if sh_orig==sh_orig_psfs else psfs[0] if crop is None: crop = 15 if phase else 21 # Ensure crop is at least 20 pixels larger than rin rin = kwargs.get('rin', 0) if crop-rin < 20: crop = rin + 20 # Ensure crop is odd if np.mod(crop, 2)==0: crop += 1 if phase: res = find_pix_phase(im, psf, psf_osamp, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, psf_corr_image=psf_corr_image, crop=crop, **kwargs) xysh_pix_phase.append(res) elif xcorr or lsq_diff: # Only set to return grid on first iteration return_grids = True if len(sh_orig)==3 and len(sh_orig_psfs)==2 and i==0 else False try: bpmask = bpmask_arr[i] except TypeError: bpmask = None # Do cross-correlation if xcorr: res = find_pix_cc(im, psf, psf_osamp, bpmask=bpmask, crop=crop, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, psf_corr_image=psf_corr_image, lsq_diff=False, return_grids=return_grids, **kwargs) # Set res_coarse going forward if return_grids and i==0: res, res_coarse = res kwargs['res_coarse'] = res_coarse return_grids = False xysh_pix_cc.append(res) # Do least squares difference if lsq_diff: res = find_pix_cc(im, psf, psf_osamp, bpmask=bpmask, crop=crop, kipc=kipc, kppc=kppc, diffusion_sigma=diffusion_sigma, psf_corr_image=psf_corr_image, lsq_diff=True, return_grids=return_grids, **kwargs) # Set res_coarse going forward if return_grids and i==0: res, res_coarse = res kwargs['res_coarse'] = res_coarse return_grids = False xysh_pix_lsq.append(res) if len(sh_orig)==2 and len(xysh_pix_phase)>0: xysh_pix_phase = np.asarray(xysh_pix_phase[0]) if len(sh_orig)==2 and len(xysh_pix_cc)>0: xysh_pix_cc = np.asarray(xysh_pix_cc[0]) if len(sh_orig)==2 and len(xysh_pix_lsq)>0: xysh_pix_lsq = np.asarray(xysh_pix_lsq[0]) if phase + xcorr + lsq_diff > 1: res = {} if phase: res['phase'] = xysh_pix_phase if xcorr: res['xcorr'] = xysh_pix_cc if lsq_diff: res['lsqdiff'] = xysh_pix_lsq else: if phase: res = xysh_pix_phase elif xcorr: res = xysh_pix_cc elif lsq_diff: res = xysh_pix_lsq return res ########################################################################### # MAST and Guidestar Catalog Retrieval ########################################################################### def download_file(filename, outdir=None, timeout=None, mast_api_token=None, overwrite=False, verbose=False): """ Download a MAST file Modified from M. Perrin's tools: https://github.com/mperrin/misc_jwst/blob/main/misc_jwst/guiding_analyses.py Parameters ---------- filename : str Name of file to download outdir : str Output directory timeout : float Timeout in seconds to wait for download to start mast_api_token : str MAST API token overwrite : bool Overwrite existing file? verbose : bool Print extra info? """ import requests, io from astropy.utils.console import ProgressBarOrSpinner from astropy.utils.data import conf blocksize = conf.download_block_size outpath = os.path.join(outdir, filename) if outdir is not None else filename if os.path.isfile(outpath) and (not overwrite): if verbose: print("ALREADY DOWNLOADED: ", outpath) return mast_url='https://mast.stsci.edu/api/v0.1/Download/file' uri_prefix = 'mast:JWST/product/' uri = uri_prefix + filename # Include MAST API token mast_api_token = os.environ.get('MAST_API_TOKEN') if mast_api_token is None else mast_api_token headers = None if mast_api_token is None else dict(Authorization=f"token {mast_api_token}") response = requests.get(mast_url, params=dict(uri=uri), timeout=timeout, stream=True, headers=headers) try: response.raise_for_status() except requests.exceptions.HTTPError as exc: token1 = os.environ.get('MAST_API_TOKEN') token2 = os.environ.get('MAST_API_TOKEN2') token_list = [token1, token2] for i, token in enumerate(token_list): if (token is not None) and (token != mast_api_token): _log.info(f'Attempting alternate MAST_API_TOKEN...') headers = dict(Authorization=f"token {token}") response = requests.get(mast_url, params=dict(uri=uri), timeout=timeout, stream=True, headers=headers) try: response.raise_for_status() except: if i==len(token_list)-1: raise Exception(exc) else: break # Full URL of data product url = mast_url + uri if 'content-length' in response.headers: length = int(response.headers['content-length']) if length == 0: _log.warning(f'URL {url} has length=0') else: length = None # Only show progress bar if logging level is INFO or lower. if _log.getEffectiveLevel() <= 20: progress_stream = None # Astropy default else: progress_stream = io.StringIO() bytes_read = 0 msg = f'Downloading URL {url} to {outpath} ...' with ProgressBarOrSpinner(length, msg, file=progress_stream) as pb: with open(outpath, 'wb') as fd: for data in response.iter_content(chunk_size=blocksize): fd.write(data) bytes_read += len(data) if length is not None: pb.update(bytes_read if bytes_read <= length else length) else: pb.update(bytes_read) response.close() return response def retrieve_mast_files(filenames, outdir=None, verbose=False, **kwargs): """Download one or more guiding data products from MAST Modified from M. Perrin's tools: https://github.com/mperrin/misc_jwst/blob/main/misc_jwst/guiding_analyses.py """ outputs = [] for f in filenames: download_file(f, outdir=outdir, **kwargs) # Check if files exist and append to outputs outfile = os.path.join(outdir, f) if outdir is not None else f if not os.path.isfile(outfile): print("ERROR: " + outfile + " failed to download.") else: if verbose: print("COMPLETE: ", outfile) outputs.append(outfile) return outputs def set_params(parameters): """Utility function for making dicts used in MAST queries""" return [{"paramName":p, "values":v} for p,v in parameters.items()] import functools @functools.lru_cache def find_relevant_guiding_file(sci_filename, outdir=None, verbose=False, uncals=False, **kwargs): """ Download fine guide file for a given science data proejct Given a filename of a JWST science file, retrieve the relevant guiding data product. This uses FITS keywords in the science header to determine the time period and guide mode, and then retrieves the file from MAST Modified from M. Perrin's tools: https://github.com/mperrin/misc_jwst/blob/main/misc_jwst/guiding_analyses.py """ import astropy from astroquery.mast import Mast sci_hdul = fits.open(sci_filename) progid = sci_hdul[0].header['PROGRAM'] obs = sci_hdul[0].header['OBSERVTN'] guidemode = sci_hdul[0].header['PCS_MODE'] # Set output directory if it doesn't exist if outdir is None: mast_dir = os.getenv('JWSTDOWNLOAD_OUTDIR', None) if mast_dir is not None: outdir = os.path.join(mast_dir, progid, 'fgs') # Create directory if it doesn't exist if not os.path.isdir(outdir): os.makedirs(outdir) # Set up the query keywords = { 'program': [progid], 'observtn': [obs], 'exp_type': ['FGS_'+guidemode], } params = { 'columns': '*', 'filters': set_params(keywords), } # Run the web service query. This uses the specialized, lower-level webservice for the # guidestar queries: https://mast.stsci.edu/api/v0/_services.html#MastScienceInstrumentKeywordsGuideStar service = 'Mast.Jwst.Filtered.GuideStar' t = Mast.service_request(service, params) if len(t) > 0: # Ensure unique file names, should any be repeated over multiple observations (e.g. if parallels): fn = list(set(t['fileName'])) # Set of derived Observation IDs: products = list(set(fn)) # If you want the uncals if uncals: products = list(set([x.replace('_cal','_uncal') for x in fn])) products.sort() if verbose: print(f"For science data file: {sci_filename}") print("Found guiding telemetry files:") for p in products: print(" ", p) # Some guide files are split into multiple segments, which we have to deal with. guide_timestamp_parts = [fn.split('_')[2] for fn in products] is_segmented = ['seg' in part for part in guide_timestamp_parts] for i in range(len(guide_timestamp_parts)): if is_segmented[i]: guide_timestamp_parts[i] = guide_timestamp_parts[i].split('-')[0] guide_timestamps = np.asarray(guide_timestamp_parts, int) t_beg = astropy.time.Time(sci_hdul[0].header['DATE-BEG']) t_end = astropy.time.Time(sci_hdul[0].header['DATE-END']) obs_end_time = int(t_end.strftime('%Y%j%H%M%S')) delta_times = np.array(guide_timestamps-obs_end_time, float) # want to find the minimum delta which is at least positive # try: delta_times_nan = delta_times.copy() delta_times_nan[delta_times<0] = np.nan wmatch = np.where(delta_times_nan == np.nanmin(delta_times_nan))[0][0] # except IndexError: # delta_times = np.abs(delta_times) # wmatch = np.where(delta_times == np.nanmin(delta_times))[0][0] delta_min = (guide_timestamps-obs_end_time)[wmatch] if verbose: print("Based on science DATE-END keyword and guiding timestamps, the matching GS file is: ") print(" ", products[wmatch]) print(f" t_end = {obs_end_time}\t delta = {delta_min}") if is_segmented[wmatch]: # We ought to fetch all the segmented GS files for that guide period products_to_fetch = [fn for fn in products if fn.startswith(products[wmatch][0:33])] if verbose: print(" That GS data is divided into multiple segment files:") print(" ".join(products_to_fetch)) else: products_to_fetch = [products[wmatch],] outfiles = retrieve_mast_files(products_to_fetch, outdir=outdir, verbose=verbose) return outfiles def get_jitter_balls(files_sci, indir, outdir=None, verbose=False, return_raw=False): """ Get jitter ball positions from guiding files Find the jitter ball positions from the guiding files associated with a science file. By default, downloads FGS fine guide files to MAST ouput directory if it exists and places into 'fgs' subdirectory. Otherwise, downloads to current working directory. Returns `(xoff_all, yoff_all)` lists of x and y positions for each science file. Values are in units of arcsec relative to the first science file. Parameters ---------- files_sci : list List of science file names indir : str Input directory of science files outdir : str Output directory for downloaded guiding files verbose : bool Print extra info during download? return_raw : bool Return raw xidl and yidl values instead of relative offsets? Default is False """ from astropy.table import Table, vstack from astropy.time import Time xidl_all = [] yidl_all = [] for sci_filename in files_sci: fpath = os.path.join(indir, sci_filename) # Find guidestar files and read in centroid data as astropy Table gs_files = find_relevant_guiding_file(fpath, outdir=outdir, verbose=verbose) for i, gs_fn in enumerate(gs_files): if i==0: centroid_table = Table.read(gs_fn, hdu=5) else: centroid_table = vstack([centroid_table, Table.read(gs_fn, hdu=5)], metadata_conflicts='silent') # Determine start and end times for the exposure with fits.open(fpath) as sci_hdul: t_beg = Time(sci_hdul[0].header['DATE-BEG']) t_end = Time(sci_hdul[0].header['DATE-END']) # Find the subset of centroid data during exposure ctimes = Time(centroid_table['observatory_time']) mask_good = centroid_table['bad_centroid_dq_flag'] == 'GOOD' ctimes_during_exposure = (t_beg < ctimes ) & (ctimes < t_end) & mask_good xpos = centroid_table[ctimes_during_exposure]['guide_star_position_x'] ypos = centroid_table[ctimes_during_exposure]['guide_star_position_y'] xidl_all.append(xpos) yidl_all.append(ypos) if return_raw: return xidl_all, yidl_all else: # Subtract nominal position xmean0 = np.mean(xidl_all[0]) ymean0 = np.mean(yidl_all[0]) xoff_all = [(x - xmean0) for x in xidl_all] yoff_all = [(y - ymean0) for y in yidl_all] return xoff_all, yoff_all @plt.style.context('webbpsf_ext.wext_style') def plot_jitter_balls(xoff_all, yoff_all, sci_filename=None, fov_size=50, save=False, save_dir=None, return_fixaxes=False): """ Plot jitter ball positions""" # Check that xoff_all and yoff_all are lists if not isinstance(xoff_all, list) or not isinstance(yoff_all, list): raise ValueError("xoff_all and yoff_all must be a list of arrays") # Convert to mas xoff_all = [x*1000 for x in xoff_all] yoff_all = [y*1000 for y in yoff_all] xoff_mean = np.array([np.mean(x) for x in xoff_all]) yoff_mean = np.array([np.mean(y) for y in yoff_all]) # Create Plots fig = plt.figure(figsize=(8,8), layout='constrained') # Create axes for scatter plot ax = fig.add_gridspec(top=0.75, right=0.75).subplots() ax.set_aspect('equal') # Create axes for histograms ax_histx = ax.inset_axes([0, 1.01, 1, 0.25], sharex=ax) ax_histy = ax.inset_axes([1.01, 0, 0.25, 1], sharey=ax) for i in range(len(xoff_all)): xoffsets = xoff_all[i] yoffsets = yoff_all[i] ax.scatter(xoffsets, yoffsets, alpha=0.1, marker='.', s=1) xylim = np.array([-1,1]) * fov_size/2 ax.set_xlim(xylim + xoff_mean[0]) ax.set_ylim(xylim + yoff_mean[0]) ax.set_xlabel("FGS Centroid Offset XIDL [mas]")#, fontsize=18) ax.set_ylabel("FGS Centroid Offset YIDL [mas]")#, fontsize=18) if sci_filename is not None: sci_filename_act = '_'.join(os.path.basename(sci_filename).split('_')[0:2]) fig.suptitle(f"Guiding during {sci_filename_act}_*", fontsize=14) else: fig.suptitle("Guiding during science exposures", fontsize=14) for i in range(len(xoff_all)): xc, yc = (xoff_mean[i], yoff_mean[i]) for j, rad in enumerate([1,2,3]): ax.add_artist(plt.Circle( (xc, yc), rad, fill=False, color='gray', ls='--')) if rad<fov_size/2 and i==0: ax.text(j*0.5, rad+0.1, f"{rad} mas", color='gray') # Draw histograms ax_histx.tick_params(axis="x", labelbottom=False) ax_histy.tick_params(axis="y", labelleft=False) bsize = 0.2 nbins = int(fov_size / bsize) xbins = np.linspace(xoff_mean[0]-fov_size/2, xoff_mean[0]+fov_size/2, nbins) ybins = np.linspace(yoff_mean[0]-fov_size/2, yoff_mean[0]+fov_size/2, nbins) for i in range(len(xoff_all)): xoffsets = xoff_all[i] yoffsets = yoff_all[i] ax_histx.hist(xoffsets, bins=xbins, alpha=0.8) ax_histy.hist(yoffsets, bins=ybins, orientation='horizontal', alpha=0.8) if save: figname = f'guiding_{sci_filename_act}.pdf' if save_dir is not None: figname = os.path.join(save_dir, figname) fig.savefig(figname, bbox_inches='tight') print(f"Saved {figname}") if return_fixaxes: return fig, (ax, ax_histx, ax_histy)