A JIT-compiled solver#

Solving a chemical mechanism is a computationally intensive task. There are many reasons for this such as multiple iterations by a solver to achieve an integration, stiff systems requiring internal substepping to acheive a numerically stable solution, and cache misses, among others.

This tutorial focuses on alleviating the cache misses. A popular method for handling cache misses is to pre-compute the indices. This method, which may be referred to as ahead-of-time (AOT) compilation, is used in applications such as KPP [DSD+02]. Pre-computed methods require code preprocessors and preclude runtime configurable software, which is a goal of MICM.

MICM uses just-in-time (JIT) compiled functions built with LLVM JIT libraries to supply runtime-configurable chemistry to avoid cache misses in important chemistry functions.

Up until now, a micm::RosenbrockSolver has been used. This is a special class in MICM which builds all of the componenets needed to solve chemistry in memory, including the forcing function and the jacobian. MICM also provides a micm::JitRosenbrockSolver which builds and compiles several important chemistry functions at runtime.

What are we JIT-compilng?#

A list of compiled functions is below. How they are compiled and the ellided operations are beyond the scope of this tutorial, but you are free to inspect the source code yourself.

So what does this gain me?#

Runtime configuraiton of chemical mechanisms that are fast. Let’s compare the speed of micm::RosenbrockSolver and the micm::JitRosenbrockSolver classes applied to the same problem. We will use the simple fictitous chemical system from previous examples for simplicity, but feel free to try out more complex mechanisms to see the effects of JIT compiling on compute time in more realistic cases.

If you’re looking for a copy and paste, choose the appropriate tab below and be on your way! Otherwise, stick around for a line by line explanation.

#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/solver/jit_rosenbrock.hpp>
#include <micm/jit/solver/jit_solver_parameters.hpp>
#include <micm/solver/rosenbrock.hpp>
#include <micm/solver/solver_builder.hpp>
#include <micm/jit/solver/jit_solver_builder.hpp>

#include <chrono>
#include <iostream>

// Use our namespace so that this example is easier to read
using namespace micm;

constexpr size_t n_grid_cells = 1;

// partial template specializations
template<class T>
using GroupVectorMatrix = micm::VectorMatrix<T, n_grid_cells>;
using GroupSparseVectorMatrix = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<n_grid_cells>>;

auto run_solver(auto& solver)
{
  SolverStats total_stats;
  State state = solver.GetState();

  state.variables_ = 1;

  for (int i = 0; i < n_grid_cells; ++i)
  {
    state.conditions_[i].temperature_ = 287.45;  // K
    state.conditions_[i].pressure_ = 101319.9;   // Pa
    state.conditions_[i].air_density_ = 1e6;     // mol m-3
  }
  auto foo = Species("Foo");
  std::vector<double> foo_conc(n_grid_cells, 1.0);
  state.SetConcentration(foo, foo_conc);

  // choose a timestep and print the initial state
  double time_step = 500;  // s

  auto total_solve_time = std::chrono::nanoseconds::zero();

  // solve for ten iterations
  for (int i = 0; i < 10; ++i)
  {
    double elapsed_solve_time = 0;

    while (elapsed_solve_time < time_step)
    {
      auto start = std::chrono::high_resolution_clock::now();
      solver.CalculateRateConstants(state);
      auto result = solver.Solve(time_step - elapsed_solve_time, state);
      auto end = std::chrono::high_resolution_clock::now();
      total_solve_time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
      elapsed_solve_time = result.final_time_;

      total_stats.function_calls_ += result.stats_.function_calls_;
      total_stats.jacobian_updates_ += result.stats_.jacobian_updates_;
      total_stats.number_of_steps_ += result.stats_.number_of_steps_;
      total_stats.accepted_ += result.stats_.accepted_;
      total_stats.rejected_ += result.stats_.rejected_;
      total_stats.decompositions_ += result.stats_.decompositions_;
      total_stats.solves_ += result.stats_.solves_;
    }
  }

  return std::make_tuple(state, total_stats, total_solve_time);
}

