Image Generation with DCGAN
Nov 28, 2016
This is the project to wrap up my Fall Quarter 2016 after having taken Neural Networks & Deep Learning and Image Processing courses. The goal is to familiarize myself with TensorFlow, the DCGAN model, and image generation in general. Code and detailed configuration is up here.
Images are 2D mediums that carry information. They usually take different forms like newspaper, book covers, movie posters, floor plans, company logos, or even clothing. Interestingly, images of the same kind usually follow a certain norm. For example, if we knew we are going to see a portrait, we probably would expect to see a human face with clear facial features, rather than a car or some other random things. Similarly, we think it is normal for an image of bedroom to contain features like bed, closet and windows, any images that possess these features are considered bedroom, instead of kitchen. This might seem obvious for us, but not quite for computers. If we could teach the computer what makes an image "normal" for a portrait, we will end up with a machine that can do two things: 1. differentiate between non-portraits and real portraits. 2. come up with realistic portraits although the people in the images do not actually exist.
However, teaching a computer what's underlying the norm is not an easy task. Let's say we want our generator to output an image of digit "3". In computer's language, we are essentially asking it to return a matrix of certain numbers, where bright areas are represented by high values and dark areas are represented by low values. The question is, since there are infinite ways to write a "3", how can we make sure the generator learns the inherent distribution of pixel values so that each time it will not only outputs a "3" but a different "3"? Turned out we can achieve this with an algorithm named DCGAN.(Image source)
What is DCGAN?
DCGAN stands for Deep Convolutional Generative Adversarial Networks, it was proposed by Alec Radford et al. in late 2015. To explain this model intuitively, I'll borrow a simple analogy from mother nature. Millions of years ago, there was a bunch of insects in the woods, they look normal, act normal, nothing special. Simultaneously, their predators - birds were also trying to make a living. To survive, the insects have to escape the birds by coming up with reliable camouflages, and the birds have to fill the stomach by improving their sight to recognize the prey. These two parties formed a natural competition that drove both to evolve, and eventually be good at their jobs. Millions of years later, we now have stick insects that look very much like branches, and birds evolved with sharp eyes.
In the context of machine learning, the birds can be seen as object classifiers, which differentiate the stick insect from the ambience, whereas the stick insects are generators that mimic the branches to fool the birds. We represent this two-player-competition game using the model below. For now, let's treat the generator and discriminator as two black-box functions that map an input to an output. Specifically, the generator takes a vector of random numbers (either normally or uniformly distributed) and convert them to an image. On the other hand, the discriminator has two input options: when it is fed with a real image, it should say "real!", and output a probabilistic number close to 1; when it is fed with a generated image, it should say "fake!", and output a number close to 0. By competing with each other, both the generator and discriminator will get better. Ultimately, the generator is able to create images so real that the discriminator can no longer differentiate, which marks the end of the game. This whole training mechanism is named Generative Adversarial Networks (GAN), proposed by Ian Goodfellow et al. in 2014.(Image source)
In GAN, the generator and discriminator are constructed with multilayer perceptrons. Built on top of that, DCGAN used Convolutional/Deconvolutional Neural Networks, which made the model more powerful.(Image source)
In my experiments, I had a hard time to train these bad boys with two major roadblocks:
First, using a pair of ConveNets with the structure above means HUGE number of parameters. It ended up with around two million weights to be trained and optimized. I initially ran the code on my Mac (2.4GHz i5 with 8G RAM) and soon realized the progress is minimal compared with the torture that my laptop is going through. So I migrated to AWS using their g2.2xlarge instance with a single GPU, which gave at least five times of speed improvement and it was able to run for 10 hours.
Second, it was really hard to keep a balance between the generator and discriminator. Nine out ten times one would have outperformed another and dominated the game, resulting in nonsensical images as none of the networks is sufficiently trained. Besides the setup suggestions from original DCGAN paper, I implemented a bunch of tricks to keep the competition going:
- Start with a small network structure and only add extra complexity when necessary. This is to avoid overfitting and allow rapid trial-and-error.
- Intuitively, the generator is up for transforming a noise vector into an image, which is relatively a harder task, compared with taking an image and outputting a scalar value for discriminator. To make the game even, I limited the capacity of the discriminator by decreasing its number of features so that the generator is more powerful.
- Since the discriminator is inherently weak, I found it's beneficial to pre-train it before entering the competition, so that it can warm itself up and knows what's going on.
- The generator might need a few iterations to get started as well. Instead of training the generator multiple times, I made the discriminator's life harder by substituting 1/5 of the real images to the generated images and pretend they are just as real. This is meant to slow down the discriminator at the beginning so that the generator can take time to observe and learn. I turned this trick off after 200 epochs to let the competition go back to normal.
Here are graphs from Tensorflow reflecting what is supposed to happen during the early training stage. Each horizontal slice represents a probability density distribution that the discriminator believes the input images are real, at a certain training epoch (vertical axis).
The graph above shows discriminator's predictions on the real images. Before 100th epoch, we can see the distributions are widely spread around 0.5, indicating the discriminator can't really tell if the images are "real" or not. From epoch 100 to 200, it starts figuring out these images are real ones and assigns them with higher probability. Around epoch 240, it gets confused because the generator has caught up. Eventually, the distribution will be around 0.5 again, marking algorithm's convergence.
The graph below shows discriminator's predictions on the generated images. We see the pre-trained discriminator gradually learns the differences between real and generated images, and assign lower and lower probability to the latter. As training error gets back-propagated to the generator, it is forced to learn the underlying representations of real images so that generated images look more normal. Around epoch 240, the generator rises up and fools the discriminator.
Experiments and Results
On the left is a sample of the original MNIST dataset, there are 36 digits of size 28*28*1. The goal is to generate digits 0-9 that look very much like human's hand writing. Click the button to check out random digits generated from multiple model checkpoints.
At epoch 0, the generated images are purely random noise. Around epoch 60, the generator figures out each digit should be surrounded by black areas. Around epoch 100-500, although the images do not show intelligible digits, the generator learns that a digit should be made of various strokes (connected bright areas). Starting at epoch 2000, we can tell numbers that are recognizable. We may increase the training epochs to get better results.
On the left is a sample of the original CelebA dataset, there are 36 celebrity faces of size 64*64*3. The goal is to generate visually realistic human faces. Click the button to check out some freshly generated faces. Note that only "Epoch 16500" is rendered in real time.
The face images started out as random noise as well. At epoch 150, the generator learns the faces are of a different color distribution compared with the background. It also learns that each face image should be separated by a white padding on the sides. From epoch 400 to 3000, the generator learns to represent human faces with major facial features like eyes, nose and mouth, and refine the pixels so that edges are sharper. Around epoch 16500, although some of them are quite disturbing, most show normal human faces, especially when you take out your glasses and keep the faces small. I'm glad to see a variety of genders, ages, facial expressions, backgrounds and so on, feel like there are stories behind each person. Keep in mind that these people do not exist in this world!
(might take a while)
To make the task even more fun, we want to not just generate an image that looks normal, but also control features in an image. For example, when we ask Google to search an image of a tough man with a gun in hand, we will get an image down below. What if we ask the generator to come up with an image that contains these elements, how can we map those semantic features down to the pixel level in the generated image?
Since a 64*64*3 image is essentially a point in the 64*64*3 dimensional space, which, in this case, is generated by transforming the initial vector z in a non-linear manner. We call this vector z the latent vector and we want to find out the relationship between the input values in vector z and output images.
Here I randomly select one dimension in the latent vector, then change its value from negative to positive in equal step size, values of all other dimensions remain unchanged. We can notice some interesting reflections in the images such as change in background brightness or facial expressions. But most of the time, image features are entangled, such that totally different images are generated as tuning the values in just one dimension.
|darken||brighten||get bald||assassin to smiley girl|
Say we have two points in the 100-dimensional latent space, each representing a valid digit(2 and 8). We can draw a "hyper-straight" line between those two points, and linear interpolation describes the transitioning from one digit to another digit by traversing on that line. We can see the transition is quite smooth, indicating the model does not memorize all the training images but actually learned the underlying representations of the digits.
In linear interpolation, we have a "hyper-straight" line that connects two generated images. Traversing on this line gives us acceptable interpolation but not a good one, because points around the middle of the line may be nonsensical. Spherical interpolation is a better approach, which treats the latent space as a hypersphere, so that traversing on a "hyper-circular" path gives better qualitative results.
Future work: more on unsupervised feature learning and interpolation
Generative Visual Manipulation on the Natural Image Manifold: https://arxiv.org/abs/1609.03552
Adversarial Feature Learning: https://arxiv.org/abs/1605.09782
Adversarially Learned Inference: https://arxiv.org/abs/1606.00704
Deep Feature Interpolation: https://arxiv.org/abs/1611.05507
One of the ways that we human perceive the world is by, obviously, looking at it. In fact, the raw visual input that retina feeds to the brain can be considered as a 2D image, which doesn't contain any explicit information about the 3rd dimension - depth. Taking a snapshot of the view you see right now, no matter it is the laptop that is right in front of you, or your neighbor Mr.Smith who is passing by the street, they all fall on to a 2D plane and has no intrinsic difference with the letters on newspaper. Luckily your brain is a piece of complex machine that does more than what's offered from a camera, evolution has enabled it to pick up pictorial cues, such as differences in brightness, sizes, and positions to infer 3D relationships among the objects in sight. This ability to interpret 2D images in 3D sensation provides us with visual immersiveness.
Following this logic, if we extrapolate the idea of image generation to an extreme, it will give us another interesting application: a real-time rendered, one-of-a-kind virtual reality. Of course there might have many roadblocks to prevent this from happening. For example, it would require tremendous computational resources in order to get a high rendering speed, not even to mention the rendering quality. However, if we allow our imagination to go wild for a minute, having a technology like this would drastically increase our experience as a human: we can take a nap in the ocean, ride the rocket to the moon, explore the jungle as an ant, have a trip back to history, so on and so forth, only time and our own imagination will become the limitation!