tools,

PyTorch Lightning for surface crack classification. Less boilerplate code?

Jul 24, 2020 · 7 mins read
PyTorch Lightning for surface crack classification. Less boilerplate code?
Share this

First version has been relased in March 2019. At the time of writing we have version 0.9 pre-release available. It means that PyTorch Lightning is a new player in town. But what it is and how it is different from Tensorflow or (more importantly) PyTorch? Is it really a PyTorch wrapper with less boilerplate? I used it in a simple project, so I can share my thoughts.

Less boilerplate?

Authors of this library describe their project as:

The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate

PyTorch with less boilerplate sounded amazing to me, because I was personally overwhelmed by amount of code I have to write in it. Manually transfering data to GPU device, implementing boilerplate training loop and all that stuff. I would love to get rid of that. Which parts of boilerplate have been cut out?

Defining model and training loop (LightningModule)

Look a the screenshot from Lightning that shows code equivalents in PyTorch and in their project. Nothing has changed in the way you define layers or a forward pass flow, but look, there is a huge gap in amount of code you need for training loop!

Basically all the low-level functionalities like: manually loading tensors to GPU, zero_grad(), backward pass etc. are gone. All you need to do is to define a training_step method that returns your metrics after a step forward.

And it’s nicely encapsulated in your model class. Even creating PyTorch dataloader objects (for training, validation and test sets separately) is handled as a part of LightningModule class interface.

So for my project I tried to define my own LightningModule and well, these images don’t lie. My code looks similarly simple. And damn I love the way it is all organized in one place and with simple functions. Indeed, the class gets bigger after you implement functions like validation_step or training_epoch_end and you calculate metrics inside but still it’s readable clear enough for me.

Training part

In the next step I wanted to train my model. How do we configure and run training process now? There is another screenshot from authors and it looks promising. Trainer class takes care of training (validation and testing as well) process and you don’t have to write all that boilerplate below on your own. It is all hidden within a class that has a fit() method with nice API. Take a look.

I think this example is oversimplified and it fails to present cool things you can do with Trainer. So take another look, this time at my usage of that class which has some more code. Here I use an additional logger object (Neptune.ai for tracking experiments’ history) and callback for early stopping - to prevent overfitting.

# Construct LightningModule object with parameters map.
model  = LitModel(hparams=dict(cfg))

# Neptune.ai logger for tracking my experiments (optional, but cool).
# PyTorch Lightning supports also Tensorboard, MLFlow and Comet here.
logger = pl.loggers.NeptuneLogger(
    api_key=None,
    params=dict(cfg),
    **cfg.neptune
)

# Stop early when val_loss stops improving.
early_stopping = pl.callbacks.EarlyStopping('val_loss')

# Create Trainer object, note that we don't pass any specific model yet!
trainer = pl.Trainer(
    logger=logger,
    early_stop_callback=early_stopping,
    gpus=cfg.training.use_gpu,
    max_epochs=cfg.training.max_epochs,
)

# Here we train specific model (our LitModel).
trainer.fit(model)

# And here we ask for predictions and score on test set.
# This may be a bit unfortunate example because I don't pass any dataset to
# models constructor explicitly (it's read based on cfg dictionary inside
# LitModel constructor).
trainer.test()

It is still quite simple and there are more cool things that you can configure in your Trainer object. For instance, you may set train_percent_check to tell that e.g. only 10% of your training data should be used right now. Such a useful feature if you only want a quick run for debugging purposes.

Moreover, I really like the idea of pl.loggers that let you use tools like Neptune.ai, MLFlow, Comet or Tensorboard interchangeably. Lightning supports all of these, so you just have to change NeptuneLogger to CometLogger if you decide to switch and it’s all working. I really appreciate it.

You can see the class documentation here to read about its other features. And speaking of docs…

Any documentation?

Lightning has a documentation page with classic Read-the-Docs that you’re probably familiar with. It starts with a Quick Start guided that shows you how Lightning’s code is organized (similar to screenshots above). They also have another tutorial explaining how to implement your first project step by step.

Core parts of the library (such as LightningModule, Trainer or logger objects) have their own pages in documentation. You can read about basic usage, best practices and full API description there. Frankly, I don’t remember if there was any thing I could not find there while implementing my project.

One more cool thing is a list of community examples, showing how people used Lightning for image classification, NLP or speech transformers. Unfortunately, my little project is not on the list yet… 🤔

Sometimes you get stuck

I think here’s the first problem I can mention. It is a (relatively) new project so, even though getting more and more popular, there are still only a few StackOverflow Q&A or community examples you can find. Sometimes you may get stuck and there may not be any answer (yet).

This was my impression when I was working on the project (like a month ago) and I won’t be surprised if there are twice as much information and examples available right now. Either way I think it’s a matter of time, because the governance and the community itself is stable. Lightning has many community users and contributors that will gladly help you solve your problems.

But I think we still need to wait a bit longer to see if it’s a real threat for other libraries.

Lightning is a better PyTorch

In fact, a lot of things. Lightning turned out to really be a lightweight wrapper that reduces amount of boilerplate. But it still belongs to the PyTorch ecosystem and uses the same code and principles as PyTorch.

When you prepare datasets or define layers of the model, you will use pure PyTorch. This is a fact that makes it so easy to convert your existing projects to Lightning. You will only remove a part of code responsible for training, making your code shorter and more organized. It’s a win-win for me.

Final thoughts

I can definitely feel that I wrote “less boilerplate” this time. Even though you still need to implement few things by yourself (forward, training_step) they are nicely organized within Lightning class. So they look similar in all projects you may see, from your colleagues or Internet.

I also like the idea of Trainer class that can be used to fit different models once you create it. This could be a great idea if you want to traing and test different models with the same training setup (split ratio, max. number of epochs etc.).

At last, it reminds me of Keras - or maybe it should be put between Keras and PyTorch in terms of abstraction it provides. You don’t need to care about all boilerplate and details, but still it gives you a bit much of more control than Keras. This seems to be perfect for me.

And you? Have you ever worked with Lightning before? If so, share you thoughts and opinion.