본문 바로가기
AI - Deep Learning/Super_Resolution

[논문 구현] SRGAN(by Pytorch)을 활용한 Super Resolution Code

by 모두의 케빈 2023. 6. 27.

이번 시간에는 SRGAN에 대한 논문 리뷰 내용을 토대로 Pytorch를 활용하여 직접 코드로 구현해 보는 시간을 갖도록 하겠습니다. 혹시 SRGAN 논문에 대해 잘 모르시는 분들께서는 아래 링크를 먼저 정독하고 오시면 코드 해석에 도움이 될 것 같습니다. 

[논문 리뷰] SRGAN 논문 완벽 정리: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

 

[논문 리뷰] SRGAN 논문 완벽 정리: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

논문 제목: Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network(2017) 논문 링크: https://arxiv.org/pdf/1609.04802.pdf 목차 #1 Abstract & Introduction > 기존 방법들에 대한 단점이 보이기 시작하다. >

kevinitcoding.tistory.com

 

 

 

 

#1 코드 구현 조건


논문에서 제시한 내용을 토대로 코드를 구현하다.

 

[1] 논문에서는 ImageNet 데이터셋 약 35만 장을 학습에 활용했지만, 저는 용량 문제로 VOC2012 데이터셋을 활용하도록 하겠습니다. VOC2012 데이터셋은 Classification, Detection, Segmentation, Action Classification 등의 Task에 활용되는 데이터 셋입니다. 그중에서 Classification을 위한 약 17,000장의 이미지를 학습에 사용했습니다.

[2] SRGAN의 성능은 논문과 동일하게 Set5, Set14 등을 활용했습니다. 성능 평가는 공평하게 진행하기 위해 Center Crop 한 이미지로 평가합니다.

[3] 모델은 4x 업스케일 버전(Upscale factor r =4)으로만 작성했습니다.

[4]  논문에서는 저해상도 이미지를 [0,1] 사이로, 고해상도 이미지는 [-1,1] 사이로 Normalize 했지만 저는 편의성을 위해 그냥 [0,1] 사이로 Normalize 했습니다.

[5] 논문에서는 Generator의 가중치를 SRResNet의 가중치로 초기화했지만, SOTA를 달성하기 위한 것이 아니라 스터디 용이기 때문에 이 과정은 생략했습니다.

[6] SRGAN의 학습 과정에서 학습률(Learning rate)은 10만 번의 iteration 동안 10^-4, 이후 10만 번 동안은 10^-5로 바뀝니다. 그러나 저는 그렇게 까지 오래 학습을 시키지 않아서 학습률은 10^-4로 고정했습니다. 하지만 이 글을 읽으시는 분들께 도움이 될 것 같아서 학습률 스케쥴러(Learning rate Scheduler)를 활용한 코드를 작성해 두었습니다.

[7] 논문 기준, 성능이 가장 좋은 SRGAN-VGG54를 구현했습니다.(VGG Net의 마지막 Pooling Layer 바로 직전까지의 Feature Map을 활용) 또한 논문과 다르게 TVLoss를 Total Loss에 추가했습니다. TV Loss는 생성자(Generator)가 조금 더 부드럽고 자연스러운 이미지를 만들 수 있도록 학습에 도움을 주는 Loss 함수입니다.

[8] 논문에서는 판별자(Discriminator)의 마지막 Layer에 Dense Layer를 활용했지만, 저는 Global Average Pooling(GAP)과 Convolutiojn Layer를 활용하여 그 기능을 대체했습니다. GAP와 Conv Layer를 함께 사용하면 Dense Layer를 사용하는 것보다 연산량이 감소하는 장점이 있습니다.

 

 

#2 코드 설명


코드는 Google Colab을 기반으로 작성했습니다.

Google Drive Mount

Colab에 Drive를 연동합니다.

import os
from google.colab import drive
drive.mount('/content/drive')

 

VOC2012 데이터셋 다운로드

VOC2012 데이터셋 다운로드 코드입니다. 다른 데이터셋을 활용하거나 다른 방식으로 다운로드해도 상관없습니다.

