Decisions, Decisions, Decisions…
Implementing Decision Trees, My 5th Machine Learning Algorithm
Up to this point in the Cornell Machine Learning Certificate program, the algorithms have been either classification (put a data point into a category — +1/-1, yes/no, etc.) or regression (predict a data point’s numeric value given the features.) Decision Trees are the first machine learning algorithm we’ve discussed where the same type of algorithm can be used for either purpose, which makes them flexible and powerful, as long as you know how to use them.
In this write-up, I will go over the basics of what a Decision Tree is, how it works, and how it can be used.
As with most machine learning scenarios, we start with a data set and a question. Let’s say we are an automobile manufacturer and we want to know based on data we can collect about a customer whether a specific person is more likely to purchase a car or a truck. We collect data on 1,000 vehicle buyers about gender, income, immediate family size, and what type of vehicle was eventually purchased. Now we have our data set and our question, it’s time to put it into a Decision Tree.
Let’s say we have 500 cars and 500 trucks in our data set. Right now, we have a 50/50 chance of pulling a random person from that top level node and getting a car vs. truck buyer. The question becomes, if I were to separate my data by one of the features available, would that end up being predictive in terms of car vs. truck purchase? The first thing a Decision Tree does is exactly that — split the data according to a feature value, and then see if the resulting data are more or less predictive than before. Let’s take gender, for example (and to keep things simple, we are going to assume only two categories for illustration purposes — in the real world, you’d really want to account for various gender expressions in order to be complete). If we divide our data into gender 1 and gender 2, what do we end up with in terms of predictive power based on that split?
In our fictitious example, it turns out dividing data by gender improves our predictive capability, going from a 50/50 split to a 60/40 split. So now you have two options — either pick one of the sub-nodes and try another split — say, by income — or check to see if one of your other features has even more predictive power as a first split. While gender appears to have *some* influence, it doesn’t look like our data supports this as a major predictor. What if we started over with family size?
So here we get some much nicer results: If a family has 3 or more members, it turns out they end up buying cars by a 6:1 ratio, whereas in our data set, families of 1 or 2 were more than twice as likely to buy a truck. One of the goals of a Decision Tree is to be as predictive as possible while also being as shallow as possible, so if you have a feature in your data that by itself gives you that kind of bang for your buck, that’s where you want to start.
The process continues from there, in much the same fashion. You keep going down the tree until you find the right level of accuracy, which might end up looking something like this:
In this example, we just picked one node each time to dive deeper into, but in reality the tree can branch out (pun intended) at any level and continue dividing until all features have been exhausted. At the bottom of each split we have what is called a “leaf node” (pun intended, I’m sure, but I don’t get to claim credit for that one) which is the node at which a prediction is made based on the most likely value in that node. In the example above, it would look like this:
So is this the best model to deploy into production?
Nope. Look carefully — after the first split (Family < 3) it turns out that every leaf that gets generated thereafter predicts truck. Remember our goal is to have the most shallow Decision Tree possible without losing precision, so it turns out the model we would want to deploy into production given the data we have available would look like this:
In the real world, the number of features and number of observations would likely be significantly higher, and so it rarely would turn out that a single split at the top level would be the best model. That said, in our fictitious example, that’s exactly what we would do.
However, if we modified our example slightly, the more complex tree would likely be worth the extra compute it took up. Take this alternative reality:
Since we now have criteria that changes our model at a certain point, we would use the more complex model and not stop after the first split. It all depends on how the data shakes out.
If we come back to Decision Trees, we will need to talk about regression problems, overfitting (where you get a very accurate model on your training data that doesn’t end up working very well in the real world), calculating “impurity” at each node, and a number of other topics that go into making the best Decision Tree possible. For now, I hope this high level overview and simple example at least helps you understand what Decision Trees are, how they work, and what they might be used for in real life.