Read on realworldml.net
This is Part 3 of this mini-series on how to solve real-world business problems using Machine Learning.
So far we have learned how to frame the business problem as an ML problem, and how to prepare the data.
Let’s move on to the model training step.
Remember 🙋
The 4 steps to building a real-world ML product are
Problem framing (2 weeks ago)
Data preparation (last week)
Model training (today) 🏋️
MLOps (next week)
Let’s get started!
Example
Imagine you work at a ride-sharing app company in NYC as an ML engineer. And you want to help the operations team allocate the fleet of drivers optimally each hour of the day. The end goal is to maximize revenue.
You have already framed the problem as a time-series prediction problem and prepared the data.
Step 3. Model training
It is now time to train a good predictive model. For that, I recommend you follow these steps.
#1 Split the data into train and validation sets
Pick a cutoff timestamp, and split the data into 2 disjoint sets:
Training data that contains observations before the cutoff date. This dataset is used to create the model (aka train it)
Test data that contains observations posterior to the cutoff date. This dataset is only used at the end of your training, to evaluate your model performance on unseen data.
I recommend splitting the data at the very beginning of your model training script, so you avoid leakage between these 2 sets.
What about random splitting? 🤔
Random splitting is not a valid strategy for datasets that have high temporal autocorrelation, like this one, especially when you use a powerful algorithm like XGBoost. Random splitting would result in overly optimistic test metrics, that would never be matched in production once you deploy the model
#2. Create a baseline model, without ML
Before using any ML algorithm, you should set a baseline performance for the problem.
I recommend using a simple heuristic, that you can implement in a gist.
For example, to predict next-hour taxi demand you use the latest hour demand as an estimation.
class BaselineModelPreviousHour:
"""
Prediction = actual demand observed in the last hour
"""
def fit(self, X: pd.DataFrame, y: pd.Series):
pass
def predict(self, X: pd.DataFrame) -> pd.Series:
""""""
return X[f'rides_previous_1_hour']
And compute your error metric (e.g. Mean Absolute Error, MAE) on the test set.
This number gives you a baseline against which you can compare the error metrics you will get when using more complex ML models.
#3 Iteratively beat the baseline model, using ML
Machine Learning algorithms are pattern-finding machines, that can learn the right mapping between your features and your target.
It is best to start simple and keep on adding complexity, in an iterative way. Every new model version is trained on the training data, and evaluated using the right error metric (e.g. Mean Absolute Error) on the test data.
For example:
Build baseline model (from step #2) → MAE 1
Train an XGBoost model on the original raw features, with default hyperparameters → MAЕ 2 < MAE 1 🥳
Engineer a new feature and re-train your XGBoost model → MAE 3 > MAE 2 😔
Increase the number of samples in your training set by decreasing the step size, and re-train your XGBoost model → MAE 4 < MAE 2 🥳 🥳
Optimize your XGBoost hyperparameters → MAE 5 < MAE 4 🥳 🥳 🥳
4 strategies to improve your Machine Learning model ✨
Add more samples to the dataset
Add more features to the dataset, either new or engineered from existing ones.
Try another algorithm, e.g. LightGBM instead of XGBoost.
Tune the algorithm hyperparameters.
Once you are happy with the test metrics, your Machine Learning model prototype is ready to be deployed.
Next steps
Our ML model prototype is ready. It is now time to deploy it and put it to work.
But this is something we will cover next week.
Have a great weekend
And keep on learning!
Pau
An important way to improve model performance is data cleaning. For example, fixing mislabeled data points will reduce "noise" during training and the model is able to generalize better on unseen data. This is less intensive than collecting and labeling new data and can give considerable gains.
L we
L
L