flynn.gg

Christopher Flynn

Machine Learning
Systems Architect,
PhD Mathematician

Home
Projects
Open Source
Blog
Résumé

GitHub
LinkedIn

Blog


skranger - ranger in Python

2020-07-25 Feed

forest Photo by Johannes Plenio on Unsplash

The ranger library is self-described as a fast implementation of Random Forests, particularly suited for high dimensional data. It is available as an R package, which provides bindings to its C++ implementation of Random Forest algorithms based primarily on Brieman’s 2001 paper on the subject. In benchmarks, it trains significantly faster than scikit-learn’s RandomForestClassifier, where the tree implementation is written in C using Cython.

Last month, with a ton of help from Kevin Cybura, I released the first version of skranger, which provides python bindings ranger. It includes scikit-learn compatible classes for performing classification and regression, using the fast C++ implementation. Ranger also provides survival forests, so skranger also includes a scikit-survival compatible class for survival prediction using ranger’s implementation.

Here are some basic examples using sample datasets (from the README).

Classification

The RangerForestClassifier predictor uses ranger’s ForestProbability class to enable both predict and predict_proba methods.

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestClassifier

X, y = load_iris(True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

rfc = RangerForestClassifier()
rfc.fit(X_train, y_train)

predictions = rfc.predict(X_test)
print(predictions)
# [1 2 0 0 0 0 1 2 1 1 2 2 2 1 1 0 1 1 0 1 1 1 0 2 1 0 0 1 2 2 0 1 2 2 0 2 0 0]

probabilities = rfc.predict_proba(X_test)
print(probabilities)
# [[0.01333333 0.98666667 0.        ]
#  [0.         0.         1.        ]
#  ...
#  [0.98746032 0.01253968 0.        ]
#  [0.99       0.01       0.        ]]

Regression

The RangerForestRegressor predictor uses ranger’s ForestRegression class.

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestRegressor

X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

rfr = RangerForestRegressor()
rfr.fit(X_train, y_train)

predictions = rfr.predict(X_test)
print(predictions)
# [20.01270808 24.65041667 11.97722067 20.10345    26.48676667 42.19045952
#  19.821      31.51163333  8.34169603 18.94511667 20.21901915 16.01440705
#  ...
#  18.37752952 19.34765    20.13355    21.19648333 18.91611667 15.58964837
#  31.4223    ]

Survival

The RangerForestSurvival predictor uses ranger’s ForestSurvival class, and has an interface similar to the RandomSurvivalForest found in the scikit-survival package.

from sksurv.datasets import load_veterans_lung_cancer
from sklearn.model_selection import train_test_split
from skranger.ensemble import RangerForestSurvival

X, y = load_veterans_lung_cancer()
# select the numeric columns as features
X = X[["Age_in_years", "Karnofsky_score", "Months_from_Diagnosis"]]
X_train, X_test, y_train, y_test = train_test_split(X, y)

rfs = RangerForestSurvival()
rfs.fit(X_train, y_train)

predictions = rfs.predict(X_test)
print(predictions)
# [107.99634921  47.41235714  88.39933333  91.23566667  61.82104762
#   61.15052381  90.29888492  47.88706349  21.25111508  85.5768254
#   ...
#   56.85498016  53.98227381  48.88464683  95.58649206  48.9142619
#   57.68516667  71.96549206 101.79123016  58.95402381  98.36299206]

chf = rfs.predict_cumulative_hazard_function(X_test)
print(chf)
# [[0.04233333 0.0605     0.24305556 ... 1.6216627  1.6216627  1.6216627 ]
#  [0.00583333 0.00583333 0.00583333 ... 1.55410714 1.56410714 1.58410714]
#  ...
#  [0.12933333 0.14766667 0.14766667 ... 1.64342857 1.64342857 1.65342857]
#  [0.00983333 0.0112619  0.04815079 ... 1.79304365 1.79304365 1.79304365]]

survival = rfs.predict_survival_function(X_test)
print(survival)
# [[0.95855021 0.94129377 0.78422794 ... 0.19756993 0.19756993 0.19756993]
#  [0.99418365 0.99418365 0.99418365 ... 0.21137803 0.20927478 0.20513086]
#  ...
#  [0.87868102 0.86271864 0.86271864 ... 0.19331611 0.19331611 0.19139258]
#  [0.99021486 0.98880127 0.95299007 ... 0.16645277 0.16645277 0.16645277]]

Cython

skranger uses Cython, a successor to pyrex, to enable calling into the ranger C++ code. Cython is a programming language that is similar in syntax to Python, but contains additional components that allow writing C or C++ module extensions that can be called directly from Python. The Cython code is used to generate efficient C code, which is then compiled. The compiled C or C++ extensions are generally much more performant than vanilla Python implementations.

Binding with ranger

To interact with ranger, we mostly follow the guide for interacting with C++ using Cython.

For any C++ object in ranger with which we wish to interact, we must declare a reference in a .pxd file, which can be thought of as a header file. For instance, we wish to create instances of the ForestProbability class in C++. We must reference both the .cpp file for the class, as well as the exposed implementation in the corresponding header file. We declare the methods and corresponding signatures that we want to utilize from Cython. Here is what the ForestProbability reference looks like in the .pxd file.

cdef extern from "./ranger/src/Forest/ForestProbability.cpp":
    pass

cdef extern from "./ranger/src/Forest/ForestProbability.h" namespace "ranger":
    cdef cppclass ForestProbability(Forest):
        ForestProbability() except +
        vector[double]& getClassValues()
        vector[vector[vector[double]]] getTerminalClassCounts()
        void setClassWeights(vector[double]& class_weights)
        void loadForest(
            size_t num_trees,
            vector[vector[vector[size_t]]]& forest_child_nodeIDs,
            vector[vector[size_t]]& forest_split_varIDs,
            vector[vector[double]]& forest_split_values,
            vector[double]& class_Values,
            vector[vector[vector[double]]]& forest_terminal_class_counts,
            vector[bool]& is_ordered_variable,
        )

We also need to create a .pyx file, which contains the functions which call the C++ objects from Python. This Cython module will be used in a .py file, and acts as the bridge between Python and C++. Below is a heavily truncated version of the implementation found here. The function is very similar to the R binding implementation in ranger.

# Function which is called by a python module
cpdef dict ranger(
    ranger_.TreeType treetype,
    np.ndarray[double, ndim=2, mode="fortran"] x,
    np.ndarray[double, ndim=2, mode="fortran"] y,
    vector[string]& variable_names,
    unsigned int mtry,
    unsigned int num_trees,
    bool verbose,
    ...
):

    result = {}

    # C++ object declarations
    cdef unique_ptr[ranger_.Forest] forest

    cdef ranger_.ostream* verbose_out

    cdef vector[vector[vector[size_t]]] child_node_ids
    cdef vector[vector[size_t]] split_var_ids
    cdef vector[vector[double]] split_values

    ...

    try:
        ...
        # Calling into the C++ object
        deref(forest).initR(
            move(data.c_data),
            mtry,
            num_trees,
            verbose_out,
            ...
        )

        ...

    except Exception as exc:
        raise exc

    # Returning a python dictionary with the results
    return result

Finally we can import the .pyx module in python, and call it directly. Here is a heavily truncated version of one of the calls from skranger.

# importing the pyx module
from skranger.ensemble import ranger

class RangerForestClassifier(RangerValidationMixin, ClassifierMixin, BaseEstimator):
    r"""Ranger Random Forest Probability/Classification implementation for sci-kit learn.
    ...
    """

    ...

    def fit(self, X, y, sample_weight=None):
        """Fit the ranger random forest using training data.
        :param array2d X: training input features
        :param array1d y: training input target classes
        :param array1d sample_weight: optional weights for input samples
        """
        self.tree_type_ = 9  # tree_type, TREE_PROBABILITY enables predict_proba

        # lots of input validation
        ...

        # Fit the forest using the cpdef function
        self.ranger_forest_ = ranger.ranger(
            self.tree_type_,
            np.asfortranarray(X.astype("float64")),
            np.asfortranarray(np.atleast_2d(y).astype("float64").transpose()),
            self.feature_names_,  # variable_names
            self.mtry_,
            self.n_estimators,  # num_trees
            ...
        )
        return self

Since scikit-learn uses numpy for data containers, the final part of this project is to be able to pass numpy arrays to C++. Ranger encapsulates its data structures in several different classes, with a parent class Data and child classes for different C++ numeric data types.

To pass numpy arrays, we create a DataNumpy class inheriting from ranger’s Data, allowing us to pass necessary numpy data and arguments to work with ranger.

#include "globals.h"
#include "utility.h"
#include "Data.h"

#ifndef DATANUMPY_H_
#define DATANUMPY_H_

namespace ranger {

class DataNumpy: public Data {
public:
  DataNumpy() = default;
  DataNumpy(double* x, double* y, std::vector<std::string> variable_names, size_t num_rows, size_t num_cols, size_t num_cols_y) {
    std::vector<double> xv(x, x + num_cols * num_rows);
    std::vector<double> yv(y, y + num_cols_y * num_rows);
    this->x = xv;
    this->y = yv;
    this->variable_names = variable_names;
    this->num_rows = num_rows;
    this->num_cols = num_cols;
    this->num_cols_no_snp = num_cols;
  }

  DataNumpy(const DataNumpy&) = delete;
  DataNumpy& operator=(const DataNumpy&) = delete;

  virtual ~DataNumpy() override = default;

  // getters and setters
  ...

private:
  std::vector<double> x;
  std::vector<double> y;
};

} // namespace ranger

We must also wrap the class in Cython, so that we can call it directly from Python.

# enable using numpy and its C API
import numpy as np
cimport numpy as np


cdef class DataNumpy:
    """Cython wrapper for DataNumpy C++ class in ``DataNumpy.h``.
    This wraps the Data class in C++, which encapsulates training data passed to the
    random forest classes. It allows us to pass numpy arrays as a ranger-compatible
    Data object.
    """
    cdef unique_ptr[ranger_.DataNumpy] c_data

    @cython.boundscheck(False)
    @cython.wraparound(False)
    def __cinit__(self,
        np.ndarray[double, ndim=2, mode="fortran"] x not None,
        np.ndarray[double, ndim=2, mode="fortran"] y not None,
        vector[string] variable_names,
    ):
        cdef size_t num_rows = np.PyArray_DIMS(x)[0]  # in lieu of x.shape
        cdef size_t num_cols = np.PyArray_DIMS(x)[1]
        cdef size_t num_cols_y = np.PyArray_DIMS(y)[1]
        self.c_data.reset(
            new ranger_.DataNumpy(
                &x[0, 0],
                &y[0, 0],
                variable_names,
                num_rows,
                num_cols,
                num_cols_y,
            )
        )

    def get_x(self, size_t row, size_t col):
        return deref(self.c_data).get_x(row, col)

    def get_y(self, size_t row, size_t col):
        return deref(self.c_data).get_y(row, col)

    def reserve_memory(self, size_t y_cols):
        return deref(self.c_data).reserveMemory(y_cols)

    def set_x(self, size_t col, size_t row, double value, bool& error):
        return deref(self.c_data).set_x(col, row, value, error)

    def set_y(self, size_t col, size_t row, double value, bool& error):
        return deref(self.c_data).set_y(col, row, value, error)

Tying all of these components together requires compilation. To do that we use a build.py file which is referenced in the project’s pyproject.toml file under poetry tooling. Poetry will then generate a setup.py file at build time which first calls build.py to compile the C++/Cython code. This build process might not work on every platform, however. This motivates us creating pre-compiled wheel .whl files that are platform specific and that can be installed without invoking build.py.

Wheels

Wheels are standard python distributable files that are built specifically for at least one platform-python version combination. They have extension .whl and are a replacement for the legacy python .egg files which are also for installing packages.

To build wheels for multiple platforms for skranger, we use a project called multibuild, which is commonly used in python’s ecosystem of scientific packages to build software like pillow and scikit-learn to wheel files.

The project provides scripts for CI on travis and appveyor which allows Cython projects to be built on Linux, MacOS, and Windows. The Linux builds also generate manylinux distributables, which are compatible with a wide array of different Linux distros.

The scripts rely on running a setup.py file to build wheels, but since we are using poetry to maintain the project, we have to write some custom scripts to invoke the build. There is also a pre-build step in which the ranger C++ files are copied into the skranger project that must be run first.

Luckily the multibuild project is flexible enough such that its commands can be overridden by downstream projects. This allows us to use poetry in CI to build the wheels.

Thanks to the multibuild project, wheel files are available for skranger on Linux and MacOS for Python versions 3.6, 3.7, and 3.8.

License

Ranger has two licenses. The C++ source code is licensed under MIT. The R package bindings are licensed under GPLv3. The skranger source borrows a lot of implementation logic from the R package, thus the skranger Python and Cython code is also licensed under GPLv3.

Further reading

ranger

skranger

Cython, build

CI

python

Back to the posts.