A JIT-compiled solver#
Solving a chemical mechanism is a computationally intensive task. There are many reasons for this such as multiple iterations by a solver to achieve an integration, stiff systems requiring internal substepping to acheive a numerically stable solution, and cache misses, among others.
This tutorial focuses on alleviating the cache misses. A popular method for handling cache misses is to pre-compute the indices. This method, which may be referred to as ahead-of-time (AOT) compilation, is used in applications such as KPP [DSD+02]. Pre-computed methods require code preprocessors and preclude runtime configurable software, which is a goal of MICM.
MICM uses just-in-time (JIT) compiled functions built with LLVM JIT libraries to supply runtime-configurable chemistry to avoid cache misses in important chemistry functions.
Up until now, a micm::RosenbrockSolver
has been used. This is a special class in MICM which builds all
of the componenets needed to solve chemistry in memory, including the forcing function and the jacobian. MICM
also provides a micm::JitRosenbrockSolver
which builds and compiles several important chemistry functions at runtime.
What are we JIT-compilng?#
A list of compiled functions is below. How they are compiled and the ellided operations are beyond the scope of this tutorial, but you are free to inspect the source code yourself.
micm::RosenbrockSolver::AlphaMinusJacobian()
is JIT compiled and called withmicm::JitRosenbrockSolver::AlphaMinusJacobian()
micm::LuDecomposition::Decompose()
is JIT compiled and called withmicm::JitLuDecomposition::Decompose()
micm::LinearSolver::Factor()
andmicm::LinearSolver::Solve()
are JIT compiled and called withmicm::JitLinearSolver::Factor()
andmicm::JitLinearSolver::Solve()
micm::ProcessSet::AddForcingTerms()
andmicm::ProcessSet::AddJacobianTerms()
are JIT compiled and called withmicm::JitProcessSet::AddForcingTerms()
andmicm::JitProcessSet::AddJacobianTerms()
So what does this gain me?#
Runtime configuraiton of chemical mechanisms that are fast. Let’s compare the speed of micm::RosenbrockSolver
and the micm::JitRosenbrockSolver
classes applied to the same problem. We will use the simple
fictitous chemical system from previous examples for simplicity, but feel free to try out more complex
mechanisms to see the effects of JIT compiling on compute time in more realistic cases.
If you’re looking for a copy and paste, choose the appropriate tab below and be on your way! Otherwise, stick around for a line by line explanation.
#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/solver/jit_rosenbrock.hpp>
#include <micm/jit/solver/jit_solver_parameters.hpp>
#include <micm/solver/rosenbrock.hpp>
#include <micm/solver/solver_builder.hpp>
#include <micm/jit/solver/jit_solver_builder.hpp>
#include <chrono>
#include <iostream>
// Use our namespace so that this example is easier to read
using namespace micm;
constexpr size_t n_grid_cells = 1;
// partial template specializations
template<class T>
using GroupVectorMatrix = micm::VectorMatrix<T, n_grid_cells>;
using GroupSparseVectorMatrix = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<n_grid_cells>>;
auto run_solver(auto& solver)
{
SolverStats total_stats;
State state = solver.GetState();
state.variables_ = 1;
for (int i = 0; i < n_grid_cells; ++i)
{
state.conditions_[i].temperature_ = 287.45; // K
state.conditions_[i].pressure_ = 101319.9; // Pa
state.conditions_[i].air_density_ = 1e6; // mol m-3
}
auto foo = Species("Foo");
std::vector<double> foo_conc(n_grid_cells, 1.0);
state.SetConcentration(foo, foo_conc);
// choose a timestep and print the initial state
double time_step = 500; // s
auto total_solve_time = std::chrono::nanoseconds::zero();
// solve for ten iterations
for (int i = 0; i < 10; ++i)
{
double elapsed_solve_time = 0;
while (elapsed_solve_time < time_step)
{
auto start = std::chrono::high_resolution_clock::now();
solver.CalculateRateConstants(state);
auto result = solver.Solve(time_step - elapsed_solve_time, state);
auto end = std::chrono::high_resolution_clock::now();
total_solve_time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
elapsed_solve_time = result.final_time_;
total_stats.function_calls_ += result.stats_.function_calls_;
total_stats.jacobian_updates_ += result.stats_.jacobian_updates_;
total_stats.number_of_steps_ += result.stats_.number_of_steps_;
total_stats.accepted_ += result.stats_.accepted_;
total_stats.rejected_ += result.stats_.rejected_;
total_stats.decompositions_ += result.stats_.decompositions_;
total_stats.solves_ += result.stats_.solves_;
}
}
return std::make_tuple(state, total_stats, total_solve_time);
}
int main(const int argc, const char* argv[])
{
auto foo = Species{ "Foo" };
auto bar = Species{ "Bar" };
auto baz = Species{ "Baz" };
Phase gas_phase{ std::vector<Species>{ foo, bar, baz } };
System chemical_system{ SystemParameters{ .gas_phase_ = gas_phase } };
Process r1 = Process::Create()
.SetReactants({ foo })
.SetProducts({ Yield(bar, 0.8), Yield(baz, 0.2) })
.SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-3 }))
.SetPhase(gas_phase);
Process r2 = Process::Create()
.SetReactants({ foo, bar })
.SetProducts({ Yield(baz, 1) })
.SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-5, .C_ = 110.0 }))
.SetPhase(gas_phase);
std::vector<Process> reactions{ r1, r2 };
auto solver_parameters = RosenbrockSolverParameters::ThreeStageRosenbrockParameters();
auto solver = micm::CpuSolverBuilder<micm::RosenbrockSolverParameters>(solver_parameters)
.SetSystem(chemical_system)
.SetReactions(reactions)
.SetNumberOfGridCells(n_grid_cells)
.Build();
auto start = std::chrono::high_resolution_clock::now();
auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
.SetSystem(chemical_system)
.SetReactions(reactions)
.SetNumberOfGridCells(n_grid_cells)
.Build();
auto end = std::chrono::high_resolution_clock::now();
auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;
auto result_tuple = run_solver(solver);
auto jit_result_tuple = run_solver(jit_solver);
// Rerun for more fair comparison after assumed improvements to
// branch-prediction during state update
result_tuple = run_solver(solver);
jit_result_tuple = run_solver(jit_solver);
std::cout << "Standard solve time: " << std::get<2>(result_tuple).count() << " nanoseconds" << std::endl;
std::cout << "JIT solve time: " << std::get<2>(jit_result_tuple).count() << " nanoseconds" << std::endl;
auto result_stats = std::get<1>(result_tuple);
std::cout << "Standard solve stats: " << std::endl;
std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
std::cout << "\tfunction_calls: " << result_stats.function_calls_ << std::endl;
std::cout << "\tjacobian_updates: " << result_stats.jacobian_updates_ << std::endl;
std::cout << "\tnumber_of_steps: " << result_stats.number_of_steps_ << std::endl;
std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
std::cout << "\trejected: " << result_stats.rejected_ << std::endl;
std::cout << "\tdecompositions: " << result_stats.decompositions_ << std::endl;
std::cout << "\tsolves: " << result_stats.solves_ << std::endl;
auto jit_result_stats = std::get<1>(jit_result_tuple);
std::cout << "JIT solve stats: " << std::endl;
std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
std::cout << "\tfunction_calls: " << jit_result_stats.function_calls_ << std::endl;
std::cout << "\tjacobian_updates: " << jit_result_stats.jacobian_updates_ << std::endl;
std::cout << "\tnumber_of_steps: " << jit_result_stats.number_of_steps_ << std::endl;
std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
std::cout << "\trejected: " << jit_result_stats.rejected_ << std::endl;
std::cout << "\tdecompositions: " << jit_result_stats.decompositions_ << std::endl;
std::cout << "\tsolves: " << jit_result_stats.solves_ << std::endl;
auto result = std::get<0>(result_tuple);
auto jit_result = std::get<0>(jit_result_tuple);
for (auto& species : result.variable_names_)
{
for (int i = 0; i < n_grid_cells; ++i)
{
double a = result.variables_[i][result.variable_map_[species]];
double b = jit_result.variables_[i][jit_result.variable_map_[species]];
if (std::abs(a - b) > 1.0e-5 * (std::abs(a) + std::abs(b)) / 2.0 + 1.0e-30)
{
std::cout << species << " does not match final concentration" << std::endl;
}
}
}
return 0;
}
To build and run the example using GNU (assuming the default install location), copy and
paste the example code into a file named foo_jit_chem.cpp
and run:
g++ -o foo_jit_chem foo_jit_chem.cpp -I/usr/local/micm-3.2.0/include `llvm-config --cxxflags --ldflags --system-libs --libs support core orcjit native irreader` -std=c++20 -fexceptions
./foo_jit_chem
Line-by-line explanation#
Starting with the header files, we need headers for timing, output, and of course for both types of solvers.
#include <micm/jit/jit_compiler.hpp>
#include <micm/jit/solver/jit_rosenbrock.hpp>
#include <micm/jit/solver/jit_solver_parameters.hpp>
#include <micm/solver/rosenbrock.hpp>
#include <micm/solver/solver_builder.hpp>
Next, we use our namespace, define our number of gridcells (1 for now), and some partial template specializations. We are using our custom vectorized matrix, which groups mutliple reactions across grid cells into tiny blocks in a vector, allowing multiple grid cells to be solved simultaneously.
#include <chrono>
#include <iostream>
// Use our namespace so that this example is easier to read
using namespace micm;
constexpr size_t n_grid_cells = 1;
// partial template specializations
Now, all at once, is the function which runs either type of solver. We set all species concentrations to 1 \(\mathrm{mol\ m^-3}\). Additionally, we are collecting all of the solver stats across all solving timesteps
using GroupVectorMatrix = micm::VectorMatrix<T, n_grid_cells>;
using GroupSparseVectorMatrix = micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<n_grid_cells>>;
auto run_solver(auto& solver)
{
SolverStats total_stats;
State state = solver.GetState();
state.variables_ = 1;
for (int i = 0; i < n_grid_cells; ++i)
{
state.conditions_[i].temperature_ = 287.45; // K
state.conditions_[i].pressure_ = 101319.9; // Pa
state.conditions_[i].air_density_ = 1e6; // mol m-3
}
auto foo = Species("Foo");
std::vector<double> foo_conc(n_grid_cells, 1.0);
state.SetConcentration(foo, foo_conc);
// choose a timestep and print the initial state
double time_step = 500; // s
auto total_solve_time = std::chrono::nanoseconds::zero();
// solve for ten iterations
for (int i = 0; i < 10; ++i)
{
double elapsed_solve_time = 0;
while (elapsed_solve_time < time_step)
{
auto start = std::chrono::high_resolution_clock::now();
solver.CalculateRateConstants(state);
auto result = solver.Solve(time_step - elapsed_solve_time, state);
auto end = std::chrono::high_resolution_clock::now();
total_solve_time += std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
elapsed_solve_time = result.final_time_;
total_stats.function_calls_ += result.stats_.function_calls_;
total_stats.jacobian_updates_ += result.stats_.jacobian_updates_;
total_stats.number_of_steps_ += result.stats_.number_of_steps_;
total_stats.accepted_ += result.stats_.accepted_;
total_stats.rejected_ += result.stats_.rejected_;
total_stats.decompositions_ += result.stats_.decompositions_;
total_stats.solves_ += result.stats_.solves_;
}
}
return std::make_tuple(state, total_stats, total_solve_time);
}
int main(const int argc, const char* argv[])
{
Finally, the main function which reads the configuration and initializes the jit solver.
auto bar = Species{ "Bar" };
auto baz = Species{ "Baz" };
Phase gas_phase{ std::vector<Species>{ foo, bar, baz } };
System chemical_system{ SystemParameters{ .gas_phase_ = gas_phase } };
Process r1 = Process::Create()
.SetReactants({ foo })
.SetProducts({ Yield(bar, 0.8), Yield(baz, 0.2) })
.SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-3 }))
.SetPhase(gas_phase);
Process r2 = Process::Create()
.SetReactants({ foo, bar })
.SetProducts({ Yield(baz, 1) })
.SetRateConstant(ArrheniusRateConstant({ .A_ = 1.0e-5, .C_ = 110.0 }))
.SetPhase(gas_phase);
std::vector<Process> reactions{ r1, r2 };
auto solver_parameters = RosenbrockSolverParameters::ThreeStageRosenbrockParameters();
auto solver = micm::CpuSolverBuilder<micm::RosenbrockSolverParameters>(solver_parameters)
.SetSystem(chemical_system)
.SetReactions(reactions)
.SetNumberOfGridCells(n_grid_cells)
.Build();
auto start = std::chrono::high_resolution_clock::now();
auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
.SetSystem(chemical_system)
.SetReactions(reactions)
.SetNumberOfGridCells(n_grid_cells)
.Build();
auto end = std::chrono::high_resolution_clock::now();
auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;
auto result_tuple = run_solver(solver);
The only additional step here is to make an instance of the micm::JitCompiler
and pass it as a shared pointer to the micm::JitRosenbrockSolver
. We also are
using our vectorized matrix for both solvers. The micm::JitRosenbrockSolver
only works
with the vectorized matrix whereas the micm::RosenbrockSolver
works with a regular matrix.
At construction of the micm::JitRosenbrockSolver
, all JIT functions are compiled. We record that
time here.
auto jit_solver = micm::JitSolverBuilder<micm::JitRosenbrockSolverParameters, n_grid_cells>(solver_parameters)
.SetSystem(chemical_system)
.SetReactions(reactions)
.SetNumberOfGridCells(n_grid_cells)
.Build();
auto end = std::chrono::high_resolution_clock::now();
auto jit_compile_time = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start);
std::cout << "Jit compile time: " << jit_compile_time.count() << " nanoseconds" << std::endl;
auto result_tuple = run_solver(solver);
Finally, we run both solvers, output their cumulative stats, and compare their results.
// Rerun for more fair comparison after assumed improvements to
// branch-prediction during state update
result_tuple = run_solver(solver);
jit_result_tuple = run_solver(jit_solver);
std::cout << "Standard solve time: " << std::get<2>(result_tuple).count() << " nanoseconds" << std::endl;
std::cout << "JIT solve time: " << std::get<2>(jit_result_tuple).count() << " nanoseconds" << std::endl;
auto result_stats = std::get<1>(result_tuple);
std::cout << "Standard solve stats: " << std::endl;
std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
std::cout << "\tfunction_calls: " << result_stats.function_calls_ << std::endl;
std::cout << "\tjacobian_updates: " << result_stats.jacobian_updates_ << std::endl;
std::cout << "\tnumber_of_steps: " << result_stats.number_of_steps_ << std::endl;
std::cout << "\taccepted: " << result_stats.accepted_ << std::endl;
std::cout << "\trejected: " << result_stats.rejected_ << std::endl;
std::cout << "\tdecompositions: " << result_stats.decompositions_ << std::endl;
std::cout << "\tsolves: " << result_stats.solves_ << std::endl;
auto jit_result_stats = std::get<1>(jit_result_tuple);
std::cout << "JIT solve stats: " << std::endl;
std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
std::cout << "\tfunction_calls: " << jit_result_stats.function_calls_ << std::endl;
std::cout << "\tjacobian_updates: " << jit_result_stats.jacobian_updates_ << std::endl;
std::cout << "\tnumber_of_steps: " << jit_result_stats.number_of_steps_ << std::endl;
std::cout << "\taccepted: " << jit_result_stats.accepted_ << std::endl;
std::cout << "\trejected: " << jit_result_stats.rejected_ << std::endl;
std::cout << "\tdecompositions: " << jit_result_stats.decompositions_ << std::endl;
std::cout << "\tsolves: " << jit_result_stats.solves_ << std::endl;
auto result = std::get<0>(result_tuple);
auto jit_result = std::get<0>(jit_result_tuple);
for (auto& species : result.variable_names_)
{
for (int i = 0; i < n_grid_cells; ++i)
{
double a = result.variables_[i][result.variable_map_[species]];
double b = jit_result.variables_[i][jit_result.variable_map_[species]];
if (std::abs(a - b) > 1.0e-5 * (std::abs(a) + std::abs(b)) / 2.0 + 1.0e-30)
{
std::cout << species << " does not match final concentration" << std::endl;
}
}
}
return 0;
}
The output will be similar to this:
Jit compile time: 38305334 nanoseconds
Standard solve time: 122582 nanoseconds
JIT solve time: 96541 nanoseconds
Standard solve stats:
accepted: 130
function_calls: 260
jacobian_updates: 130
number_of_steps: 130
accepted: 130
rejected: 0
decompositions: 130
solves: 390
singular: 0
JIT solve stats:
accepted: 130
function_calls: 260
jacobian_updates: 130
number_of_steps: 130
accepted: 130
rejected: 0
decompositions: 130
solves: 390
singular: 0