In this series, Artem Nazarenko, Computer Vision Engineer at Everypixel, shows you how you can implement the architecture of a neural network. In the first part, we were talking about the working principles of GAN and methods of collecting datasets for training, the second part was about preparing for GAN training. Today, we are going to start training.
We declare datasets and dataloaders for loading data and provide the device on which the network will be trained.
Provide the device on which we will train the network:
We train attention and shadow generation blocks separately.
Attention block training. We take U-Net as the attention block model and import the architecture from the segmentation_models.pytorch repository. To improve the quality of the network, replace the standard encoding part of the U-Net with the resnet34 classifier network.
Since the attention block accepts a shadow-free image and a mask of the inserted object at the input, we will replace the first convolutional layer in the model: a 4-channel tensor (3 color channels + 1 black-and-white) is sent to the module’s input.
Declare the loss function, metric and optimizer.
Create a function to train the attention block. The training is standard. It consists of three cycles: a cycle by epochs, a training cycle by batches, and a validation cycle by batches.
At each iteration of the dataloader, a direct run of the data through the model and obtaining predictions are performed. Then, the loss functions and metrics are calculated, after which a reverse pass of the learning algorithm (backpropagation of the error) is done, and the weights are updated.
After the training of the attention block is completed, proceed to the main part of the network.
Shadow generation block training. As a model of the shadow generation block, we will similarly take U-Net and a lighter network – resnet18 as an encoder.
Since at the input shadow generation block accepts a shadow-free image and 3 masks (the mask of the inserted object, the mask of neighboring objects and the mask of their shadows), we will replace the first convolutional layer in the model: the module receives a 6-channel tensor (3 color channels + 3 black-white ones) at the input.
Behind the U-Net, we add 4 refinement blocks at the end. One block consists of a sequence: BatchNorm2d, ReLU and Conv2d.
Declare a generator class.
Declare a discriminator class.
Declare generator and discriminator model objects, as well as loss functions and optimizers for the generator and discriminator.
Everything is ready for training. Provide a function for training the SG block. Calling it will be similar to calling the attention learning function.
Visualization of the learning process
Graphs, general information. For training, I used a GTX 1080Ti graphics card on the Hostkey server. In the process, I tracked the change in the loss functions for the plotted graphs using the tensorboard utility. Below, the figures show training graphs based on the training and validation samples.
Training Graphs — Training Samples
The second figure is especially useful because the validation samples are not used in the generator training process. They are independent. The training graphs show that it reached the plateau at approx. the 200-250th epoch. Here it was already possible to slow down the training of the generator since the loss function was not monotonic.
However, it is useful to look at the training graphs on a logarithmic scale as it shows the monotony of the graph more clearly. According to the graph of the logarithm of the validation loss function, we can see that it was too early to stop learning at approx. the 200-250th epoch. It could have been done later, at the 400th epoch.
Training Graphs — Validation Samples
For clarity of the experiment, the predicted picture was periodically saved (see the gif of the visualization of the learning process above).
Some difficulties. During the training process, we had to solve a simple problem — incorrect weighting of the loss functions.
Since our final loss function consists of the weighted sum of the other loss functions, the contribution of each of them to the total must be adjusted separately by setting the coefficients for them. The best option is to take the coefficients suggested in the original article.
If the balancing of the loss functions is wrong, we can get unsatisfactory results. For example, if too strong a contribution is set for L2, and then the training of the neural network can even come to a standstill. L2 converges quickly enough, but at the same time, it is undesirable to completely remove it from the total amount – the output shadow will be less realistic, less consistent in color and transparency.
An example of a generated shadow in the absence of an L2-loss contribution
The picture shows the ground truth image on the left and the generated image on the right.
Inference. For prediction and testing, we will combine the attention and SG models into one ARShadowGAN class.
The inference code is below.
This article discusses a generative adversarial network by the example of solving one of the ambitious and difficult tasks at the junction of Augmented Reality and Computer Vision. In general, the resulting model can generate shadows, although not always perfect.
Note that GAN is not the only way to generate shadows. There are other approaches that, for example, use 3D object reconstruction techniques, differentiated rendering, etc.
P.S. I would be happy to answer any questions you may have and to receive your feedback. Thank you for your attention!