int main(const int argc, const char* argv[])
{
  auto foo = Species{ "Foo" };
  auto bar = Species{ "Bar" };
  auto baz = Species{ "Baz" };

  Phase gas_phase{ std::vector<Species>{ foo, bar, baz } };

  System chemical_system{ SystemParameters{ .gas_phase_ = gas_phase } };

  Process r1 = Process::Create()
                   .SetReactants({ foo })
                   .SetProducts({ Yield(bar, 0.8), Yield(baz, 0.2) })
                   .SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-3 }))
                   .SetPhase(gas_phase);

  Process r2 = Process::Create()
                   .SetReactants({ foo, bar })
                   .SetProducts({ Yield(baz, 1) })
                   .SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-5, .C_ = 110.0 }))
                   .SetPhase(gas_phase);

  std::vector<Process> reactions{ r1, r2 };

  auto solver_parameters = RosenbrockSolverParameters::ThreeStageRosenbrockParameters();

  auto solver = micm::CpuSolverBuilder<micm::RosenbrockSolverParameters>(solver_parameters)
                    .SetSystem(chemical_system)
                    .SetReactions(reactions)
                    .SetNumberOfGridCells(n_grid_cells)
                    .Build();
  
  auto start = std::chrono::high_resolution_clock::now();
  auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
                        .SetSystem(chemical_system)
                        .SetReactions(reactions)
                        .SetNumberOfGridCells(n_grid_cells)
                        .Build();
  auto end = std::chrono::high_resolution_clock::now();
  auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

  std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;

  auto result_tuple = run_solver(solver);
  auto jit_result_tuple = run_solver(jit_solver);

  // Rerun for more fair comparison after assumed improvements to
  // branch-prediction during state update
  result_tuple = run_solver(solver);
  jit_result_tuple = run_solver(jit_solver);

  std::cout << "Standard solve time: " << std::get<2>(result_tuple).count() << " nanoseconds" << std::endl;
  std::cout << "JIT solve time: " << std::get<2>(jit_result_tuple).count() << " nanoseconds" << std::endl;

  auto result_stats = std::get<1>(result_tuple);
  std::cout << "Standard solve stats: " << std::endl;
  std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
  std::cout << "\tfunction_calls: " << result_stats.function_calls_ << std::endl;
  std::cout << "\tjacobian_updates: " << result_stats.jacobian_updates_ << std::endl;
  std::cout << "\tnumber_of_steps: " << result_stats.number_of_steps_ << std::endl;
  std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
  std::cout << "\trejected: " << result_stats.rejected_ << std::endl;
  std::cout << "\tdecompositions: " << result_stats.decompositions_ << std::endl;
  std::cout << "\tsolves: " << result_stats.solves_ << std::endl;

  auto jit_result_stats = std::get<1>(jit_result_tuple);
  std::cout << "JIT solve stats: " << std::endl;
  std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
  std::cout << "\tfunction_calls: " << jit_result_stats.function_calls_ << std::endl;
  std::cout << "\tjacobian_updates: " << jit_result_stats.jacobian_updates_ << std::endl;
  std::cout << "\tnumber_of_steps: " << jit_result_stats.number_of_steps_ << std::endl;
  std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
  std::cout << "\trejected: " << jit_result_stats.rejected_ << std::endl;
  std::cout << "\tdecompositions: " << jit_result_stats.decompositions_ << std::endl;
  std::cout << "\tsolves: " << jit_result_stats.solves_ << std::endl;

  auto result = std::get<0>(result_tuple);
  auto jit_result = std::get<0>(jit_result_tuple);

  for (auto& species : result.variable_names_)
  {
    for (int i = 0; i < n_grid_cells; ++i)
    {
      double a = result.variables_[i][result.variable_map_[species]];
      double b = jit_result.variables_[i][jit_result.variable_map_[species]];
      if (std::abs(a - b) > 1.0e-5 * (std::abs(a) + std::abs(b)) / 2.0 + 1.0e-30)
      {
        std::cout << species << " does not match final concentration" << std::endl;
      }
    }
  }
  return 0;
}

To build and run the example using GNU (assuming the default install location), copy and paste the example code into a file named foo_jit_chem.cpp and run:

g++ -o foo_jit_chem foo_jit_chem.cpp -I/usr/local/micm-3.2.0/include `llvm-config --cxxflags --ldflags --system-libs --libs support core orcjit native irreader` -std=c++20 -fexceptions
./foo_jit_chem

