본문 바로가기
AI - Deep Learning/Super_Resolution

[논문 구현] SRCNN(by Pytorch)을 활용한 Super Resolution

by 모두의 케빈 2023. 6. 14.

앞선 시간에는 SRCNN 논문 리뷰와 Keras를 활용하여 SRCNN을 구현해 봤습니다. 이번 시간에는 Pytorch를 이용하여 SRCNN을 구현해 보도록 하겠습니다. 코드는 논문 리뷰에 근거해서 작성해서 아직 논문을 읽어보시지 못하셨다면 먼저 읽고 오시는 것을 권장드립니다. :)

 

[논문리뷰] SRCNN: Image Super-Resolution Using Deep Convolutional Networks

 

[논문리뷰] SRCNN: Image Super-Resolution Using Deep Convolutional Networks

목차 원본 논문: https://arxiv.org/pdf/1501.00092.pdf #1,2 Introduction and Related Work > SRCNN의 장점 > SRCNN의 업적 #3 CNN for Super Resolution > SRCNN Layer 별 역할 > 기존 방식과 SRCNN의 비교 분석 #4 Experiments: > [1] 데이터

kevinitcoding.tistory.com

 

[논문 구현] SRCNN(by Keras)을 활용한 Super Resolution

 

[논문 구현] SRCNN(by Keras)을 활용한 Super Resolution

저번 시간에는 SRCNN 논문을 리뷰해 봤습니다. 이번 딥러닝 프레임 워크 중, Keras를 활용하여 SRCNN을 구현해 보도록 하겠습니다. 이 코드는 앞에서 제가 작성한 논문 리뷰 내용을 기반으로 작성했

kevinitcoding.tistory.com

 

 


목차

모델의 구현 조건

SRCNN Code 구현(Pytorch)

    > 필요한 라이브러리 선언

    > 파라미터 선언

    > 구글 드라이브 연동 및 Device 설정

    > 저해상도 이미지 만들기

    > 원본 HR 및 (저해상도, 고해상도) Paired 이미지 패치 확인

    > Model 선언

    > Train & Test Loop

    > Model Training

    > Model 성능 Test

    > 이미지 비교(원본 고해상도 vs Bicubic vs SRCNN)

참고 문헌


 

 

모델의 구현 조건 (코드는 Google Colab 기반으로 작성했습니다.)

구현 조건은 Keras 코드와 동일합니다.

1. 논문에서는 Color 이미지를 YCrCb 채널로 변환 없이 RGB 채널을 그대로 SR해도 성능이 좋다고 되어 있습니다. 해당 결과를 참고하여 RGB Channel을 그대로 입력받는 모델을 구현합니다.

2. 논문의 Filter 개수에 대한 실험 결과를 참고하여 n1 = 128, n2 = 64으로 설정합니다. (n은 각 Conv Layer의 Filter 개수) 최종 출력은 고해상도 3차원 이미지가 되어야 하므로 n3 = 3입니다.

3. 논문의 Filter Size에 대한 실험 결과를 참고하여 f1 = 9, f2 = 3, f3= 5로 설정합니다. (f는 각 Conv Layer의 Filter 크기)

4. 첫 번째, 두 번째 Convolution Layer의 활성함수는 ReLU, 마지막 Layer의 활성 함수는 Linear로 설정합니다.

5. 논문 조건과 유사하게 Weights는 Xaiver 정규 분포로, Bias는 Zero Vector(0)으로 초기화합니다.

6. 학습률(Learning Rate)은 구현 편의성을 고려하여 0.003으로 일괄 통일합니다.

7. SRCNN에 사용할 저해상도 이미지는 논문의 방법과 동일하게 고해상도 이미지를 Random Crop 후 Bicubic 보간법에 의해 Upscaling 된 이미지들을 활용합니다.

8. 고해상도 이미지는 논문과 동일하게 91 Images를 활용했으며, Kaggle에서 직접 다운로드하였습니다. 검증 데이터도 논문과 동일하게 set5를 사용했으며 마찬가지로 Kaggel에서 직접 다운로드했습니다. 데이터는 아래 링크 참고해 주세요.

https://www.kaggle.com/datasets/ll01dm/t91-image-dataset

 

T91 Image Dataset

Super Resolution dataset

www.kaggle.com

https://www.kaggle.com/datasets/ll01dm/set-5-14-super-resolution-dataset

 

Set 5 & 14 Super Resolution Dataset

Dataset for evaluating Super Resolution networks

www.kaggle.com

 

 

