Text Generation With Pytorch

Reading Time: 4 minutes

Hello guys! Here we are again to have some fun with deep learning. As of the previous post, we trained a model to generate text with Tensorflow. Today, I am gonna show you how we can do it with Pytorch. Let’s go!

Overview

In this blog post, what we are going to do is pretty much the same as what we did in the last post. We will create a model which can learn to generate some meaningful context like below:

“I am sure we have studied Hogwarts and saw that the magical appearance of Dumbledore was starting to fear that the innocent” she said. Harry had just given the illusion how stars had lunged in at the same moment they were staring into a corner, the faint wall had obliged in the ground, he tried, but the detritus of magical creature lay in the air in front and spitefully the bond about a celebrated of green and brown, that weapons began weight magicked the wounds quickly; Dolohov.

Prerequisites

To get the most out of today’s post, I suggest that you have:

  • Python installed (Python3 is definitely recommended)
  • Pytorch installed (at least version 1.0)
  • Some experience with Python and know how RNNs, word embeddings work
  • Read my previous post (link here)

About the last thing though, the logic behind how things work remains the same regardless of whether your code is written in Tensorflow or Pytorch, so this post will be focused on Pytorch implementation only.

Dataset

The data processing code from the last post is not Tensorflow-dependent, which means that we can use as-is without any modifications.

Firstly, let’s import the packages we need for today:

Obviously we can’t use tf.app.flags, but we always have argparse at our back to do the job.

Next, we need a function to process the raw data. You can check the implementation detail in the Dataset session of the last post. Here I only show you the complete code:

And finally, we must define a function to generate batches for training:

That is all we need for this step. Phew! Not always that easy though, but
just make things simple where things can be simple, right?

Model

Creating a network in Pytorch is very straight-forward. All we have to do is create a subclass of torch.nn.Module, define the necessary layers in __init__ method and implement the forward pass within forward method.

Let’s recall a little bit. We need an embedding layer, an LSTM layer, and a dense layer, so here is the __init__ method:

The next method, forward, will take an input sequence and the previous states and produce the output together with states of the current timestep:

Because we need to reset states at the beginning of every epoch, we need to define one more method to help us set all states to zero:

That may look strange to some of you. Since LSTM’s states consist of two separate states called hidden states and memory states (denoted as state_h and state_c respectively). Remember this difference when using LSTM units.

Loss

We have done with the network. Now we need a loss function and a training op. Defining the two is surprisingly simple in Pytorch:

“We’re not doing gradient clipping this time?”, you may ask. So glad that you pointed it out. Of course we will, but not here. You will see in a second.

Training

We are ready to train the network. Here we will come across one thing that some may like while others may not favor at all: manually manage the data transfer between devices.

If your machine doesn’t have a GPU, you are somehow lucky. For those who have, just don’t forget to keep track of where your tensors are. Here are some tips of mine:

  • If the training is slow, you might have forgotten to move data to GPU
  • You can move everything to GPU first, then fix along the errors until things work.

Okay, let’s code. First, we will get the device information, get the training data, create the network, loss function and the training op. And don’t forget to transfer the network to GPU:

Next, for each epoch, we will loop through the batches to compute loss values and update network’s parameters. A typical set of steps for training in Pytorch is:

  • Call the train() method on the network’s instance (it will inform inner mechanism that we are about to train, not execute the training)
  • Reset all gradients
  • Compute output, loss value, accuracy, etc
  • Perform back-propagation
  • Update the network’s parameters

Here is how it looks like in code:

You may notice the detach() thing. Whenever we want to use something that belongs to the computational graph for other operations, we must remove them from the graph by calling detach() method. The reason is, Pytorch keeps track of the tensors’ flow to perform back-propagation through a mechanism called autograd. We mess it up and Pytorch will fail to deliver the loss.

Is there anything I have missed? Oh, the gradient clipping! While it may not be as intuitive, it only requires one line of code. We just need to put it after calling loss.backward() and before optimizer.step() like this:

Finally, we will add code to print the loss value to console and have the model generate some text for us during training:

That is the training loop. The only thing left is to define the predict method.

Inference

We finally reached the last and most fun part: implement the predict method. What we are going to do can be illustrated in this figure below:

Fig. 1: the inference process

Assuming that we have some initial words (“Lord” and “Voldemort” in this case), we will use them as input to compute the final output, which is the word “is”. The code is as follow, don’t forget to tell the network that we are about to evaluate by calling eval() method and of course, remember to move your stuff to GPU:

Next, we will use that final output as input for the next time step and continue doing so until we have a sequence of length we wanted. Finally, we simply print out the result sequence to the console:

We can now hit the run button and of course, don’t forget to get yourself a cup of coffee. Enjoy your machine’s creativity!

Final word

So in today’s post, we have created a model which can learn from any raw text source and generate some interesting content for us.

We have done it with ease by using Pytorch, a deep learning library which has gained a bunch of attention for the recent years. All the code and training data can be found at my repo (Pytorch scripts have _pt postfix).

That’s it for today, guys! Thank you so much for reading. And I am definitely seeing you soon.

Reference

  1. Text generation with Tensorflow: link
  2. Colah’s excellent blog post about LSTM: link
  3. Intro to RNN’s tutorial from Mat, Udacity: link
  4. Donald Trump’s full speech: link
  5. Oliver Twist: link

Trung Tran is a Deep Learning Engineer working in the car industry. His main daily job is to build deep learning models for autonomous driving projects, which varies from 2D/3D object detection to road scene segmentation. After office hours, he works on his personal projects which focus on Natural Language Processing and Reinforcement Learning. He loves to write technical blog posts, which helps spread his knowledge/experience to those who are struggling. Less pain, more gain.

1 comments On Text Generation With Pytorch

  • Trung Tran

    I received a lot of emails when I published my old blog post asking for Harry Potter’s text files. I’m sorry for disappointing you guys but I can’t share them (you know the reason why).

    Still, there’s a lot of free stuff out there for you to experiment. So, enjoy your network 😀

Leave a reply:

Your email address will not be published.