A deep neural network can learn a better version of itself

A few months ago I experimented with deep learning to generate art. You may have seen some of the art I posted on twitter @pmarelas. In that experiment I used a pre-trained Generative Adversarial Network (CartoonGAN) that was trained to turn everyday images into cartoons. The details of how that works is in the linked paper.

CartoonGAN is great but it has two problems. First, it is really big and requires a lot of computation to process an image. Second, the network has a problem reproducing colour accurately. To produce realistic cartoon images consistently I resorted to post-processing the images to correct for colour mismatch.

After my first experiment my curiosity caught the better of me. I wanted to figure out if I could improve CartoonGAN. In particular, I wanted to see if I could eliminate the need to post-process images and make CartoonGAN fast enough to process video and camera frames in real-time.

I am happy to share I achieved my goal. The problem of speed boiled down to generating a smaller and more efficient network. The smaller the network the fewer instructions required to process the inputs. It turns out the trick to generating a smaller network is to turn the network on itself.

Neural networks learn to optimise an outcome. If we can express the outcome as a loss function the network will try its best to minimise the loss which moves it closer to the desired outcome. For this task there were two desirable outcomes:

  1. Could I shrink the network without impacting its ability to produce realistic cartoon images?
  2. Could I avoid the post-processing?

To shrink the network I used network pruning and distillation methods. The premise is quite simple.

We take a copy of the original network, remove large portions of it (I removed entire resblocks) and train it on a new set of images (I used 200 images from ImageNet). Then during training feed each image to the small network and the larger original network. Take the images produced by both networks and compare them using a measure of similarity. For this experiment I used the sum of mean squared errors between the two images pixels and is the basis for the small networks loss function. This process allowed the small network to learn how to reproduce the same cartoon images generated by the larger original network.

The next step is to determine how to avoid post-processing. It turns out we can use the exact same approach with a twist. Instead of teaching the small network how to reproduce the same cartoon images generated by the larger network, we post-processed the images generated by the larger network ahead of time and used these images as the basis for comparison by the small networks loss function. With this additional step the small network learnt how to correct the colour mismatches by learning the difference between the unprocessed and post-processed images.

The result is a network that is able to accurately reproduce colour and is 50% smaller than its predecessor.

All this is possible without access to the original networks training data. The new and improved Fast Cartoon GAN is fast enough to process video in real-time using a consumer grade GPU.

To demonstrate the result I produced a clip of myself below. This was broadcast from my laptop (a little grainy) at 10 frames/s to a server running a NVIDIA 1080 GPU.

Perhaps one day we can use this technology to put the AI into humans instead of trying to replace humans with AI.

PS: Yes those things on my hands are band-aids not artefacts 🙂