A deep neural network can learn a better version of itself

A few months ago I experimented with deep learning to generate art. You may have seen some of the art I posted on twitter @pmarelas. In that experiment I used a pre-trained Generative Adversarial Network (CartoonGAN) that was trained to turn everyday images into cartoons. The details of how that works is in the linked paper.

CartoonGAN is great but it has two problems. First, it is really big and requires a lot of computation to process an image. Second, the network has a problem reproducing colour accurately. To produce realistic cartoon images consistently I had to resort to post-processing the images to correct colour mismatches.

I embarked on a journey to determine if I could improve CartoonGAN. In particular, I wanted to see if I could eliminate the need to post-process images and at the same time speed up the network so it could process video and camera frames in real-time.

I am happy to share I achieved my goal. The problem of speed boiled down to generating a smaller and more efficient network. The smaller the network the fewer instructions required to process the inputs (i.e. the image or video frame). What I discovered is the trick to generating a smaller network is to turn the network on itself. Say what? Let me explain.

Neural networks learn to optimise an outcome. If we can express the outcome as a loss/cost function the network will try its best to minimise the loss/cost which moves it closer to the desired outcome. For this task there were two challenges.

  1. Could I shrink the network without impacting its ability to produce realistic cartoon images?
  2. Could I avoid the manual post-processing step?

To shrink the network I used network pruning and distillation methods. The idea is quite simple.

Take a copy of the original network, remove large portions of it (I removed entire resblocks) and train it on a new set of images (I used 200 images from ImageNet). During the training process feed each image to the small network and the original network. Take the images produced by both networks and compare them using a measure of similarity. For this experiment I used the sum of mean squared errors between the two images pixels. This measure is the basis for the small networks loss function. This process allowed the small network to learn how to reproduce the same cartoon images generated by the larger original network.

The next step is to determine how to avoid post-processing. It turns out we can use the exact same approach with a twist. Instead of teaching the small network how to reproduce cartoon images generated by the larger network, we post-process the images generated by the larger network ahead of time and use these modified images in the loss function. With this additional step the small network learns how to correct the colour mismatches by learning the difference between the unprocessed and post-processed images.

The result is an improved network called Fast CartoonGAN (named after fast.ai where I acquired my deep learning education) that is able to accurately reproduce colour and is significantly smaller (1455 MiB versus 2285 MiB) and faster (11 ms versus 20.5 ms) than its predecessor.

To demonstrate the result I recorded a video clip of myself. This was captured from my laptop webcam (a little grainy) using OBS at 10 frames/s and broadcast live to a server running Fast CartoonGAN in pytorch and processing the stream with a NVIDIA 1080 GPU.

Perhaps one day we can use this technology to put the AI into humans instead of replacing humans with AI.

PS: Yes those things on my hands are band-aids not artefacts 🙂