"""
The high-level API for fitting. Requires the Opac object.
"""
import warnings
from functools import partial
from multiprocessing import Pool
import numpy as np
from tqdm import tqdm
from cortecs.fit.fit_neural_net import *
from cortecs.fit.fit_pca import *
from cortecs.fit.fit_polynomial import *
[docs]
class Fitter(object):
"""
fits the opacity data to a neural network.
todo: fit CIA. only on one dimension, because there's no pressure dependence.
"""
# make the dictionary of stuff here
method_dict = {
"pca": fit_pca,
"neural_net": fit_neural_net,
"polynomial": fit_polynomial,
}
prep_method_dict = {
"pca": prep_pca,
"neural_net": prep_neural_net,
"polynomial": prep_polynomial,
}
save_method_dict = {
"pca": save_pca,
"neural_net": save_neural_net,
"polynomial": save_polynomial,
}
def __init__(self, opac, method="pca", **fitter_kwargs):
"""
todo: make list of opac
fits the opacity data to a neural network.
Inputs
------
opac: Opac
the opacity object.
method: str
the method to use for fitting. Options include (in order of increasing complexity) include
'polynomial', 'pca', and 'neural_net'. The more complex the model, the larger the model size (i.e., potentially
the lower the compression factor), and the more likely it is to fit well.
fitter_kwargs: dict
kwargs that are passed to the fitter. one kwarg, for instance, is the fit_axis: for PCA,
this determines what axis is fit against.
"""
self.opac = opac
self.fitter_kwargs = fitter_kwargs
if method not in self.method_dict.keys():
raise ValueError("method {} not supported".format(method))
self.method = method
self.fit_func = self.method_dict[self.method]
self.wl = self.opac.wl
self.P = self.opac.P
self.T = self.opac.T
# todo: figure out how to change the fitting...based on the fit axis?
return
[docs]
def fit(self, parallel=False, verbose=1):
"""
fits the opacity data to a neural network. loops over all wavelengths.
Can loop over a list of species? ...should be parallelized!
Inputs
------
parallel: bool
whether to parallelize the fitting.
"""
# iterate over the wavelengths.
prep_method = self.prep_method_dict[self.method]
self.prep_res = prep_method(self.opac.cross_section, **self.fitter_kwargs)
if not parallel:
self.fit_serial(verbose=verbose)
else:
self.fit_parallel()
return # will I need to save and stuff like that?
[docs]
def fit_serial(self, verbose=0):
"""
fits the opacity data with a given method. Serial.
:return:
todo: keep in mind that the PCA method is not actually independent of wavelength.
"""
# loop over the wavelengths and fit
res = []
with warnings.catch_warnings():
if verbose == 1:
iterator = tqdm(enumerate(self.wl), total=len(self.wl))
else:
iterator = enumerate(self.wl)
for i, w in iterator:
cross_section = self.opac.cross_section[:, :, i]
res += [
self.fit_func(
cross_section,
self.P,
self.T,
self.prep_res,
**self.fitter_kwargs
)
]
self.fitter_results = [self.prep_res, np.array(res)]
return
[docs]
def update_pbar(self, arg):
"""
Updates a tqdm progress bar.
"""
self.pbar.update(1)
pass
[docs]
def fit_parallel(self):
"""
fits the opacity data with a given method. Parallel.
:return:
"""
with warnings.catch_warnings():
num_processes = 1
func = partial(
self.fit_func,
P=self.P,
T=self.T,
prep_res=self.prep_res,
**self.fitter_kwargs
)
self.pbar = tqdm(
total=len(self.wl),
position=0,
leave=True,
unit="chunk",
desc="Fitting with {} method".format(self.method),
)
# these two lines are where the bulk of the multiprocessing happens
pool = Pool(num_processes)
# actualy loop over using pool.map. need
# reformat the cross_section to be a list of 2D arrays
cross_section_reformatted = [
self.opac.cross_section[:, :, i] for i in range(len(self.wl))
]
# we tehcnically want sorted results. but apply async is the only way to get the progress bar to work!
async_result = []
for i, item in enumerate(cross_section_reformatted):
async_result.append(
[i, pool.apply_async(func, args=(item,), callback=self.update_pbar)]
)
# Close the pool
pool.close()
pool.join()
# Close the progress bar
self.pbar.close()
# sort the results based on the index
sorted_results = [None] * len(async_result)
for item in async_result:
i, res = item
sorted_results[i] = res.get()
self.fitter_results = [self.prep_res, sorted_results]
return
[docs]
def save(self, savename):
"""
saves the fitter results.
"""
save_func = self.save_method_dict[self.method]
save_func(savename, self.fitter_results)
return