Understand binary classification and decision trees with only 9 lines of code

Jul 01, 2020 · 6 mins read
Understand binary classification and decision trees with only 9 lines of code
Share this

I’ve already mentioned that there are plenty of machine learning tasks (problems) that you can tackle. Some of them are pretty easy, some of them require a lot of skills and experience in the field. In this post, I will to introduce (in my opinion) the easiest one and describe it in a way that everyone should understand. At the end, we will create and train a decision tree model on our data and see how it performs.

What is a Binary Classification?

The problem of a binary classification is something that you already know and in fact you face it everyday. Here are some real-world examples of binary classification tasks:

  • Will it rain tomorrow? (yes or no)
  • Is this an orange or a grapefruit?
  • Is the patient having a heart disease? (yes or no)

The general idea is to classify a single element of a group of elements into one of two groups on the basis of available information (or more formally: classification rule). More precisely, the information comes from numerous features that each element (sample) has e.g. weight, color, current humidity, temperature etc.

Humans can do a lot of binary classification tasks on autopilot, but we had to learn that before. You may be able to tell whether it will rain tomorrow, but were you always able to do that? Similarly, an experienced doctor may tell if a patient is sick or has a certain disease - but he had to learn a lot in the first place.

The same is true for computers - and that’s what we call machine learning.

Using Decision Trees for classification

I believe that decision trees are very intuitive and you saw them once or twice already. All they do is dividing a set of all samples into two paths (smaller decision trees) based on values of their features e.g. “does the patient have a higher heart rate?” - it will create two smaller groups of samples with answers “yes” or “no”. Then we may ask further questions that will bring us closer to classification result.

To better understand it, let me first introduce a simple problem we will tackle in this post:

  1. We want to classify flower samples as either roses or tulips.
  2. Here’s what we know about each flower:
    • whether it has thorns (yes or no),
    • how much does it cost (single price value).

And here is what an example decision tree might look like for our problem:

But it’s not our responsibility for come up with all these conditions - it’s what decision tree does. All we have to do, is to prepare data: features of each flower with their corresponding labels. And then we train our tree model to make it find the right splits. Let’s see how it works in practice.

Training your first machine learning model

First of all, I created a small set of training data that you can see below. It has 10 rows (5 for roses and 5 for tulips) with features (has_thorns, price) and labels (flower). For such trivial example this should be enough.

Here 1 and 0 values of has_thorns feature correspond to yes and no values respectively. Then we have a price of each flower and at the end - a label (rose or tulip). You may noticed that I made it up to fit the sample tree I introduced earlier. It means that our final tree should look similar, if trained well.

Training decision tree classifier

Let’s waste no time and jump straight into code. I will use Python and three classic modules as listed below, we need need to include these in the first place. The class of the day is a DecisionTreeClassifier from scikit-learn library that’s going to do all the work.

import pandas as pd # for reading CSV files
import matplotlib.pyplot as plt # for displaying plots
from sklearn.tree import DecisionTreeClassifier, plot_tree

Now, we use pandas read_csv function to read CSV data I prepared. It reads whole file into one dataframe, but we need to store features (data) and labels separately, so we extract flower column to another “data frame”.

data = pd.read_csv('flowers.csv')
labels = data.pop('flower') # drop "flower" column from data table - leaving features

Finally, let’s create DecisionTreeClassifier and fit it with our features as a first parameter and labels as the second one. This single line is enough to train the model. Additionally, we will use sklearn’s plot_tree function that allows us to visualize final tree - and see how it decided to split data into smaller trees.

clf = DecisionTreeClassifier(), labels)

plot_tree(clf, feature_names=['has_thorns', 'price'],
          class_names=['rose', 'tulip'], impurity=False, filled=True)

That’s it. 9 lines. Our model is ready and we can see how it looks like and how well it performs.

How does our decision tree look like?

Here is a result of plot_tree function, you can compare it with what I drawn at the top. Does it look similar?

In the first line of each block you can see a condition used for splitting decision tree. Number of samples tells us how many samples are still being taken into consideration in current block. In the line saying value you can see how many roses and tulips are still “available” at the moment.

One more thing, with one more additional line we can see how the model performs on our training data:

clf.score(data, labels)
# => 1.0

100% of data rows have been classified successfully. With only 9 lines of code, fairly simple I would say.


Of course this one was an extremely simple problem - only two features, ten samples without any noise. Moreover, you would normally have not only training samples but also another data set for validation and testing of your model. But that’s a separate topic that I will cover soon.

Nonetheless, the goal of this post was to introduce binary classification and decision trees and show how painless it is to train your first machine learning model. No math included. Simple, short code.

Please let me know if you liked the topic and if you want to practice a bit, here are few other (more serious) binary classification datasets where you can use decision trees and get some decent results: