Building LSTMs with PyTorch and Lightning AI Part 8: Setting Up a Simpler LSTM
DEV Community

Building LSTMs with PyTorch and Lightning AI Part 8: Setting Up a Simpler LSTM

Continuing Training with More Epochs

In the previous article, we saw how easily we could continue training by adding more epochs. We also observed the improvements in the model's predictions using TensorBoard. Let's train the model one more time to bring the predictions even closer to the desired values.

As before, we first retrieve the latest checkpoint:

path_to_best_checkpoint = trainer.checkpoint_callback.best_model_path

Next, we increase the maximum number of epochs to 5000:

trainer = L.Trainer(max_epochs=5000)

Then we resume training by calling fit() with the checkpoint:

trainer = L.Trainer(max_epochs=5000)
trainer.fit(
    model,
    train_dataloaders=dataloader,
    ckpt_path=path_to_best_checkpoint
)

Once training is complete, we can print the predictions again:

print("\nComparing observed and predicted values")
print(
    "Company A: Observed = 0, Predicted =",
    model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)
print(
    "Company B: Observed = 1, Predicted =",
    model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)

This gives the following output:

Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor(0.0004)
Company B: Observed = 1, Predicted = tensor(0.9672)

As you can see, both predictions are now much closer to their target values.

Inspecting Training Progress with TensorBoard

Next, let's open TensorBoard and inspect the updated training graphs. Notice that all of the graphs have become much flatter, indicating that the model has nearly converged. The predictions for Company A and Company B are now very close to their target values of 0 and 1, respectively.

At this point, we have successfully trained our LSTM model.

Building a Simpler LSTM with nn.LSTM

Now let's look at an even simpler way to build an LSTM using PyTorch's built-in nn.LSTM module. We will start by creating another class:

class LightningLSTM(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=1)

The input_size specifies the number of features in each input. In our example, each day contains only a single feature, which is the stock price of a company. The hidden_size specifies the number of values in the hidden state. For this example, we use a hidden size of 1, since we ultimately want the model to predict a single value for Day 5.

Next, we will implement the forward() method, which works a little differently when using nn.LSTM. We will explore that in the next article.


AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production. git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free. Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use. Give it a ⭐ star on Github

Comments

No comments yet. Start the discussion.