Source code for starred.plots.plot_function

import os

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from matplotlib.widgets import Slider
from copy import deepcopy
from astropy.visualization import simple_norm

from starred.utils.generic_utils import Downsample


CMAP_RR = 'RdBu_r'  # symmetric colormap for imshow panels showing residuals


[docs] def single_PSF_plot(model, data, sigma_2, kwargs, n_psf=0, figsize=(15, 8), units=None, upsampling=None, masks=None, mask_alpha=0.7, star_coordinates=None, n_sigma=5): """ Plots the narrow PSF fit for a single observation. :param model: array containing the model :param data: array containing the observations :param sigma_2: array containing the square of the noise maps :param kwargs: dictionary containing the parameters of the model :param n_psf: selected PSF index :type n_psf: int :param figsize: tuple that indicates the size of the figure :param units: units in which the pixel values are expressed :type units: str :param upsampling_factor: Provide the upsampling factor to degrade the model to the image resolution. Leave to 'None' to show the higher resolution model. :type upsampling_factor: int :param masks: Boolean masks :type masks: array of the size of your image :param star_coordinates: array of shape (N, 2), where N is the number of stamps in data, each row contains (x, y) coordinates, in pixels, with center the middle of the original astronomical image. default None. :param n_sigma: number of sigmas to clip the residuals. Default is 5. :type n_sigma: float :return: output figure """ if units is not None: str_unit = '[' + units + ']' else: str_unit = '' if masks is not None: alphas = deepcopy(masks) ind = np.where(masks == 0) alphas[ind] = mask_alpha else: alphas = np.ones_like(data) if star_coordinates is None: star_coordinates = np.zeros((len(data), 2)) estimated_full_psf = model.model(**kwargs, positions=star_coordinates)[n_psf] analytic = model.get_moffat(kwargs['kwargs_moffat'], norm=True) s = model.get_narrow_psf(**kwargs, position=star_coordinates[n_psf], norm=True) background = model.get_background(kwargs['kwargs_background']) if upsampling is not None: analytic = Downsample(analytic, factor=upsampling) s = Downsample(s, factor=upsampling) background = Downsample(background, factor=upsampling) dif = data[n_psf, :, :] - estimated_full_psf rr = dif / np.sqrt(sigma_2[n_psf, :, :]) fig, axs = plt.subplots(2, 3, figsize=figsize) fraction = 0.046 pad = 0.04 font_size = 14 ticks_size = 6 plt.rc('font', size=font_size) axs[0, 0].set_title('Data %s' % str_unit, fontsize=font_size) axs[0, 0].tick_params(axis='both', which='major', labelsize=ticks_size) axs[0, 1].set_title('PSF model %s' % str_unit, fontsize=font_size) axs[0, 1].tick_params(axis='both', which='major', labelsize=ticks_size) axs[0, 2].set_title('Map of relative residuals', fontsize=font_size) axs[0, 2].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 0].set_title('Moffat', fontsize=font_size) axs[1, 0].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 1].set_title('Grid of pixels', fontsize=font_size) axs[1, 1].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 2].set_title('Narrow PSF', fontsize=font_size) axs[1, 2].tick_params(axis='both', which='major', labelsize=ticks_size) fig.colorbar(axs[0, 0].imshow(data[n_psf, :, :], norm=colors.SymLogNorm(linthresh=100), origin='lower'), ax=axs[0, 0], fraction=fraction, pad=pad, format='%.0e') fig.colorbar(axs[0, 1].imshow(estimated_full_psf, norm=colors.SymLogNorm(linthresh=100), origin='lower'), ax=axs[0, 1], fraction=fraction, pad=pad, format='%.0e') fig.colorbar(axs[0, 2].imshow(rr, origin='lower', alpha = alphas[n_psf,:,:], cmap=CMAP_RR, vmin=-n_sigma, vmax=n_sigma), ax=axs[0, 2], fraction=fraction, pad=pad) fig.colorbar(axs[1, 0].imshow(analytic, norm=colors.SymLogNorm(linthresh=1e-2), origin='lower'), ax=axs[1, 0], fraction=fraction, pad=pad) fig.colorbar(axs[1, 1].imshow(background, origin='lower'), ax=axs[1, 1], fraction=fraction, pad=pad) fig.colorbar(axs[1, 2].imshow(s, norm=colors.SymLogNorm(linthresh=1e-3), origin='lower'), ax=axs[1, 2], fraction=fraction, pad=pad) for ax in np.array(axs).flatten(): ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.tight_layout() return fig
[docs] def multiple_PSF_plot(model, data, sigma_2, kwargs, star_coordinates=None, masks=None, mask_alpha=0.7, figsize=None, units=None, vmin=None, vmax=None, n_sigma=5): """ Plots the narrow PSF fit for all observations. :param model: array containing the model :param data: array containing the observations :param sigma_2: array containing the square of the noise maps :param kwargs: dictionary containing the parameters of the model :param star_positions: array of shape (N, 2), containing the pixel coordinates of each star Default None, N is the number of stamps (data.shape[0]), coordinates relative to the center of the original astronomical image. :param masks: array containing the masks :param figsize: tuple that indicates the size of the figure :param units: units in which the pixel values are expressed :type units: str :param vmin: lower limit for displaying the residuals (in unit of noise sigma) :type vmin: float :param vmax: upper limit for displaying the residuals (in unit of noise sigma) :type vmax: float :param n_sigma: number of sigmas to clip the residuals. Default is 5. Except if both `vmin` and `vmax` are provided, the range will be [-n_sigma, +n_sigma]. :type n_sigma: float :return: output figure """ if units is not None: str_unit = '[' + units + ']' else: str_unit = '' if figsize is None: nimage,nx,ny = np.shape(data) figsize = (12+nimage*2, 10) fig, axs = plt.subplots(2, model.M, figsize=figsize) plt.subplots_adjust(wspace=0.3) if model.M == 1: axs = np.asarray([axs]).T fraction = 0.046 pad = 0.04 font_size = 14 plt.rc('font', size=12) fmt_PSF = '%.0e' fmt_residuals = '%2.f' if masks is not None: alphas = deepcopy(masks) ind = np.where(masks == 0) alphas[ind] = mask_alpha kargs = [{'alpha':alphas[i,:,:]} for i in range(model.M)] else: kargs = [{} for i in range(model.M)] if star_coordinates is None: star_coordinates = np.zeros((data.shape[0], 2)) for ka in kargs: if vmin is not None: ka['vmin'] = vmin if vmax is not None: ka['vmax'] = vmax elif vmin is None and vmax is None: ka['vmin'] = -n_sigma ka['vmax'] = +n_sigma all_estimated_full_psf = model.model(**kwargs, positions=star_coordinates) for i in range(model.M): estimated_full_psf = all_estimated_full_psf[i] axs[0, i].set_title('PSF model %i %s' % (i + 1, str_unit), fontsize=font_size) axs[0, i].tick_params(axis='both', which='major', labelsize=10) axs[1, i].set_title('Relative residuals %i' % (i + 1), fontsize=font_size) axs[1, i].tick_params(axis='both', which='major', labelsize=10) fig.colorbar(axs[0, i].imshow(estimated_full_psf, norm=colors.SymLogNorm(linthresh=100), origin='lower'), ax=axs[0, i], fraction=fraction, pad=pad, format=fmt_PSF) fig.colorbar(axs[1, i].imshow((data[i, :, :] - estimated_full_psf) / np.sqrt(sigma_2[i, :, :]), origin='lower', cmap=CMAP_RR, **kargs[i]), ax=axs[1, i], fraction=fraction, pad=pad, format=fmt_residuals) for ax in np.array(axs).flatten(): ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) return fig
[docs] def display_data(data, sigma_2=None, masks=None, figsize=None, units=None, center=None): """ Plots the observations and the noise maps. :param data: array containing the observations :param sigma_2: array containing the square of the noise maps :param figsize: tuple that indicates the size of the figure :param units: units in which the pixel values are expressed :type units: str :param center: x and y coordinates of the centers of the observations :return: output figure """ if sigma_2 is None: row = 1 show_sigma = False else: row = 2 show_sigma = True if figsize is None: nimage,nx,ny = np.shape(data) figsize = (12+nimage*2, 10) if units is not None: str_unit = '[' + units + ']' else: str_unit = '' n_image, nx, ny = np.shape(np.asarray(data)) fig, axs = plt.subplots(row, n_image, figsize=figsize) plt.subplots_adjust(wspace=0.3) if row == 1 and n_image == 1: axs = np.asarray([[axs]]) elif row == 1: axs = np.asarray([axs]) elif n_image == 1: axs = np.asarray([axs]).T fraction = 0.046 pad = 0.04 fontsize = 12 if masks is not None: alphas = deepcopy(masks) ind = np.where(masks == 0) alphas[ind] = 0.9 kargs = [{'alpha':alphas[i,:,:]} for i in range(n_image)] else: kargs = [{} for i in range(n_image)] for i in range(n_image): plt.rc('font', size=12) axs[0][i].set_title('Data %i %s' % (i + 1, str_unit), fontsize=fontsize) axs[0][i].tick_params(axis='both', which='major', labelsize=10) if show_sigma: axs[1][i].set_title('Noise map %i %s' % (i + 1, str_unit), fontsize=fontsize) axs[1][i].tick_params(axis='both', which='major', labelsize=10) fig.colorbar(axs[0][i].imshow(data[i, :, :], norm=colors.SymLogNorm(linthresh=10), origin='lower', **kargs[i]), ax=axs[0][i], fraction=fraction, pad=pad, format='%.0e') if center is not None: c_x, c_y = center[0], center[1] axs[0][i].scatter(nx / 2. + c_x[i] - 0.5, ny / 2. + c_y[i] - 0.5, marker='x', c='r') # +0.5 to mach matplotplit pixel convention if show_sigma: fig.colorbar( axs[1][i].imshow(np.sqrt(sigma_2[i, :, :]), norm=colors.SymLogNorm(linthresh=10), origin='lower', **kargs[i]), ax=axs[1][i], fraction=fraction, pad=pad, format='%2.f') for ax in axs.flatten(): ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.tight_layout() return fig
[docs] def dict_to_kwargs_list(dict): """ Transform dictionnary into a list kwargs. All entry must have the same lenght. :param """ k_list = [] keys = list(dict.keys()) for i in range(len(dict[keys[0]])): k_list.append({}) for key in keys: k_list[i][key]=dict[key][i] return k_list
[docs] def plot_deconvolution(model, data, sigma_2, s, kwargs, epoch=0, units=None, figsize=(15, 10), cut_dict=None): """ Plots the results of the deconvolution. :param data: array containing the observations. Has shape (n_epoch, n_pixel, n_pixel). :param sigma_2: array containing the square of the noise maps (n_epoch, n_pixel, n_pixel). :param s: array containing the narrow PSF (n_epoch, n_pixel*susampling factor, n_pixel*susampling factor). :param epoch: index of the epoch to plot :param kwargs: dictionary containing the parameters of the model :type epoch: int :param figsize: tuple that indicates the size of the figure :param units: units in which the pixel values are expressed :type units: str :return: output figure """ if units is not None: str_unit = '[' + units + ']' else: str_unit = '' if cut_dict is None : # Default setting cut_dict = { 'linthresh':[5e2,5e2,None,5e1,5e1,1e-3], 'vmin':[None, None, None, None, None, None], 'vmax':[None, None, None, None, None, None], } k_dict = dict_to_kwargs_list(cut_dict) output = model.model(kwargs)[epoch] deconv, h = model.getDeconvolved(kwargs, epoch) data_show = data[epoch, :, :] dif = data_show - output rr = np.abs(dif) / np.sqrt(sigma_2[epoch, :, :]) fig, axs = plt.subplots(2, 3, figsize=(15, 8)) fraction = 0.046 pad = 0.04 font_size = 10 ticks_size = 6 plt.rc('font', size=font_size) axs[0, 0].set_title(f'Data {str_unit}', fontsize=8) axs[0, 0].tick_params(axis='both', which='major', labelsize=ticks_size) axs[0, 1].set_title(f'Convolving back {str_unit}', fontsize=8) axs[0, 1].tick_params(axis='both', which='major', labelsize=ticks_size) axs[0, 2].set_title('Map of relative residuals', fontsize=8) axs[0, 2].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 0].set_title(f'Background {str_unit}', fontsize=8) axs[1, 0].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 1].set_title(f'Deconvolved image {str_unit}', fontsize=8) axs[1, 1].tick_params(axis='both', which='major', labelsize=ticks_size) axs[1, 2].set_title('Narrow PSF', fontsize=8) axs[1, 2].tick_params(axis='both', which='major', labelsize=ticks_size) fig.colorbar(axs[0, 0].imshow(data_show, norm=colors.SymLogNorm(**k_dict[0]), origin='lower'), ax=axs[0, 0], fraction=fraction, pad=pad) fig.colorbar(axs[0, 1].imshow(output, norm=colors.SymLogNorm(**k_dict[1]), origin='lower'), ax=axs[0, 1], fraction=fraction,pad=pad) if 'linthresh' in k_dict[2].keys(): del k_dict[2]['linthresh'] fig.colorbar(axs[0, 2].imshow(rr, origin='lower', cmap=CMAP_RR, **k_dict[2]), ax=axs[0, 2], fraction=fraction, pad=pad) fig.colorbar(axs[1, 0].imshow(h, norm=colors.SymLogNorm(**k_dict[3]), origin='lower'), ax=axs[1, 0], fraction=fraction,pad=pad) fig.colorbar(axs[1, 1].imshow(deconv, norm=colors.SymLogNorm(**k_dict[4]), origin='lower'), ax=axs[1, 1],fraction=fraction, pad=pad) fig.colorbar(axs[1, 2].imshow(s[epoch, :, :], norm=colors.SymLogNorm(**k_dict[5]), origin='lower'), ax=axs[1, 2],fraction=fraction, pad=pad) return fig
[docs] def view_deconv_model(model, kwargs, data, sigma_2, figsize=(9, 7.5), cmap='gist_heat'): output = model.model(kwargs) psf = model.psf noisemap = sigma_2 ** 0.5 # setup for first epoch deconvs = [model.getDeconvolved(kwargs, i) for i in range(len(output))] decs, hs = zip(*deconvs) # subtract the constant component from h hs = [h - kwargs['kwargs_background']['mean'][i] for i, h in enumerate(hs)] deconv = decs[0] h = hs[0] normdeconv = simple_norm(deconv, stretch='asinh', percent=99.9) if np.std(h) > 0.: normh = simple_norm(h, stretch='asinh', percent=99.99) else: normh = simple_norm(h) normh.vmin = 0 normh.vmax = 1 s = psf[0] ########################################################################## # figure fig, axs = plt.subplots(2, 3, figsize=figsize) for ax in axs.flatten(): ax.set_xticks([]) ax.set_yticks([]) datap = axs[0, 0].imshow(data[0], origin='lower', cmap=cmap) axs[0, 0].set_title('data') modelp = axs[0, 1].imshow(output[0], origin='lower', cmap=cmap) axs[0, 1].set_title('model') diffp = axs[1, 0].imshow((data[0] - output[0]) / noisemap[0], origin='lower', cmap=CMAP_RR) axs[1, 0].set_title('(data-model)/noise') backp = axs[1, 1].imshow(h, origin='lower', cmap=cmap, norm=normh) axs[1, 1].set_title('background') decp = axs[0, 2].imshow(deconv, origin='lower', cmap=cmap, norm=normdeconv) axs[0, 2].set_title('deconvolved') psfp = axs[1, 2].imshow(s, origin='lower', cmap=cmap) axs[1, 2].set_title('narrow psf') plt.tight_layout() if len(output)>1: axcolor = 'lightgoldenrodyellow' axslider = plt.axes([0.1, 0.05, 0.75, 0.01], facecolor=axcolor) slider = Slider(axslider, 'Epoch', 0, len(output)-1, valinit=0, valstep=1) ####################################################################### # functions for slider update, only if more than one epoch. def press(event): #pragma: no cover try: button = event.button except: button = 'None' if event.key == 'right' or button == 'down': if slider.val < len(output) - 1: slider.set_val(slider.val + 1) elif event.key == 'left' or button == 'up': if slider.val > 0: slider.set_val(slider.val - 1) update(slider.val) fig.canvas.draw_idle() def reset(event):#pragma: no cover slider.reset() def update(val):#pragma: no cover epoch0 = int(slider.val) deconv, h = decs[epoch0], hs[epoch0] s = psf[epoch0] # update all the plots datap.set_data(data[epoch0]) modelp.set_data(output[epoch0]) diffp.set_data((data[epoch0] - output[epoch0])/noisemap[epoch0]) backp.set_data(h) decp.set_data(deconv) psfp.set_data(s) fig.canvas.mpl_connect('key_press_event', press) fig.canvas.mpl_connect('scroll_event', press) slider.on_changed(update) plt.show(block=False)
[docs] def make_movie(model, kwargs, data, sigma_2, outpath, figsize=(9, 7.5), epochs_list=None, duration=20, loop=1, format='gif', cmap=None): output = model.model(kwargs) psf = model.psf noisemap = sigma_2 ** 0.5 if epochs_list is None: epochs_list = range(len(output)) # setup for first epoch deconvs = [model.getDeconvolved(kwargs, i) for i in range(len(output))] deconv, h = deconvs[0] s = psf[0] ########################################################################## # figure fig, axs = plt.subplots(2, 3, figsize=figsize) for ax in axs.flatten(): ax.set_xticks([]) ax.set_yticks([]) datap = axs[0, 0].imshow(data[0], origin='lower', cmap=cmap) axs[0, 0].set_title('data') modelp = axs[0, 1].imshow(output[0], origin='lower', cmap=cmap) axs[0, 1].set_title('model') diffp = axs[1, 0].imshow((data[0] - output[0]) / noisemap[0], origin='lower', cmap=CMAP_RR) axs[1, 0].set_title('(data-model)/noise') backp = axs[1, 1].imshow(h, origin='lower', cmap=cmap) axs[1, 1].set_title('background') decp = axs[0, 2].imshow(deconv, origin='lower', cmap=cmap) axs[0, 2].set_title('deconvolved') psfp = axs[1, 2].imshow(s, origin='lower', cmap=cmap) axs[1, 2].set_title('narrow psf') plt.tight_layout() # update all the plots files = [] for i, epoch0 in enumerate(epochs_list): deconv, h = deconvs[epoch0] s = psf[epoch0] datap.set_data(data[epoch0]) modelp.set_data(output[epoch0]) diffp.set_data((data[epoch0] - output[epoch0]) / noisemap[epoch0]) backp.set_data(h) decp.set_data(deconv) psfp.set_data(s) file_png = os.path.join(outpath, "frame{0:05d}.png".format(i)) fig.savefig(file_png) files.append(file_png) if format == 'gif': gif_name = os.path.join(outpath, "deconv.gif") make_gif(files, gif_name, duration=duration, loop=loop) elif format == 'mp4v': # pragma: no cover video_name = os.path.join(outpath, f'deconv.{format}') fps = len(files) / duration make_video(files, outvid=video_name, fps=fps, size=None, is_color=True, format=format) else: RuntimeError('Unsupported video format. Use "gif" or "mp4v".')
[docs] def make_gif(list_files, output_path, duration=100, loop= 1): try: from PIL import Image except ImportError as e: print(e) print('Python package PIL is required for gif creation.') frames = [Image.open(image) for image in list_files] frame_one = frames[0] frame_one.save(output_path, format="GIF", append_images=frames, save_all=True, duration=duration)
[docs] def make_video(images, outvid=None, fps=5, size=None, is_color=True, format="mp4v"): # pragma: no cover """ Create a video from a list of images. @param outvid output video @param images list of images to use in the video @param fps frame per second @param size size of each frame @param is_color color @param format see http://www.fourcc.org/codecs.php @return see http://opencv-python-tutroals.readthedocs.org/en/latest/py_tutorials/py_gui/py_video_display/py_video_display.html The function relies on http://opencv-python-tutroals.readthedocs.org/en/latest/. By default, the video will have the size of the first image. It will resize every image to this size before adding them to the video. """ try: from cv2 import VideoWriter, VideoWriter_fourcc, imread, resize except ImportError as e: print(e) print('Python package opencv-python is required for video creation.') fourcc = VideoWriter_fourcc(*format) vid = None for image in images: if not os.path.exists(image): raise FileNotFoundError(image) img = imread(image) if vid is None: if size is None: size = img.shape[1], img.shape[0] vid = VideoWriter(outvid, fourcc, float(fps), size, is_color) if size[0] != img.shape[1] and size[1] != img.shape[0]: img = resize(img, size) vid.write(img) vid.release() return vid
[docs] def plot_loss(loss_history, figsize = (10,5), ax = None, title = None): if ax is None: fig, ax = plt.subplots(1,1, figsize =figsize) ax.plot(range(len(loss_history)), loss_history) ax.set_xlabel('Steps') ax.set_ylabel('Loss') ax.set_yscale('log') if title is not None: ax.set_title(title) return fig
[docs] def plot_convergence_by_walker(samples_mcmc, param_mcmc, n_walkers, verbose = False): n_params = samples_mcmc.shape[1] n_step = int(samples_mcmc.shape[0] / n_walkers) chain = np.empty((n_walkers, n_step, n_params)) for i in np.arange(n_params): samples = samples_mcmc[:, i].T chain[:, :, i] = samples.reshape((n_step, n_walkers)).T mean_pos = np.zeros((n_params, n_step)) median_pos = np.zeros((n_params, n_step)) std_pos = np.zeros((n_params, n_step)) q16_pos = np.zeros((n_params, n_step)) q84_pos = np.zeros((n_params, n_step)) # chain = np.empty((nwalker, nstep, ndim), dtype = np.double) for i in np.arange(n_params): for j in np.arange(n_step): mean_pos[i][j] = np.mean(chain[:, j, i]) median_pos[i][j] = np.median(chain[:, j, i]) std_pos[i][j] = np.std(chain[:, j, i]) q16_pos[i][j] = np.percentile(chain[:, j, i], 16.) q84_pos[i][j] = np.percentile(chain[:, j, i], 84.) fig, ax = plt.subplots(n_params, sharex=True, figsize=(16, 2 * n_params)) if n_params == 1: ax = [ax] last = n_step burnin = int((9.*n_step) / 10.) #get the final value on the last 10% on the chain for i in range(n_params): if verbose : print(param_mcmc[i], '{:.4f} +/- {:.4f}'.format(median_pos[i][last - 1], (q84_pos[i][last - 1] - q16_pos[i][last - 1]) / 2)) ax[i].plot(median_pos[i][:last], c='g') ax[i].axhline(np.median(median_pos[i][burnin:last]), c='r', lw=1) ax[i].fill_between(np.arange(last), q84_pos[i][:last], q16_pos[i][:last], alpha=0.4) ax[i].set_ylabel(param_mcmc[i], fontsize=10) ax[i].set_xlim(0, last) return fig