Boosted tree predictions in SQL
Computing predictions with a boosted tree model is usually not much more than a standard tree traversal. Not uncommonly the actual challenge is how to get the data for making decisions at each node. This often means exporting features from a database and loading them to the tree library, which can take more time than computing the predictions themselves. Pity.
Generally this is a pretty well known problem, with a standard solution: move the computation to where the data is. Which can can be done quite easily in this case, even if it means computing boosted tree predictions in SQL.
Let's take a look at the demo model from the XGBoost tutorial, which has the following representation:
booster: 0:[f29<-9.53674316e-07] yes=1,no=2,missing=1 1:[f56<-9.53674316e-07] yes=3,no=4,missing=3 3:leaf=1.71217716 4:leaf=-1.70044053 2:[f109<-9.53674316e-07] yes=5,no=6,missing=5 5:leaf=-1.94070864 6:leaf=1.85964918 booster: 0:[f60<-9.53674316e-07] yes=1,no=2,missing=1 1:[f29<-9.53674316e-07] yes=3,no=4,missing=3 3:leaf=0.78471756 4:leaf=-0.968530357 2:leaf=-6.23624468
The model consists of a set of two booster trees. Each tree
has a set of regular nodes along with the root node, as well
as a set of leaf nodes. Each regular node has an associated
feature and a split value, followed by three links to a
possible next node. Each leaf node has only one attribute: the
prediction. We can represent this with the following table,
where some attributes will be
NULL for each node
And the demo model from above translates to these
CREATE TABLE tree_model ( model_name VARCHAR NOT NULL, booster_id INTEGER NOT NULL, node_id INTEGER NOT NULL, feature_index INTEGER, split DOUBLE PRECISION, node_yes INTEGER, node_no INTEGER, node_missing INTEGER, prediction DOUBLE PRECISION );
INSERT INTO tree_model VALUES ('demo_model.xgb', 0, 0, 29, -9.53674316e-07, 1, 2, 1, NULL), ('demo_model.xgb', 0, 1, 56, -9.53674316e-07, 3, 4, 3, NULL), ('demo_model.xgb', 0, 2, 109, -9.53674316e-07, 5, 6, 5, NULL), ('demo_model.xgb', 0, 3, NULL, NULL, NULL, NULL, NULL, 1.71217716), ('demo_model.xgb', 0, 4, NULL, NULL, NULL, NULL, NULL, -1.70044053), ('demo_model.xgb', 0, 5, NULL, NULL, NULL, NULL, NULL, -1.94070864), ('demo_model.xgb', 0, 6, NULL, NULL, NULL, NULL, NULL, 1.85964918), ('demo_model.xgb', 1, 0, 60, -9.53674316e-07, 1, 2, 1, NULL), ('demo_model.xgb', 1, 1, 29, -9.53674316e-07, 3, 4, 3, NULL), ('demo_model.xgb', 1, 2, NULL, NULL, NULL, NULL, NULL, -6.23624468), ('demo_model.xgb', 1, 3, NULL, NULL, NULL, NULL, NULL, 0.78471756), ('demo_model.xgb', 1, 4, NULL, NULL, NULL, NULL, NULL, -0.968530357);
Now assuming the feature data for which we want to compute predictions is available in the following database table:
CREATE TABLE features ( sample_id INTEGER NOT NULL, feature_values INTEGER NOT NULL -- We only need integers for the demo. );
We can join the feature table to the model table, and start computing predictions:
And that's it. We computed predictions from a boosted tree model directly in SQL. Demo code for translating GBM models to SQL as shown above can be found on GitHub.
WITH recursive tree_eval AS ( -- The non-recursive term first selects all the root nodes. SELECT -- maintain sample_id, model_name and booster_id as index. f.sample_id, m.model_name, m.booster_id, CASE -- Choose the next node based on the split value of the root node. WHEN f.feature_values[m.feature_index] < m.split THEN m.node_yes WHEN f.feature_values[m.feature_index] >= m.split THEN m.node_no WHEN f.feature_values[m.feature_index] IS NULL THEN m.node_missing END AS node_id, m.prediction -- In case the root node is already a leaf... FROM tree_model m CROSS JOIN features f -- Compute predictions for all samples. WHERE m.model_name = 'demo_model.xgb' -- But only for one specific model. AND m.node_id = 0 -- Start from the root nodes. UNION ALL -- We won't have duplicates, each level has one node per tree. -- The recursive term node now follows the tree by one level. SELECT f.sample_id, p.model_name, p.booster_id, CASE -- Again choose the next node base on the split value. WHEN f.feature_values[m.feature_index] < m.split THEN m.node_yes WHEN f.feature_values[m.feature_index] >= m.split THEN m.node_no WHEN f.feature_values[m.feature_index] IS NULL THEN m.node_missing END AS node_id, m.prediction -- Will only be set for leaf nodes. FROM tree_eval p -- If the current node is a leaf node, the following join will -- return no rows, and therefore act as stop condition. JOIN tree_model m USING (model_name, booster_id, node_id) JOIN features f USING (sample_id) ) -- Finally sum up the predictions from the different trees for the result -- (according to the model objective function). SELECT sample_id, (1 / (1 + exp(-sum(prediction)))) AS prediction FROM tree_eval WHERE prediction IS NOT NULL -- Only leaf nodes have prediction values. GROUP BY sample_id ORDER BY sample_id;
Some additional notes:
- The implementation here uses recursive queries, which is a feature that not necessarily every database implements. While this certainly makes the prediction query easier to write, it is also not strictly required. Since the trees from a specific model have a fixed maximum depth, the recursion can be replaced with a fixed set of (left) joins. While this solution is arguably less elegant, it also shouldn't matter too much if the query is automatically generated. In fact, such a a query may even be faster than the recursive approach.
- The implementation here also uses an array data type
for the feature values, which is another feature not
necessarily every database implements. And even if the
database has such a data type, it's probably not used to
model the feature values (for good reasons).
Again the array data type is not strictly required
though, and is mainly used here to simplify the example
queries. The array access for the feature values can be
replaced either by direct access to the row id's in some
databases, or by an arguably less
CASEstatement in most other databases (again: matters less if queries are automatically generated).
- In any case, computing the predictions in SQL is unlikely to be actually faster than the computation using the optimized code of a boosted tree library. So this approach only makes sense if there is a scenario where indeed the majority of the time is spend moving data, instead of computing anything.