# VOC2012 데이터셋 다운로드
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
!tar -xvf VOCtrainval_11-May-2012.tar -C /content/drive/MyDrive/Blog/초해상도/SRGAN

 

라이브러리 호출

필요한 라이브러리를 불러옵니다. pytorch_ssim는 SR 분야의 Metric인 SSIM을 계산해 주는 라이브러리인데, 저는 따로 사용하지는 않았습니다. 필요하신 분은 '!pip install pytorch_ssim'로 설치 후 사용하시면 됩니다.

from os import listdir
from os.path import join
import random
import matplotlib.pyplot as plt
import math
from math import log10
import time

from PIL import Image
import torch
from torch import nn
import torch.optim as optim
from torchvision.models.vgg import vgg16
from torch.utils.data.dataset import Dataset
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize

import cv2
from tqdm import tqdm
#import pytorch_ssim

 

파라미터 선언

참고로 VOC2012 데이터 셋 안에는 가로나 세로 크기가 96 이하인 이미지가 2장이 있습니다. 해당 이미지는 사전에 제거해 주셔야, Crop_size를 96으로 해도 에러가 발생하지 않으니 참고해 주세요.

crop_size = 96

if (crop_size % 4) != 0:
  crop_size = crop_size - (crop_size%4)

upscale_factor = 4
epochs = 250
batch_size = 64
dataset_dir = '/content/drive/MyDrive/Blog/초해상도/SRGAN/VOCdevkit/VOC2012/JPEGImages'
val_dataset_dir = '/content/drive/MyDrive/Blog/초해상도/SRGAN/VOCdevkit/VOC2012/Val_Images'

save_path = '/content/drive/MyDrive/Blog/초해상도/SRGAN'

print("train_dataset: {}개, val_dataset: {}개". format(len(listdir(dataset_dir)), 
                                          len(listdir(val_dataset_dir))))
train_dataset: 16823개, val_dataset: 300개

 

Custom Dataset에 필요한 함수 선언

is_image_file 함수는 특정 파일이 이미지 파일인지 검사하여 참이면 True를, 거짓이면 False를 Return 합니다.

hr_transform 함수는 PIL Image 파일을 입력받아서 Random으로 Crop 후, Torch Tensor로 변경해 주는 함수입니다.

lr_transform 함수는 hr_transform에서 처리된 Troch Tensor를 입력값으로 받습니다. 이를 PIL Image 파일로 변경하고 4배로 Scale Down(upscale_factor = 4)한 다음 Torch Tensor로 다시 변경해 주는 함수입니다.

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def hr_transform(crop_size):
  return Compose([RandomCrop(crop_size), ToTensor()])

def lr_transform(crop_size, upscale_factor):
  return Compose([ToPILImage(), Resize(crop_size//upscale_factor, interpolation = Image.BICUBIC), ToTensor()])

 

Custom Dataset 선언

TrainDataset은 위에서 정의한 hr_transform 한 hr_image와 lr_transform 적용한 lr_image를 각각 Return 합니다.

ValDataset은 성능 평가를 위해 Center Crop 한 hr_image, Scale Down 한 lr_image, bicubic upscale 한 이미지를 return 합니다.

최종 Test는 따로 진행하지 않기 때문에, TestDataset은 생략했습니다.

class TrainDataset(Dataset):
  def __init__(self, dataset_dir, crop_size, upscale_factor):
    super(TrainDataset, self).__init__()

    self.image_filenames = [join(dataset_dir,x) for x in listdir(dataset_dir) if is_image_file(x)]
    self.hr_transform = hr_transform(crop_size)
    self.lr_transform = lr_transform(crop_size, upscale_factor)

  def __getitem__(self, index):
    hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
    lr_image = self.lr_transform(hr_image)
    return lr_image, hr_image

  def __len__(self):
    return len(self.image_filenames)

class ValDataset(Dataset):
  def __init__(self, dataset_dir, crop_size, upscale_factor):
    super(ValDataset, self).__init__()

    self.image_filenames = [join(dataset_dir,x) for x in listdir(dataset_dir) if is_image_file(x)]
    self.upscale_factor = upscale_factor

  def __getitem__(self, index):
    hr_image = Image.open(self.image_filenames[index])
    hr_image = CenterCrop(crop_size)(hr_image)

    lr_image = Resize(crop_size // self.upscale_factor, interpolation = Image.BICUBIC)(hr_image)
    bicubic_hr_image = Resize(crop_size, interpolation = Image.BICUBIC)(lr_image)

    return ToTensor()(lr_image), ToTensor()(bicubic_hr_image), ToTensor()(hr_image)

  def __len__(self):
    return len(self.image_filenames)

class TestDataset(Dataset):
  pass

 

Dataset 확인

데이터셋이 잘 구성되었는지 확인합니다.

train_dataset = TrainDataset(dataset_dir, crop_size, upscale_factor)
val_dataset = ValDataset(val_dataset_dir, crop_size, upscale_factor)

hr, lr = train_dataset[4]

fig,axes = plt.subplots(1,2,figsize = (10,10))

axes[0].imshow(hr.permute(1,2,0))
axes[1].imshow(lr.permute(1,2,0))

 

저해상도 이미지(왼쪽), 고해상도 이미지(오른쪽)

 

데이터셋이 잘 구성되었으니, 이제 DataLoader를 선언합니다.

train_dataloader = DataLoader(train_dataset, num_workers = 2,  batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset,num_workers =2, batch_size = 1, shuffle = False)

 

Loss Function 정의

SR 이미지를 위한 Generator의 Loss 함수를 정의합니다. 참고로 Discriminator는 일반 GAN과 동일하므로, 나중에 Train Loop에서 별도의 정의 없이 계산하여 사용했습니다.

TV Loss는 자연스러운 이미지를 생성할 수 있도록 도움을 주는 함수입니다. 자세한 내용은 본문 위쪽의 조건 또는 본문 아래 설명 링크를 참고해 주세요.

Total Loss에서 Content Loss와 Perceptual Loss에 곱해지는 가중치는 논문을 참고했고 TV Loss의 가중치는 제가 참고한 코드와 동일하게 설정해 두었습니다.

class Generator_Loss(nn.Module):
  def __init__(self):
    super(Generator_Loss, self).__init__()
    vgg = vgg16(pretrained=True)
    loss_net = nn.Sequential(*list(vgg.features[:30])).eval()

    for params in loss_net.parameters():
      params.requires_grad = False

    self.loss_net = loss_net
    self.mse = nn.MSELoss()
    self.tv_loss = TVLoss()

  def forward(self,  netD_out, fake_img, real_img):
    # Content Loss: Feature Map간 MSE 연산
    content_loss = self.mse(self.loss_net(fake_img), self.loss_net(real_img))
    # Adversarial Loss: Generator Loss, 학습을 위해 Log를 사용하지 않음
    adversarial_loss = torch.mean(1-netD_out)
    # Perceptual Loss: Content와 Adversarial의 가중치 합
    perceptual_loss = content_loss * 0.006 + adversarial_loss * 0.001
    # Total Loss: TV Loss 추가, 이미지가 사실적이게 보이는 효과
    tv_loss = self.tv_loss(fake_img)

    return perceptual_loss + tv_loss * 2e-8


class TVLoss(nn.Module):
  def __init__(self, tv_loss_weight=1):
      super(TVLoss, self).__init__()
      self.tv_loss_weight = tv_loss_weight

  def forward(self, x):
      batch_size = x.size()[0]
      h_x = x.size()[2]
      w_x = x.size()[3]
      count_h = self.tensor_size(x[:, :, 1:, :])
      count_w = self.tensor_size(x[:, :, :, 1:])
      h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
      w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
      return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

  @staticmethod
  def tensor_size(t):
      return t.size()[1] * t.size()[2] * t.size()[3]

 

모델 선언: Generator

class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

 

모델 선언: Discriminator

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

 

GPU 가속

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

netG = Generator(4)
netD = Discriminator()

Generator_loss = Generator_Loss()

netG.to(device)
netD.to(device)
Generator_loss.to(device)

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

 

Train Loop: SRGAN-VGG54 학습

for epoch in range(1,epochs+1):
  print("Start {} epochs ... ".format(epoch+1))
  train_bar = tqdm(train_dataloader)

  netG.train()
  netD.train()

  for lr_img, hr_img in train_bar:
    lr_img = lr_img.to(device)
    hr_img = hr_img.to(device)

    # ===================================
    #  Train Discriminator : Maximize D(x) - 1 - D(G(z))
    # ===================================

    fake_img = netG(lr_img)

    netD.zero_grad()
    loss = netD(hr_img).mean() - 1 + netD(fake_img).mean() # netD의 결과값은 배치크기 만큼의 확률값 벡터
    loss.backward(retain_graph = True)
    optimizerD.step()

    # ===================================
    #  Train Generator : Mimize 1 - D(G(z)) + Content Loss + TV Loss
    # ===================================

    netG.zero_grad()
    ## 아래 두 줄은 구글 코랩 런타임 에러 방지를 위해 추가
    fake_img = netG(lr_img)
    netD_output = netD(fake_img).mean()
    ##
    loss = Generator_loss(netD_output, fake_img, hr_img)
    loss.backward()

    fake_img = netG(lr_img)
    netD_output = netD(fake_img).mean()

    optimizerG.step()

  # ===================================
  #  Epoch 당 ValDataset 검증
  # ===================================
  netG.eval()

  with torch.no_grad():
    val_bar = tqdm(val_dataloader)
    valing_results = {'mse': 0, 'psnr': 0, 'batch_sizes': 0}

    for val_lr_img, val_bicubic_img, val_hr_img in val_bar:
      batch_size = val_lr_img.size(0)
      valing_results['batch_sizes'] += batch_size

      val_lr_img= val_lr_img.to(device)
      val_hr_img = val_hr_img.to(device)
      val_bicubic_img = val_bicubic_img.to(device)

      val_sr_img = netG(val_lr_img)

      batch_mse = ((val_sr_img - val_hr_img)**2).data.mean()
      valing_results['mse'] += (batch_mse * batch_size) # 누적 mse

      valing_results['psnr'] = 10 * log10((val_hr_img.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))

      print("Val_Score >> PSNR: {}".format(valing_results['psnr']))


torch.save(netG.state_dict(), save_path + '/' + str(epoch) + 'epochs')
torch.save(netD.state_dict(), save_path + '/' + str(epoch) + 'epochs')

 

SRGAN-VGG54 성능 평가

Colab에 오류가 있는지 학습 도중 종료된 다음 Chrome이 아이에 실행이 안되네요. 우선 아래 모델 성능 평가를 위해 저해상도 이미지를 Generator에 넣고 SR 이미지를 얻는 코드를 넣어 두었습니다. 이 글을 읽으시는 분들께서는 성공하셨으면 좋겠네요. :)

test_path = '/content/drive/MyDrive/Blog/초해상도/SRCNN/set5_set14/Set14'

model = netG.eval()

image = Image.open(test_path + '/' + 'baboon.png') # baboon.png, comic.png, monarch.png
print(ToTensor()(image).size())
image = Resize(96, interpolation = Image.BICUBIC)(image)
image = ToTensor()(image).unsqueeze(0).to(device)
print(image.size())

start = time.time()
sr_img = netG(image)
end = (time.time() - start)
print(sr_img.size())
sr_img = ToPILImage()(sr_img[0].data.cpu())

# out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)

 

 

[참고]

#1 SRGAN Pytorch 코드 (from GitHub)

https://github.com/leftthomas/SRGAN

 

GitHub - leftthomas/SRGAN: A PyTorch implementation of SRGAN based on CVPR 2017 paper "Photo-Realistic Single Image Super-Resolu

A PyTorch implementation of SRGAN based on CVPR 2017 paper "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network" - GitHub - leftthomas/SRGAN: A PyTorch im...

github.com

 

#2 TV Loss

https://curaai00.tistory.com/category/%EB%94%A5%EB%9F%AC%EB%8B%9D/GAN

 

'딥러닝/GAN' 카테고리의 글 목록

 

curaai00.tistory.com

 

댓글