Generative Adversarial Networks (GANs) are a powerful class of deep learning models that have revolutionized the field of generative modeling. GANs have been used to generate high-quality images, videos, music, and even text. In this hands-on project, I will walk you through the process of building a GAN to generate realistic images of faces.

Project overview

Our goal in this project is to build a GAN that can generate realistic images of faces. We will be using the CelebA dataset (available on Kaggle), which contains over 200,000 images of celebrity faces. The CelebA dataset is a popular benchmark dataset for GANs, and has been used in plenty of research papers.

The GAN we will build consists of two neural networks: a generator and a discriminator. The generator network takes in random noise as input and generates fake images, while the discriminator network takes in real and fake images and tries to distinguish between them. During training, the generator tries to generate images that fool the discriminator, while the discriminator tries to correctly identify real and fake images.

We will be using PyTorch (a popular deep learning framework and the one I personally like) to implement our GAN. PyTorch has excellent support for building GANs, with many pre-built components and examples available.

Preprocessing our data

Before we can train our GAN, we need to preprocess the CelebA dataset. We will use the torchvision package to download and preprocess the data. Specifically, we will resize all the images to 64x64 pixels, normalize the pixel values to the range [-1, 1], and convert the images to tensors. In technical terms, a tensor is a multi-dimensional array of data that is used to represent complex mathematical objects in deep learning and other scientific fields. Tensors can have any number of dimensions, ranging from a scalar value of 0 to n (a tensor of rank n). In deep learning, we use tensors to represent data such as images, audio, and text, as well as the weights and activations of neural network layers.


    import torch
    import torchvision.datasets as datasets
    import torchvision.transforms as transforms

    # Define data transforms
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load the CelebA dataset
    dataset = datasets.CelebA(root='./data', download=True, transform=transform)
    
Building the generator

Next, we will build the generator network. The generator takes in random noise as input and generates fake images. We will use a deep convolutional neural network (CNN) architecture for the generator, consisting of several layers of transposed convolutions, batch normalization, and ReLU activation functions.


    import torch.nn as nn

    class Generator(nn.Module):
        def __init__(self, latent_dim):
            super(Generator, self).__init__()
            
            # Input layer
            self.fc = nn.Linear(latent_dim, 4 * 4 * 512)
            self.bn1 = nn.BatchNorm2d(512)
            
            # Hidden layers
            self.conv1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
            self.bn2 = nn.BatchNorm2d(256)
            self.conv2 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
            self.bn3 = nn.BatchNorm2d(128)
            self.conv3 = nn.ConvTranspose2d(128, 64, 4, 2, 1)
            self.bn4 = nn.BatchNorm2d(64)
            
            # Output layer
            self.conv4 = nn.ConvTranspose2d(64, 3, 4, 2, 1)
            self.tanh = nn.Tanh()
            
        def forward(self, z):
            x = self.fc(z)
            x = x.view(-1, 512, 4, 4)
            x = self.bn1(x)
            x = nn.functional.relu(x)
            x = self.conv1(x)
            x = self.bn2(x)
            x = nn.functional.relu(x)
            x = self.conv2(x)
            x = self.bn3(x)
            x = nn.functional.relu(x)
            x = self.conv3(x)
            x = self.bn4(x)
            x = nn.functional.relu(x)
            x = self.conv4(x)
            x = self.tanh(x)
            return x
    

The generator takes in a noise vector of size latent_dim as input and outputs a tensor of size (batch_size, 3, 64, 64), representing our fake images.

Building the discriminator

