Source code for cortecs.eval.eval_neural_net

"""
This file holds the classes for evaluating opacity data as trained by a neural network.

maybe call all of these eval_poly.py, eval_neural_net.py, etc.?

author: @arjunsavel
"""
import jax
import jax.numpy as jnp
from jax import lax


[docs] @jax.jit def feed_forward_equal_layer_sizes(x, n_layers, weights, biases): """ feed forward neural network. this is a function that takes in the input, weights, and biases and returns the output. This function only works if all layers have the same size. Inputs ------ x: array-like the input to the neural network. n_layers: int the number of layers in the neural network. weights: list the weights of the neural network. biases: list """ def inner_function(i, x): return jax.nn.sigmoid(x.dot(weights[i]) + biases[i]) res = lax.fori_loop(0, n_layers, inner_function, x) return res
# todo: add test to make sure feed forward actually works!
[docs] def feed_forward(x, n_layers, weights, biases): """ feed forward neural network. this is a function that takes in the input, weights, and biases and returns the output. Inputs ------ x: array-like the input to the neural network. n_layers: int the number of layers in the neural network. weights: list the weights of the neural network. biases: list """ res = x for i in range(n_layers - 1): res = jax.nn.sigmoid(res.dot(weights[i]) + biases[i]) res = res.dot(weights[-1]) + biases[-1] return res
# @jax.jit
[docs] def eval_neural_net( T, P, temperatures=None, pressures=None, wavelengths=None, n_layers=None, weights=None, biases=None, **kwargs ): """ evaluates the neural network at a given temperature and pressure. Inputs ------ T: float The temperature to evaluate at. P: float The pressure to evaluate at. n_layers: int The number of layers in the neural network. weights: list The weights of the neural network. biases: list The biases of the neural network. """ x = jnp.array([jnp.log10(T), jnp.log10(P)]) res = feed_forward(x, n_layers, weights, biases) return res