Source code for cortecs.eval.eval_pca

"""
For evaluating the PCA fit at a given temperature, pressure, and wavelength.
"""

import jax
import numpy as np


[docs] @jax.jit def eval_pca_ind_wav(first_ind, second_ind, vectors, pca_coeffs): """ Evaluates the PCA fit at a given temperature and pressure. Unfortunately, not all GPUs will support a simpler dot product, I believe, Also, we cannot loop over n_components explicitly because JAX functions require static loop lengths. Inputs ------ first_ind: int The index of the first axis quantity (default temperature) to evaluate at. second_ind: int The index of the second axis quantity (default pressure) to evaluate at. vectors: array The PCA vectors. pca_coeffs: array The PCA coefficients. n_components: int The number of PCA components used in the fitting """ xsec_val = 0.0 n_components = vectors.shape[1] for component in range(n_components): xsec_val += vectors[first_ind, component] * pca_coeffs[component, second_ind] return xsec_val
[docs] def eval_pca( temperature, pressure, wavelength, T, P, wl, fitter_results, fit_axis="pressure", **kwargs ): """ Evaluates the PCA fit at a given temperature, pressure, and wavelength. Inputs ------ temperature: float The temperature to evaluate at. pressure: float The pressure to evaluate at. wavelength: float The wavelength to evaluate at. """ # find the nearest temperature, pressure, and wavelength indices. temperature_ind = np.argmin(np.abs(T - temperature)) pressure_ind = np.argmin(np.abs(P - pressure)) wavelength_ind = np.argmin(np.abs(wl - wavelength)) pca_vectors, pca_coeffs_all_wl = fitter_results pca_coeffs = pca_coeffs_all_wl[wavelength_ind, :, :] # todo: figure out how to order the pressure and temperature inds! # pdb.set_trace() if fit_axis == "pressure": first_arg = pressure_ind second_arg = temperature_ind elif fit_axis == "temperature": first_arg = temperature_ind second_arg = pressure_ind elif fit_axis == "best": T_length = len(T) P_length = len(P) # todo: what if equal? if T_length > P_length: first_arg = temperature_ind second_arg = pressure_ind else: first_arg = pressure_ind second_arg = temperature_ind # print("first_arg, second_arg", first_arg, second_arg) # print("shapes:", pca_vectors.shape, pca_coeffs.shape) return eval_pca_ind_wav(first_arg, second_arg, pca_vectors, pca_coeffs)