Line-by-line explanation#

Starting with the header files, we need headers for timing, output, and of course for both types of solvers.

#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/solver/jit_rosenbrock.hpp>
#include <micm/jit/solver/jit_solver_parameters.hpp>
#include <micm/solver/rosenbrock.hpp>
#include <micm/solver/solver_builder.hpp>

Next, we use our namespace, define our number of gridcells (1 for now), and some partial template specializations. We are using our custom vectorized matrix, which groups mutliple reactions across grid cells into tiny blocks in a vector, allowing multiple grid cells to be solved simultaneously.

#include <chrono>
#include <iostream>

// Use our namespace so that this example is easier to read
using namespace micm;

constexpr size_t n_grid_cells = 1;

// partial template specializations

Now, all at once, is the function which runs either type of solver. We set all species concentrations to 1 \(\mathrm{mol\ m^-3}\). Additionally, we are collecting all of the solver stats across all solving timesteps

using GroupVectorMatrix = micm::VectorMatrix<T, n_grid_cells>;
using GroupSparseVectorMatrix = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<n_grid_cells>>;

auto run_solver(auto& solver)
{
  SolverStats total_stats;
  State state = solver.GetState();

  state.variables_ = 1;

  for (int i = 0; i < n_grid_cells; ++i)
  {
    state.conditions_[i].temperature_ = 287.45;  // K
    state.conditions_[i].pressure_ = 101319.9;   // Pa
    state.conditions_[i].air_density_ = 1e6;     // mol m-3
  }
  auto foo = Species("Foo");
  std::vector<double> foo_conc(n_grid_cells, 1.0);
  state.SetConcentration(foo, foo_conc);

  // choose a timestep and print the initial state
  double time_step = 500;  // s

  auto total_solve_time = std::chrono::nanoseconds::zero();

  // solve for ten iterations
  for (int i = 0; i < 10; ++i)
  {
    double elapsed_solve_time = 0;

    while (elapsed_solve_time < time_step)
    {
      auto start = std::chrono::high_resolution_clock::now();
      solver.CalculateRateConstants(state);
      auto result = solver.Solve(time_step - elapsed_solve_time, state);
      auto end = std::chrono::high_resolution_clock::now();
      total_solve_time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
      elapsed_solve_time = result.final_time_;

      total_stats.function_calls_ += result.stats_.function_calls_;
      total_stats.jacobian_updates_ += result.stats_.jacobian_updates_;
      total_stats.number_of_steps_ += result.stats_.number_of_steps_;
      total_stats.accepted_ += result.stats_.accepted_;
      total_stats.rejected_ += result.stats_.rejected_;
      total_stats.decompositions_ += result.stats_.decompositions_;
      total_stats.solves_ += result.stats_.solves_;
    }
  }

  return std::make_tuple(state, total_stats, total_solve_time);
}

int main(const int argc, const char* argv[])
{

Finally, the main function which reads the configuration and initializes the jit solver.

  auto bar = Species{ "Bar" };
  auto baz = Species{ "Baz" };

  Phase gas_phase{ std::vector<Species>{ foo, bar, baz } };

  System chemical_system{ SystemParameters{ .gas_phase_ = gas_phase } };

  Process r1 = Process::Create()
                   .SetReactants({ foo })
                   .SetProducts({ Yield(bar, 0.8), Yield(baz, 0.2) })
                   .SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-3 }))
                   .SetPhase(gas_phase);

  Process r2 = Process::Create()
                   .SetReactants({ foo, bar })
                   .SetProducts({ Yield(baz, 1) })
                   .SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-5, .C_ = 110.0 }))
                   .SetPhase(gas_phase);

  std::vector<Process> reactions{ r1, r2 };

  auto solver_parameters = RosenbrockSolverParameters::ThreeStageRosenbrockParameters();

  auto solver = micm::CpuSolverBuilder<micm::RosenbrockSolverParameters>(solver_parameters)
                    .SetSystem(chemical_system)
                    .SetReactions(reactions)
                    .SetNumberOfGridCells(n_grid_cells)
                    .Build();
  
  auto start = std::chrono::high_resolution_clock::now();
  auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
                        .SetSystem(chemical_system)
                        .SetReactions(reactions)
                        .SetNumberOfGridCells(n_grid_cells)
                        .Build();
  auto end = std::chrono::high_resolution_clock::now();
  auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

  std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;

  auto result_tuple = run_solver(solver);

