Avoiding conversions & boilerplate in pybind11
2021-02-11
The goals of pygram11 include:
- providing the fastest possible histogram calculations.
- supporting uncertainties on weighted histograms.
- supporting multiple weight variations in a single histogramming routine.
The first bullet is primarily accomplished via parallel loops provided by OpenMP, but we can squeeze out a bit more performance in other places. This post focuses on a second order performance consideration: avoiding potentially expensive conversions (while supporting different data and weight array types).
Early versions of pygram11 (up to version 0.10.3) supported input data (arrays) of any type, but in the backend we supported histogramming calculations only on 32- and 64-bit floating point inputs (for both the data and the weights). If a non-floating point typed array was passed (as either the data input or weights input), we converted the incompatibly typed arrays to floating point arrays and passed the converted data to the backend C++ functions.
The backend C++ code has always been generic (implemented with templated functions). An example function and pybind11 binding was of this form:
py::tuple
This setup supports np.float32
(C++ float) and np.float64
(C++
double) input and weights via a single templated function that is used
by four bindings. This was the implementation before version 0.11.0 of
pygram11.
Now let's look at the example of a two-dimensional weighted histogram function signature:
py::tuple
Let's say we want to support 32-bit and 64-bit floating, integer, and
unsigned integer input, along with 64-bit and 32-bit floating point
weights. That's six types for both x
and y
, and two types for w
;
that's 72 total overloads.
We can lean on some template metaprogramming to make a clean (low boilerplate) and extendable implementation while achieving the goal of explicitly avoiding conversions.
We'll use boost::mp11 for some
help. We create a type list for the possible data types (pg_Ts
) and
one for the possible weight types (pg_Ws
). The boost::mp11
library
provides the mp_product
metafunction to generate all possible
combinations of its template parameters at compile time.
using boost::mp11::mp_product;
;
// all data types
using pg_Ts = type_list<
double, int64_t, uint64_t, float, int32_t, uint32_t
>;
// all weight types
using pg_Ws = type_list<double, float>;
// all combinations of data types and weight types
using pg_Ts_and_Ws = mp_product<type_list, pg_Ts, pg_Ws>;
// all combinations of data types and data types
using pg_T_pairs = mp_product<type_list, pg_Ts, pg_Ts>;
// all combinations of data types, data types, and weight types
using pg_T_pairs_and_Ws = mp_product<type_list, pg_Ts, pg_Ts, pg_Ws>;
Let's think about our new types a bit:
pg_Ts_and_Ws
is made up of pairs of data types and weight types:type_list<double, double>
, (1st)type_list<int64_t, double>
, (2nd)- ...
type_list<uint32_t, float>
(12th)
pg_Ts_and_Ts
is made up of pairs of data types:type_list<double, double>
, (1st)- ...
type_list<uint32_t, uint32_t>
(36th)
pg_T_pairs_and_Ws
is made up of data types (x2), and weight types:type_list<double, double, double>
(1st)- ...
type_list<uint32_t, uint32_t, float>
(72nd)
We can write some small functions to inject into a py::module_
object a type list of each of these three forms:
// to use "x"_a instead of py::arg("x")
using namespace pybind11::literals;
/// inject a data type and weight type pair
void
/// inject and data type and data type pair
void
/// inject a data type, data type, and weight type triplet
void
And then use boost::mp11::mp_for_each
to inject for each of our type
combinations:
These three calls of mp_for_each
will provide 12 + 36 + 72 overloads
to the py::module_
instance. One can imagine making an addition to
the supported types of the library: instead of adding multiple new
function calls to cover all possible combinations, we just add it to a
type list and the template metaprogramming takes care of generating
all possible overloads.
This code isn't an exact copy of what is in pygram11 version 0.11.0, but it's quite close. Checkout the _backend.cpp file.