Is there a way to significantly speed-up this code?
I'm solving a dynamic programming model using a backward induction algorithm. A crucial step is to calculate the current-period value function (VF), which is a function of two state variables, and is defined as the maximum over the sum of current-period reward and discounted next-period value functions by controlling two control variables that turn out to be next-period states.
My implementation of the calculation of the current-period VF is in the function compute_value_function_jax()
in the example below. I have to run this function many times in a loop in my backward induction setup. I think this code is probably already at the limits of what numpy and jit can deliver in performance, hence I'd be interested to hear what alternatives, in terms of software, are left to improve considerably the computing times of this program while preserving its efficacy.
import numpy as np import jax import jax.numpy as jnp import time def compute_value_function_jax(n_state: int, current_reward: np.array, next_reward: np.array, beta: float, n_control_grid0: int, n_control_grid1: int, n_state_grid0: int, n_state_grid1: int): # Create list of the axes where controls are located ind_control_axes = (2,3) states_sizes = (n_state_grid0, n_state_grid1) controls_sizes = (n_control_grid0, n_control_grid1) RHS = current_reward + beta*next_reward RHS_flat_controls = jnp.reshape(RHS, states_sizes + tuple([-1]), order = 'C') argmax_flat = jnp.argmax(RHS_flat_controls, axis = n_state) argmax_ind = jnp.unravel_index(argmax_flat, controls_sizes) return jnp.max(RHS, axis = ind_control_axes), argmax_ind compute_value_function_jax = jax.jit(compute_value_function_jax, static_argnames = ('n_state', 'n_control_grid0', 'n_control_grid1', 'n_state_grid0', 'n_state_grid1', 'beta')) beta = 0.99 n_state_grid0, n_control_grid0 = 120, 120 n_state_grid1, n_control_grid1 = 150, 150 current_reward = 1000*np.random.rand(n_state_grid0, n_state_grid1, n_control_grid0, n_control_grid1) next_reward = 1000*np.random.rand(n_state_grid0, n_state_grid1) n_state = 2 t0 = time.time() _,_ = compute_value_function_jax(n_state, current_reward, next_reward, beta, n_control_grid0, n_control_grid1, n_state_grid0, n_state_grid1) t1 = time.time() print(f'Elapsed time {1000*(t1-t0)} milisecs.')