The only additional step here is to make an instance of the micm::JitCompiler and pass it as a shared pointer to the micm::JitRosenbrockSolver. We also are using our vectorized matrix for both solvers. The micm::JitRosenbrockSolver only works with the vectorized matrix whereas the micm::RosenbrockSolver works with a regular matrix. At construction of the micm::JitRosenbrockSolver, all JIT functions are compiled. We record that time here.

  auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
                        .SetSystem(chemical_system)
                        .SetReactions(reactions)
                        .SetNumberOfGridCells(n_grid_cells)
                        .Build();
  auto end = std::chrono::high_resolution_clock::now();
  auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);

  std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;

  auto result_tuple = run_solver(solver);

Finally, we run both solvers, output their cumulative stats, and compare their results.

  // Rerun for more fair comparison after assumed improvements to
  // branch-prediction during state update
  result_tuple = run_solver(solver);
  jit_result_tuple = run_solver(jit_solver);

  std::cout << "Standard solve time: " << std::get<2>(result_tuple).count() << " nanoseconds" << std::endl;
  std::cout << "JIT solve time: " << std::get<2>(jit_result_tuple).count() << " nanoseconds" << std::endl;

  auto result_stats = std::get<1>(result_tuple);
  std::cout << "Standard solve stats: " << std::endl;
  std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
  std::cout << "\tfunction_calls: " << result_stats.function_calls_ << std::endl;
  std::cout << "\tjacobian_updates: " << result_stats.jacobian_updates_ << std::endl;
  std::cout << "\tnumber_of_steps: " << result_stats.number_of_steps_ << std::endl;
  std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
  std::cout << "\trejected: " << result_stats.rejected_ << std::endl;
  std::cout << "\tdecompositions: " << result_stats.decompositions_ << std::endl;
  std::cout << "\tsolves: " << result_stats.solves_ << std::endl;

  auto jit_result_stats = std::get<1>(jit_result_tuple);
  std::cout << "JIT solve stats: " << std::endl;
  std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
  std::cout << "\tfunction_calls: " << jit_result_stats.function_calls_ << std::endl;
  std::cout << "\tjacobian_updates: " << jit_result_stats.jacobian_updates_ << std::endl;
  std::cout << "\tnumber_of_steps: " << jit_result_stats.number_of_steps_ << std::endl;
  std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
  std::cout << "\trejected: " << jit_result_stats.rejected_ << std::endl;
  std::cout << "\tdecompositions: " << jit_result_stats.decompositions_ << std::endl;
  std::cout << "\tsolves: " << jit_result_stats.solves_ << std::endl;

  auto result = std::get<0>(result_tuple);
  auto jit_result = std::get<0>(jit_result_tuple);

  for (auto& species : result.variable_names_)
  {
    for (int i = 0; i < n_grid_cells; ++i)
    {
      double a = result.variables_[i][result.variable_map_[species]];
      double b = jit_result.variables_[i][jit_result.variable_map_[species]];
      if (std::abs(a - b) > 1.0e-5 * (std::abs(a) + std::abs(b)) / 2.0 + 1.0e-30)
      {
        std::cout << species << " does not match final concentration" << std::endl;
      }
    }
  }
  return 0;
}

The output will be similar to this:

Jit compile time: 38305334 nanoseconds
Standard solve time: 122582 nanoseconds
JIT solve time: 96541 nanoseconds
Standard solve stats:
        accepted: 130
        function_calls: 260
        jacobian_updates: 130
        number_of_steps: 130
        accepted: 130
        rejected: 0
        decompositions: 130
        solves: 390
        singular: 0

JIT solve stats:
        accepted: 130
        function_calls: 260
        jacobian_updates: 130
        number_of_steps: 130
        accepted: 130
        rejected: 0
        decompositions: 130
        solves: 390
        singular: 0