Next, we will build the discriminator network. The discriminator takes in real and fake images as input and tries to distinguish between them. We will use a CNN architecture for the discriminator, consisting of several layers of convolutions, batch normalization, and LeakyReLU activation functions.


    class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # Input layer
        self.conv1 = nn.Conv2d(3, 64, 4, 2, 1)
        self.leaky_relu1 = nn.LeakyReLU(0.2)
        
        # Hidden layers
        self.conv2 = nn.Conv2d(64, 128, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(128)
        self.leaky_relu2 = nn.LeakyReLU(0.2)
        self.conv3 = nn.Conv2d(128, 256, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(256)
        self.leaky_relu3 = nn.LeakyReLU(0.2)
        self.conv4 = nn.Conv2d(256, 512, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(512)
        self.leaky_relu4 = nn.LeakyReLU(0.2)
        
        # Output layer
        self.fc = nn.Linear(4 * 4 * 512, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.leaky_relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.leaky_relu2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.leaky_relu3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.leaky_relu4(x)
        x = x.view(-1, 4 * 4 * 512)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x
    

The discriminator takes in a tensor of size (batch_size, 3, 64, 64) as input and outputs a scalar value between 0 and 1, representing the probability that the input is a real image.

Training the GAN

Yes, we're almost there. Now that we have defined our generator and discriminator networks, we can train the GAN. During training, we alternate between updating the generator and discriminator networks. For each update, we sample a batch of noise vectors from a normal distribution and use the generator to generate fake images. We also sample a batch of real images from the CelebA dataset. We then train the discriminator on these real and fake images, and train the generator to generate images that fool the discriminator.


    # Define hyperparameters
    num_epochs = 200
    batch_size = 128
    lr = 0.0002
    betas = (0.5, 0.999)
    latent_dim = 100

    # Create the generator and discriminator
    generator = Generator(latent_dim).to(device)
    discriminator = Discriminator().to(device)

    # Define loss function and optimizers
    criterion = nn.BCELoss()
    gen_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)
    dis_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

    # Load the dataset
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    dataset = datasets.ImageFolder(root='celeba', transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Train the GAN
    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.shape[0]
            
            # Train discriminator
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            # Train on real images
            real_images = real_images.to(device)
            dis_optimizer.zero_grad()
            real_outputs = discriminator(real_images)
            dis_loss_real = criterion(real_outputs, real_labels)
            dis_loss_real.backward()
            
            # Train on fake images
            noise = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_images = generator(noise)
            fake_outputs = discriminator(fake_images.detach())
            dis_loss_fake = criterion(fake_outputs, fake_labels)
            dis_loss_fake.backward()
            
            # Update discriminator
            dis_loss = dis_loss_real + dis_loss_fake
            dis_optimizer.step()
            
            # Train generator
            gen_optimizer.zero_grad()
            noise = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_images = generator(noise)
            fake_outputs = discriminator(fake_images)
            gen_loss = criterion(fake_outputs, real_labels)
            gen_loss.backward()
            
            # Update generator
            gen_optimizer.step()
            
            # Print loss every 100 steps
            if (i+1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'
                    .format(epoch+1, num_epochs, i+1, len(dataloader), dis_loss.item(), gen_loss.item()))
        
    

In the training code above, we define the hyperparameters for the GAN, create the generator and discriminator networks, define the loss function and optimizers, and load the CelebA dataset using PyTorch's ImageFolder and DataLoader classes.

We then train the GAN for a specified number of epochs. For each epoch, we iterate over the batches of the dataset and update the discriminator and generator networks. We first train the discriminator on a batch of real and fake images, computing the loss using binary cross entropy. We then update the discriminator's parameters using the Adam optimizer. Next, we train the generator to generate images that fool the discriminator, again using binary cross entropy to compute the loss. Finally, we update the generator's parameters using the Adam optimizer. At the end of each epoch, we print the discriminator and generator losses.

Finally, we save the generator's parameters to a file named generator.pth. We then generate 64 sample images using the generator and save them to a file named generated_images.png.

        
    # Save the generator
    torch.save(generator.state_dict(), 'generator.pth')

    # Generate sample images
    generator.eval()
    with torch.no_grad():
        noise = torch.randn(64, latent_dim, 1, 1).to(device)
        fake_images = generator(noise)
        save_image(fake_images, 'generated_images.png')
        
    
Use cases

There are several use cases for generating face images using GANs.

  1. Entertainment: CGI is often used in the industry to create realistic human character in movies and video games. However, GANs could potentially be used to generate high quality human faces.

  2. Facial recognition: GANs could be used to generate synthetic images of faces, which could then be used to augment existing datasets for training purposes. This could help improve the accuracy and robustness of these models, especially in scenarios where there is limited training data available.

  3. Forensics:: Potential use to generate realistic images of suspects based on eyewitness descriptions. This could help law enforcement agencies identify suspects and solve crimes more quickly.

Conclusions

GANs have become an increasingly popular area of research in the field of deep learning, and have been used for a variety of tasks such as image generation, style transfer, and anomaly detection. While the basic principles of GANs are relatively simple, training them can be challenging and requires careful tuning of hyperparameters and network architecture.

Despite these challenges, GANs offer a powerful tool for generating high-quality, realistic images, and have the potential to revolutionize the field of computer vision. With further research and development, GANs may be able to generate not just images, but entire virtual worlds, with applications in fields such as gaming, virtual reality, and simulation.

To conform to standards of decency, I have chosen not to include the face images I generated in this blog post. However, if you encounter any issues with the code, please feel free to contact me for assistance.