Learning AI Poorly: How Does AI Generate Pictures?
(originally posted to LinkedIn)
Did you ever wonder how a model can generate a picture from a prompt? It is nearly midnight on a Thursday… Let’s see if I can explain it in a mostly correct but kind of wrong way that makes sense and still hit publish while it is still Thursday.
In 2015, a group of researchers from Toronto thought, “Hey! We can train models to recognize objects in images… Can we turn that around and train a model to generate images from text prompts?” So, they did and they wrote a paper:
https://arxiv.org/abs/1511.02793v2
I mean, they literally say, “One approach is to learn a generative model of text conditioned on images, known as caption generation […] These models take an image descriptor and generate unstructured texts using a recurrent decoder. In contrast, in this paper we explore models that condition in the opposite direction, i.e. taking textual descriptions as input and using them to generate relevant images.” They literally took something that already worked, turned it backwards and kind of changed the world.
Great! How does it work? If you’ve been reading for a while, you know that ChatGPT works by taking words and converting them to numbers (tokens). You train AI models on the numbers, then use that model to generate new, novel numbers, convert those numbers back to words and boom! You’ve got generative text.
Did you know that computer images are already numbers? A picture is made up of a bunch of pixels which literally is just a bunch of numbers that represent colors. So, we kind of get to skip a step on the imaging side. But, when we train we need something that describes what is in the picture. So, say we have a photo of a city skyline. We train the model by giving it a caption, “city skyline with buildings on a cloudy day” and the pixel data of the image.
Latent Spaces #
So, what happens when we “train an AI” on images? The whole point is to get a neural network to correctly guess pixel data given some sort of caption. So, you feed it thousands of captioned images. The model will slowly get better at sort of categorizing that pixel data.
How does it do that? The model has to find some sort of metrics that separate images within a mathematical space. Or, rather, figure out how to separate a collection of pixels that represent an avocado from a collection of pixels that represent a banana.
Think of one way you might do that. Color maybe? Yeah, color. Let’s grade each collection of pixels on their “greenness.” Avocado will have a lot of greenness, while bananas will not…. But, what if there’s another “yellowness” scale. Bananas will score high while avocado will get zero.
What if we look at a collection of pixels that represent a leaf? Leaves are usually green… so, high on the greenness scale, but, like, leaves aren’t avocados. So, are we screwed? No. The model will just kind of automatically create some sort of “leafy-ness” scale. The leaves will score high on that scale, but an avocado will not.
Our model now has three dimensions… yellowness, greenness, and leafyness. As we train the model with more images, we’ll get more dimensions… like, “roundness” or “flatness” or, I don’t know, “l;skjdflksj=-ness”
What is “l;skjdflksj=-ness”? That’s the point, it doesn’t matter, the machine is going to come up with dimensions that we, as humans, don’t have words for. The model is just finding variables that help it improve its performance on the task and in the process it is building out a mathematical space with tons of dimensions.
Those dimensions represent a “Latent Space.” Within that latent space, collections of pixels that are similar are closer together. So, you might have an area that is “cute puppy dog” or “dirty trash can.” Collections of pixels that look like trash will cluster near the dirty trash area of the multidimensional space. Cute puppies pixels will cluster somewhere else.
So, you can think of a point within that space as the recipe for a collection of pixels that look like a thing. A text prompt is the thing that navigates us to that location.
Diffusion #
Ok… so…. how do you generate an image from a latent space? First, you need a prompt. Say we want an avocado on a table. Somewhere inside that latent space there is a “recipe” for a collection of pixels that is high on “table-ness” scale and the “avocado” scale (greenness, roundness, foody-ness, lkasjflkjaa-ness, etc.) The prompt will get you to that recipe.
The crazy part starts with randomness. Say your model was trained on 512 pixel by 512 pixel images, what would happen if you generated a 512x512 pixel image of just random numbers. You’d have noise, or static, or distortion… whatever you want to call it. You won’t have an avocado.
What we do is work backwards from random noise and try to zero in on the recipe for the collection of pixels within that latent space. To do that, you create a network that can essentially remove noise, a little bit at a time, until you get something that looks like an avocado.
So, what we do is give the model a 512x512 image of noise and prompt the model with, “avocado on a table.” The model will give its best estimate of what noise to remove from that image to get it a bit closer to the “avocado on a table” latent space. Then we add just a pinch more noise then feed that back into the model again. The model will do its best to estimate or “predict” what noise to remove. Do that over and over again and you’ll eventually end up with an avocado on a table like image.
So generally, the algorithm is this - start with noise, predict what noise to remove to get closer to a spot in your latent space and remove it. Add a tiny bit of noise for fun, then predict what noise to remove. Do that again and again until you have an avocado on a table.
Classifier-Free Guidance #
The diffusion step above will give you an image, but people have come up with a bit of a trick to get the model to generate more realistic images. When you’re iterating along trying to predict what noise to remove from the image you typically just give it the prompt and the noisy image. With Classifier-Free Guidance, you not only give it the prompt along with the noisy image, but you also give it the noisy image without a prompt. You then calculate the difference between the noisy images. You use that value as a sort of “signal” when you feed it back into the network for the next iteration.
What that does is it says sort of “what would this network do with this noise given a prompt vs. what would it do without a prompt.” Then you kind of use that information to move you closer to the “with prompt” noisy image and away from the “without prompt” noisy image. This lets you zero into what the network would do with the prompt and try to steer away from what the network would do randomly (without the prompt). It gives you better images.
Ooph, 11:57pm. Publishing now… I hope that sort of made sense.