First steps with Scikit-plot

Eager to use Scikit-plot? Let’s get started! This section of the documentation will teach you the basic philosophy behind Scikit-plot by running you through a quick example.


Before anything else, make sure you’ve installed the latest version of Scikit-plot. Scikit-plot is on PyPi, so simply run:

$ pip install scikit-plot

to install the latest version.

Alternatively, you can clone the source repository and run:

$ python install

at the root folder.

Scikit-plot depends on Scikit-learn and Matplotlib to do its magic, so make sure you have them installed as well.

Your First Plot

For our quick example, let’s show how well a Random Forest can classify the digits dataset bundled with Scikit-learn. A popular way to evaluate a classifier’s performance is by viewing its confusion matrix.

Before we begin plotting, we’ll need to import the following for Scikit-plot:

>>> import matplotlib.pyplot as plt

matplotlib.pyplot is used by Matplotlib to make plotting work like it does in MATLAB and deals with things like axes, figures, and subplots. But don’t worry. Unless you’re an advanced user, you won’t need to understand any of that while using Scikit-plot. All you need to remember is that we use the function to show any plots generated by Scikit-plot.

Let’s begin by generating our sample digits dataset:

>>> from sklearn.datasets import load_digits
>>> X, y = load_digits(return_X_y=True)

Here, X and y contain the features and labels of our classification dataset, respectively.

We’ll proceed by creating an instance of a RandomForestClassifier object from Scikit-learn with some initial parameters:

>>> from sklearn.ensemble import RandomForestClassifier
>>> random_forest_clf = RandomForestClassifier(n_estimators=5, max_depth=5, random_state=1)

The magic happens in the next two lines:

>>> from scikitplot import classifier_factory
>>> classifier_factory(random_forest_clf)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
        max_depth=5, max_features='auto', max_leaf_nodes=None,
        min_impurity_split=1e-07, min_samples_leaf=1,
        min_samples_split=2, min_weight_fraction_leaf=0.0,
        n_estimators=5, n_jobs=1, oob_score=False, random_state=1,
        verbose=0, warm_start=False)

In detail, here’s what happened. classifier_factory() is a function that modifies an instance of a scikit-learn classifier. When we passed random_forest_clf to classifier_factory(), it appended new plotting methods to the instance, while leaving everything else alone. The original variables and methods of random_forest_clf are kept intact. In fact, if you take any of your existing scripts, pass your classifier instances to classifier_factory() at the top and run them, you’ll likely never notice a difference! (If something does break, though, we’d appreciate it if you open an issue at Scikit-plot’s Github repository.)

Among the methods added to our classifier instance is the plot_confusion_matrix() method, used to generate a colored heatmap of the classifier’s confusion matrix as evaluated on a dataset.

To plot and show how well our classifier does on the sample dataset, we’ll run random_forest_clf‘s new instance method plot_confusion_matrix(), passing it the features and labels of our sample dataset. We’ll also pass normalize=True to plot_confusion_matrix() so the values displayed in our confusion matrix plot will be from the range [0, 1]. Finally, to show our plot, we’ll call

>>> random_forest_clf.plot_confusion_matrix(X, y, normalize=True)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
Confusion matrix

And that’s it! A quick glance of our confusion matrix shows that our classifier isn’t doing so well with identifying the digits 1, 8, and 9. Hmm. Perhaps a bit more tweaking of our Random Forest’s hyperparameters is in order.


The more observant of you will notice that we didn’t train our classifier at all. Exactly how was the confusion matrix generated? Well, plot_confusion_matrix() provides an optional parameter do_cv, set to True by default, that determines whether or not the classifier will use cross-validation to generate the confusion matrix. If True, the predictions generated by each iteration in the cross-validation are aggregated and used to generate the confusion matrix.

If you do not wish to do cross-validation e.g. you have separate training and testing datasets, simply set do_cv to False and make sure the classifier is already trained prior to calling plot_confusion_matrix(). In this case, the confusion matrix will be generated on the predictions of the trained classifier on the passed X and y.

The Functions API

Although convenient, the Factory API may feel a little restrictive for more advanced users and users of external libraries. Thus, to offer more flexibility over your plotting, Scikit-plot also exposes a Functions API that, well, exposes functions.

The nature of the Functions API offers compatibility with non-scikit-learn objects.

Here’s a quick example to generate the precision-recall curves of a Keras classifier on a sample dataset.

>>> # Import what's needed for the Functions API
>>> import matplotlib.pyplot as plt
>>> import scikitplot.plotters as skplt
>>> # This is a Keras classifier. We'll generate probabilities on the test set.
>>>, y_train, batch_size=64, nb_epoch=10, verbose=2)
>>> probas = keras_clf.predict_proba(X_test, batch_size=64)
>>> # Now plot.
>>> skplt.plot_precision_recall_curve(y_test, probas)
<matplotlib.axes._subplots.AxesSubplot object at 0x7fe967d64490>
Precision Recall Curves

And again, that’s it! You’ll notice that in this plot, all we needed to do was pass the ground truth labels and predicted probabilities to plot_precision_recall_curve() to generate the precision-recall curves. This means you can use literally any classifier you want to generate the precision-recall curves, from Keras classifiers to NLTK Naive Bayes to XGBoost, as long as you pass in the predicted probabilities in the correct format.

More Plots

Want to know the other plots you can generate using Scikit-plot? Visit the Factory API Reference or the Functions API Reference.