SRCNN Code 구현 (Pytorch)

 

필요한 라이브러리 선언

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data.dataloader import Dataset, DataLoader

import cv2
import os
import numpy as np
import glob
import random
import matplotlib.pyplot as plt

 

 

파라미터 선언

path는 학습에 사용할 91 Images 데이터셋을 저장한 위치입니다. save_path는 훈련이 끝난 모델의 가중치를 저장하는 경로입니다. 저 같은 경우에는 Google Drive & Colab을 활용해서 경로를 아래와 같이 지정했는데요.

각자의 환경에 맞춰 path와 save_path는 변경해서 사용하시면 됩니다. 참고로 Local PC에서 사용하시는 분들은 파일 경로를 표시할 때 '/'가 아니라 '역슬래시(\)'를 1개 또는 2개 사용해야 할 수도 있습니다.

n1,n2,n3 = 128, 64, 3
f1,f2,f3 = 9,3,5
upscale_factor = 3
#To avoid border effects during training, all the convolutional layers have no padding, and the network
#produces a smaller output ((fsub − f1 − f2 − f3 + 3)2 × c).

input_size = 33
output_size = input_size - f1 - f2 - f3 + 3

stride = 14

batch_size = 128
epochs = 200

path = "/content/drive/MyDrive/Blog/초해상도/SRCNN/T91"
save_path = "/content/drive/MyDrive/Blog/초해상도/SRCNN/torch_SRCNN_200EPOCHS.h5"

 

 

구글 드라이브 연동

Colab에 Google Drive를 연동하는 코드입니다. Colab을 사용하지 않으시는 분들은 아래 코드는 무시하셔도 됩니다.

from google.colab import drive
drive.mount('/content/drive')

 

Device 설정

Device를 설정합니다. GPU 사용이 가능하면 GPU, 불가능하면 CPU에서 학습이 진행되도록 설정합니다.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

 

저해상도 이미지 만들기

SRCNN은 저해상도와 고해상도 이미지 패치(Crop 된 Sub Images) 사이의 관계를 학습합니다. 따라서 학습 데이터를 만들려면 (1) 고해상도 이미지를 저해상도 이미지로 만드는 작업과 (2) 각각의 이미지를 Crop 하여 맵핑하는 작업이 필요합니다.

학습을 위해 CustomDateset을 설정합니다.

"Zoom_img" 부분은 고해상도 이미지를 저해상도 이미지로 만드는 코드입니다. 고해상도 이미지 label을 upscale_factor만큼(여기서는 3) 줄이고, 다시 Bicubic 보간법으로 크기를 키워주면 한 장의 저해상도 이미지가 완성됩니다. 일상생활에서 이미지를 확대하면 픽셀이 깨지면서 저해상도 이미지가 되는 원리와 유사하다고 생각하시면 됩니다.

"Crop: img to sub_imgs" 부분은 저해상도 이미지와 고해상도 이미지를 각각 Crop 하여 Input & Label 관계로 맵핑하는 코드입니다. SRCNN은 Padding을 사용하지 않기 때문에 입력 이미지의 크기보다 출력 이미지의 크기가 작습니다.

따라서 저해상도와 고해상도 이미지를 맵핑한다는 말은 저해상도 이미지를 축소 복원하여 고해상도 이미지를 만들도록 모델을 설계한다는 의미입니다.

아래 코드에서 Input_size는 저해상도 이미지를 Crop 하는 크기를 말합니다. 논문과 동일하게 33으로 설정해 줍니다. 이렇게 되면 Crop 된 저해상도 이미지 패치들은 33 by 33의 크기를 갖게 됩니다.

Output_size는 저해상도 이미지 패치가 SRCNN을 통과하여 최종 출력된 결과 이미지의 크기를 의미합니다. 논문의 공식에 의해 "33-9-5-3+3"에 의해 19로 설정합니다.

Stride는 논문과 동일하게 14로 설정했습니다. 결과적으로 1장의 저해상도 이미지는 14씩 건너뛰면서 33 by 33의 정사각형 저해상도 Sub Images로 Crop 됩니다. 1장의 고해상도 이미지는 14씩 건너뛰면서 19 by 19의 정사각형 저해상도 Sub Images로 Crop 됩니다.  

