finding learning rate for neural nets

how to decide on learning rate for training (with PyTorch Lightning)

Among all the hyper-parameters used in machine learning algorithms, the learning rate is probably the very first one you learn about. Most likely it is also the first one that you start playing with. There’s a chance to find the optimal value with a bit of hyper parameter optimization, but this requires lots of experimenting. So why not let your tools do that (faster)?

Understanding Learning Rate

The shortest explanation of what learning rate really is that it controls how fast your network (or any algorithm) learns. If you recall how supervised learning works, you should be familiar with the concept of neural network adapting to the problem based on supervisor’s response e.g.:

Thanks to that, it can give more and more accurate answers iteratively. Every single time it receives feedback from the supervisor, it learns what the correct answer should be. And learning rate controls how much to change the model perception in response to recent errors (feedback).

Good Learning Rate, Bad Learning Rate

If there is a “good learning rate” that we are looking for, does “bad learning” also exist? And what does it mean? Let me stick to the concept of supervised learning and discuss a trivial example:

While this is not really a neural network example, let’s imagine how one could play that game and let’s assume that it’s not purely random guessing this time.

Each time you get a feedback — would you rather take a small step towards the right answer or maybe a big leap? While this example is not really about neural networks or machine learning, this is essentially how learning rate works.

Now, imagine taking only tiny steps, each bringing you closer to the correct number. Will this work? Of course. However, it can really take some time until you get there. This is the case of small learning rate. In context of machine learning, a model with too-small LR would be a slow-learner and it would need more iterations to solve the problem. Sometimes you may already decide to stop the training (or playing Guess the Number game) before it is finished.

You can also pick a value that’s too large. And what then? For instance, it may cause a neural network to change its mind too drastically (and too often). Every new sample will have a huge impact on your network beliefs. Such training will be highly unstable. It is no longer a slow-learner, but it may be even worse: your model may end up not learning anything useful in the end.

Learning Rate Range Test

Lesnie N. Smith in “Cyclical Learning Rates for Training Neural Networks” introduces a concept of cyclical learning rate — increasing and decreasing in turns during training. However, there’s one important thing covered in the paper, so called “LR Range test” (Section 3.3).

The whole thing is relatively simple: we run a short (few epochs) training session in which learning rate is increased (linearly) between two boundary values min_lr and max_lr. At the beginning, with small learning rate the network will start to slowly converge which results in loss values getting lower and lower. At some point, learning rate will get too large and cause network to diverge.

Figure 1. Learning rate suggested by lr_find method.

If you plot loss values against tested learning rate (Figure 1.), you usually look for the best initial value of learning somewhere around the middle of the steepest descending loss curve — this should still let you decrease LR a bit using learning rate scheduler. In Figure 1. where loss starts decreasing significantly between LR 0.001 and 0.1, red dot indicates optimal value chosen by PyTorch Lightning framework.

Finding LR in PyTorch Lightning

Recently PyTorch Lightning became my tool of choice for small machine learning projects. I have used it for the first time couple months ago and I keep using it since then. Apart from all the cool stuff it has, it also provides Learning Rate Finder class that will help you find a good learning rate.

Using LR Finder is nothing but auto_lr_find parameter in Trainer class:

class LitModel(LightningModule):
    def __init__(self, learning_rate):
        self.learning_rate = learning_rate
        
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=(self.lr or self.learning_rate))
                                        
trainer = Trainer(auto_lr_find=True) # by default it's False

Now when you call trainer.fit method, it performs learning rate range test underneath, finds a good initial learning rate and then actually trains (fit) your model straight away. So basically it all happens automatically within fit call and you have absolutely nothing to worry about.

As stated in the documentation, there is an alternative approach that allows you to use LR Finder manually and inspect its results. This time you have to create Trainer object with a default value of auto_lr_find (False) and call lr_find method manually:

lr_finder = trainer.tuner.lr_find(model) # Run learning rate finder

fig = lr_finder.plot(suggest=True) # Plot
fig.show()

model.hparams.lr = lr_finder.suggestion()

trainer.fit(model) # Fit model

And that’s it. The main advantage of this approach is that you can take a closer look at the plot that shows which value was chosen (see Fig. 1.).

Example: LR Finder for Fashion MNIST

I decided to train a fairly simple network architecture (LeNet) on Fashion MNIST dataset. I ran four separate experiments which only differed in initial learning rate values. Three of them hand-picked (\(10^{-5}\), \(10^{-4}\), \(10^{-1}\)) and the last one suggested by Learning Rate Finder. I will not describe whole implementation and other parameters (you can explore that in the repo). Let me just show you the findings.

It took around 12 seconds (!) to find the best initial learning rate which for my network and problem being solved turned out to be \(0,0363\). Looking at loss versusLR plot (exactly the one in Figure 1) I was surprised because the suggested point is not exactly “halfway the sharpest downward slope” (as mentioned in the paper). However, I could not tell whether that was good or bad, so started to train the model.

For logging and visualization I used TensorBoard to log loss and accuracy during training and validation steps. Below you can see metrics history for each of four experiments.

Figure 2. Training and validation accuracy for 4 experiments.

Learning rate suggested by Lightning (light blue) seems to outperform other values in both training and validation phase. At the end it reached \(88,85\%\) accuracy on validation set which is the highest score from all experiments (Figure 2). More than that, loss function values were clearly the best for the “find_lr” experiment. In the last validation step it reached loss equal to \(0,3091\) which was the lowest value compared to other curves in Figure 3.

Figure 3. Training and validation loss for 4 experiments.

Conclusion

In that short experiment, Learning Rate Finder has outperformed my hand-picked learning rates. Of course, there was a chance that I could have picked \(0,0363\) as my initial guess, but the whole point of LR Finder is to minimize all the guesswork (unless you are a luck many, then you don’t need it).

I think using this feature is useful, as written by Leslie N. Smith:

Whenever one is starting with a new architecture or dataset, a single LR range test provides both a good LR value and a good range. Then one should compare runs with a fixed LR versus CLR with this range. Whichever wins can be used with confidence for the rest of one’s experiments.

If you don’t want to perform hyper-parameter search using different values, which can take ages, you have two options left on the table: pick initial values at random (which may leave you with terribly bad performance and convergence but may work great if you are a lucky man though) or use a learning rate finder included in your machine learning framework of choice.

Which one would you pick?