Deep Learning

72_Generative Adversarial Networks

elif 2024. 2. 10. 18:38

In real-world applications such as image generation, distributions are highly complex, leading to significant performance improvements with the introduction of deep learning-based generative models. In this post, we will study Generative Adversarial Networks (GANs).

 

In GANs, the generative model learns the distribution from a training dataset using machine learning algorithms and generates new examples from that distribution. Such generative models can be thought of as a distribution $p({\text{x}}|{\text{w}})$ over the data space of vector ${\text{x}}$, where ${\text{w}}$ represents trainable paramters. More importantly, models in the form of conditional generative models, $p({\text{x}}|{\text{c}},{\text{w}})$, where ${\text{c}}$ values can specify the generation of specific images, asre increasingly significant.

 

 

When considering a generative model based on a nonlinear transformation from the latent space ${\text{z}}$ to the data space ${\text{x}}$, the latent distribution $p({\text{z}})$ that can take a Gaussian form is represented as follows.

 

 

 

The generative model, known as the generator, with trainable parameters ${\text{w}}$ and defined by a deep neural network, enables the implicit definition of the distribution over ${\text{x}}$ through the nonlinear transformation ${\text{x = g(z,w)}}$. The objective is to adapt this distribution to fit the training example dataset ${\text{\{ }}{{\text{x}}_n}\} $. A fundamental concept of Generative Adversarial Networks is the incorporation of a second network, the discriminator, which learns in tandem with the generator network. This discriminator provides the learning signal to update the generator's weights by distinguishing between real and generated examples.

 

The objective of the discriminator network is to distinguish between real examples from the dataset and fake examples generated by the generator network, and it is trained by minimizing a typical classification error function. Conversely, the goal of the generator network is to synthesize examples from the same distribution as the training set, thereby maximizing this error.

The term 'adversarial' is used because the generator and discriminator networks act in opposition to each other, where the gain of one network signifies a loss for the other network.

 

When $t=1$ represents real data and $t=0$ represents synthetic data, the discriminator network has a single output unit with a logistic sigmoid activation function. This output represents the probability that the data vecor ${\text{x}}$ is real.

 

 

The discriminator network is trained using the standard cross-entropy error function.

 

 

Here, ${d_n}$ represents the output of the discriminator network for input vector $n$, normalizaed by the total number of data points. The training set consists of real data ${x_n}$ and random samples ${z_n}$ from the latent space. Therefore, the error function can be written as follows.

 

 

Typically, the number of real data points is equal to the number of synthetic data points. The combination of generator and discriminator networks can be trained using backpropagation. However, since it is adversarial training where the error is minimized with respect to $\phi $ but maximized with respect to ${\text{w}}$, this maximization is performed using standard gradient-based methods, where the sign of the gradient is reversed to update the parameters.

 

 

Here, ${E_n}$ represents the error defined for data point $n$. From the formula, it can be seen that while the discriminator is trained to decrease the error rate, the generator is trained to increase it, indicating they have opposite signs. Training alternates between updating the parameters of the generator network and the discriminator network, with a new set of synthetic samples being generated after each geadient descent step.

 

It the generator finds a perfect solution, the discriminator network will be unable to distinguish between real and synthetic data, hence always outputting 0.5

'Deep Learning' 카테고리의 다른 글

74_Diffusion Model(2)  (0) 2024.02.12
73_Diffusion Model  (0) 2024.02.11
71_Graph Convolutional Network  (0) 2024.02.09
70_Graph Neural Networks  (0) 2024.02.08
69_Convolutional Filter  (0) 2024.02.07