class CustomDataset(Dataset):
  def __init__(self, img_paths, input_size, output_size, stride = 14, upscale_factor = 3):
    super(CustomDataset, self).__init__()
    
    self.img_paths = glob.glob(img_paths + '/' + '*.png')
    self.stride = stride
    self.upscale_factor = upscale_factor
    self.sub_lr_imgs = []
    self.sub_hr_imgs = []
    self.input_size = input_size
    self.output_size = output_size
    self.pad = abs(self.input_size - self.output_size) // 2 # 7

    print("Start {} Images Pre-Processing".format(len(self.img_paths)))
    for img_path in self.img_paths:
      img = cv2.imread(img_path, cv2.COLOR_BGR2RGB)
      #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  
      # mod_crop
      h = img.shape[0] - np.mod(img.shape[0], self.upscale_factor)
      w = img.shape[1] - np.mod(img.shape[1], self.upscale_factor)
      img = img[:h, :w, :]

      # zoom_img
      label = img.astype(np.float32) / 255.0
      temp_input = cv2.resize(label, dsize=(0,0), fx = 1/self.upscale_factor, fy = 1/self.upscale_factor,
                              interpolation = cv2.INTER_AREA)
      input = cv2.resize(temp_input, dsize=(0,0), fx = self.upscale_factor, fy = self.upscale_factor,
                        interpolation = cv2.INTER_CUBIC)
  
      # Crop: img to sub_imgs
      for h in range(0, input.shape[0] - self.input_size + 1, self.stride):
        for w in range(0, input.shape[1] - self.input_size + 1, self.stride):
          sub_lr_img = input[h:h+self.input_size, w:w+self.input_size, :]
          sub_hr_img = label[h+self.pad:h+self.pad+self.output_size, w+self.pad:w+self.pad+self.output_size, :]

          sub_lr_img = sub_lr_img.transpose((2,0,1))
          sub_hr_img = sub_hr_img.transpose((2,0,1))
          #sub_lr_img = sub_lr_img.reshape(3, self.input_size, self.input_size)
          #sub_hr_img = sub_hr_img.reshape(3, self.output_size, self.output_size)

          self.sub_lr_imgs.append(sub_lr_img)
          self.sub_hr_imgs.append(sub_hr_img)
    print("Finish, Created {} Sub-Images".format(len(self.sub_lr_imgs)))
    self.sub_lr_imgs = np.asarray(self.sub_lr_imgs)
    self.sub_hr_imgs = np.asarray(self.sub_hr_imgs)


  def __len__(self):
    return len(self.sub_lr_imgs)

  def __getitem__(self, idx):
    lr_img = self.sub_lr_imgs[idx]
    hr_img = self.sub_hr_imgs[idx]
    return lr_img, hr_img

 

 

원본 HR 이미지 확인

img = cv2.imread(train_dataset.img_paths[12])
print(img.shape)
plt.imshow(img)

 

HR 이미지(from 91 Images 데이터셋)

 

 

(저해상도, 고해상도) Paired 이미지 패치 확인

fig, axes = plt.subplots(1,2, figsize = (5,5))
idx = random.randint(0, len(train_dataset.sub_lr_imgs))

axes[0].imshow(train_dataset.sub_lr_imgs[idx].transpose(1,2,0))
axes[1].imshow(train_dataset.sub_hr_imgs[idx].transpose(1,2,0))

print(idx)
axes[0].set_title('lr_img')
axes[1].set_title('hr_img')

LR vs HR Sub Image

 

 

Model 선언

class SRCNN(nn.Module):
  def __init__(self, kernel_list, filters_list, num_channels = 3):
    super(SRCNN, self).__init__()

    f1,f2,f3 = kernel_list
    n1,n2,n3 = filters_list
    
    self.conv1 = nn.Conv2d(num_channels, n1, kernel_size = f1)
    self.conv2 = nn.Conv2d(n1, n2, kernel_size = f2)
    self.conv3 = nn.Conv2d(n2, num_channels, kernel_size = f3)
    self.relu = nn.ReLU(inplace = True)

    torch.nn.init.xavier_normal_(self.conv1.weight)
    torch.nn.init.xavier_normal_(self.conv2.weight)
    torch.nn.init.xavier_normal_(self.conv3.weight)

    torch.nn.init.zeros_(self.conv1.bias)
    torch.nn.init.zeros_(self.conv2.bias)
    torch.nn.init.zeros_(self.conv3.bias)

  def forward(self, x):
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.conv3(x)
    return x

 

Train & Test loop 

def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset)

  for batch, (X,y) in enumerate(dataloader):
    X = X.to(device)
    y = y.to(device)

    pred = model(X)
    loss = loss_fn(pred, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if batch % 100 == 0:
      loss, current = loss.item(), (batch + 1) * len(X)
      print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  test_loss = 0

  with torch.no_grad():
      for batch, (X,y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)

        pred = model(X)
        test_loss += loss_fn(pred, y)
  test_loss /= num_batches
  print(f"Avg loss: {test_loss:>8f} \n")

 

Model Compile

model = SRCNN(kernel_list = [f1,f2,f3], filters_list = [n1,n2,n3]).to(device)
print(model)

params = model.parameters()

optimizer = optim.Adam(params = params, lr=1e-3)
loss_fn = nn.MSELoss()

 

Model Training

for i in range(epochs):
  print("{} Epochs ... ".format(i+1))
  model = model.train()
  train(train_dataloader, model, loss_fn, optimizer)
print("Done!")

torch.save(model.state_dict, save_path)

 

 

Model 성능 Test

Model 성능은 논문과 동일하게 Set5에서 'Butterfly' 이미지를 사용하여 확인해 보도록 하겠습니다.

hr_img_path = '/content/drive/MyDrive/Blog/초해상도/SRCNN/set5_set14/Set5/butterfly.png'

hr_img = cv2.imread(hr_img_path)
hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
print("img shape: {}".format(hr_img.shape))

plt.imshow(hr_img)

hr_img = hr_img.astype(np.float32) / 255.0
temp_img = cv2.resize(hr_img, dsize=(0,0), fx = 1/upscale_factor, fy = 1/upscale_factor,interpolation = cv2.INTER_AREA)
bicubic_img = cv2.resize(temp_img, dsize=(0,0), fx = upscale_factor, fy = upscale_factor, interpolation = cv2.INTER_CUBIC)


model.eval()
input_img = bicubic_img.transpose((2,0,1))
input_img = torch.tensor(input_img).unsqueeze(0).to(device)

with torch.no_grad():
  srcnn_img = model(input_img)

srcnn_img = srcnn_img.squeeze().cpu().numpy().transpose((1,2,0))
#srcnn_img = cv2.cvtColor(srcnn_img, cv2.COLOR_BGR2RGB)

 

 

원본 vs Bicubic vs SRCNN

원본 이미지와 Bicubic 이미지, SRCNN 출력 이미지를 한 번에 비교해 보겠습니다.

fig, axes = plt.subplots(1,3, figsize = (10,5))

axes[0].imshow(hr_img)
axes[1].imshow(bicubic_img)
axes[2].imshow(np.squeeze(srcnn_img))

axes[0].set_title('hr_img')
axes[1].set_title('bicubic_img')
axes[2].set_title('srcnn_img')

 

여러분들 눈에는 어떠신가요? 논문에서 제안한 대로 RGB 채널 이미지를 그대로 활용하더라도 충분한 성능이 나오는 것으로 보입니다. 물론 아직은 원본 고해상도 이미지에 비하면 Detail 한 부분이 떨어지지만, 단순한 모델치고는 성능이 만족스럽습니다. 

성능 비교: 원본 고해상도 이미지 vs Bicubic vs SRCNN

 

참고로 저는 Study용으로 작성해서 모델 성능 평가 과정에서 PSNR 계산을 따로 하지는 않았는데요. 필요하신 분들은 아래 코드 참고하시면 도움이 될 것 같습니다.

def PSNR(y_pred, y_ture):
    return 10. * torch.log10(1. / torch.mean((y_pred - y_ture) ** 2))

 

 

참고 문헌

https://github.com/yjn870/SRCNN-pytorch

 

GitHub - yjn870/SRCNN-pytorch: PyTorch implementation of Image Super-Resolution Using Deep Convolutional Networks (ECCV 2014)

PyTorch implementation of Image Super-Resolution Using Deep Convolutional Networks (ECCV 2014) - GitHub - yjn870/SRCNN-pytorch: PyTorch implementation of Image Super-Resolution Using Deep Convoluti...

github.com

https://github.com/MarkPrecursor/SRCNN-keras

 

GitHub - MarkPrecursor/SRCNN-keras

Contribute to MarkPrecursor/SRCNN-keras development by creating an account on GitHub.

github.com

https://hwangtoemat.github.io/paper-review/2019-07-11-SRCNN-%EB%82%B4%EC%9A%A9/

 

[Super Resolution] 1-1. SRCNN 논문 리뷰

SRCNN - Image Super-Resolution Using Deep Convolutional Networks 을 읽고 논문 주요내용을 정리해본다.

hwangtoemat.github.io

 

댓글