본문 바로가기
AI Research/Deep Learning

[Pytorch-기초강의] 경쟁하며 학습하는 GAN

by ga.0_0.ga 2023. 3. 3.
728x90
반응형

▶ GAN(Generative Adversarial Network) 이란?

직역하면 "적대적 생성 신경망"입니다.

1. 앞서 배운 CNN과 RNN 모델로는 새로운 것을 만들어 낼 수 없었습니다. 그러나 GAN은 새로운 이미지나 음성을 "생성(창작)"하도록 할 수 있습니다.

2. GAN은 적대적으로 학습합니다. GAN은 가짜이미지를 생성하는 생성자(generator)와 생성된 이미지의 진위를 판별하는 판별자(discriminator)가 번갈아가며 학습하는 경쟁적 방식으로 학습을 진행합니다.

3. GAN은 생성자와 판별자 모두 신경망으로 되어있는 인공신경망 모델입니다.

=> 요약하자면, GAN은 서로 대립하는 두 모델이 경쟁해 학습하는 방법론입니다.

 

● GAN이 주목받는 이유

비지도학습 방식이기 때문입니다. 세상에 존재하는 데이터는 매우 많고, 대부분의 데이터에는 정답이 없습니다. 그 수 많은 데이터에 맞는 정답을 찾아 사람이 모든 데이터를 일일이 가공하는 것은 거의 불가능합니다. GAN은 앞서 배운 오토인코더와 같이 비지도학습을 하여 사람의 노력을 최소화합니다. 또한 GAN의 방법론은 적용할 수 있는 분야가 매우 많습니다. 그래서 머신러닝으로 풀고자 하는 문제들 대부분에 GAN을 이용한 방법이 시도되고 있으며 놀라운 성과를 거두고 있습니다.

● 생성자와 판별자

GAN 모델은 생성자와 판별자라는 주요 모듈 2가지로 구성됩니다. GAN은 아래 그림과 같은 구조를 갖습니다.

GAN의 구조

- 생성자

무작위 텐서로부터 여러 가지 형태의 가짜 이미지를 생성합니다. 학습이 진행되면서 생성자는 판별자를 속이려고 점점 더 정밀한 가짜 이미지를 생성하게됩니다. 따라서 마지막에 생성자는 진짜 이미지와 거의 흡사한가짜이미지를 만들 수 있게 됩니다.

- 판별자

진짜 이미지와 가짜 이미지를 구분합니다. 학습이 진행되면서 판별자는 학습 데이터에서 가져온 진짜 이미지와 생성자가 만든 가짜 이미지를 점점 더 잘 구별하게 됩니다.

=> 지폐위조범과 경찰에 비유해보자면 지폐 위조범(생성자)은 경찰을 속이기 위해 최대한 진짜 같은 지폐를 만드는 한편, 경찰(판별자)은 위조지폐와 진짜를 감별하려고 노력합니다. 이런 경쟁 구도 속에서 지폐 위조범과 경찰의 능력이 발전하게 되고, 결과적으로는 경찰이 위조범이 만드는 위폐와 진폐를 구별하기 힘든 효과를 얻을 수 있습니다.

▶ GAN으로 새로운 패션 아이템 생성하기
전체 코드 주소입니다!

 

GitHub - jgyy4775/3-min-pytorch: <펭귄브로의 3분 딥러닝, 파이토치맛> 예제 코드

<펭귄브로의 3분 딥러닝, 파이토치맛> 예제 코드. Contribute to jgyy4775/3-min-pytorch development by creating an account on GitHub.

github.com

 

● 간단 코드 설명

이전 게시물들에서 설명한 부분은 제외하고 설명하겠습니다.

- Fashion MNIST 데이터 셋 불러오기

trainset = datasets.FashionMNIST(
    './.data',
    train=True,
    download=True,
    transform=transforms.Compose([
       transforms.ToTensor(),
       transforms.Normalize((0.5,), (0.5,))
    ])
)
train_loader = torch.utils.data.DataLoader(
    dataset     = trainset,
    batch_size  = BATCH_SIZE,
    shuffle     = True
)

 

- 생성자 구현

=> Sequential 클래스는 신경망을 이루는 각 층에서 수행할 연산들을 입력받아 차례대로 실행하는 역할을 합니다. 생성자는 실제 데이터와 비슷한 가짜 데이터를 만들어내는 신경망입니다. 무작위 텐서를 입력하는 이유는 생성자가 실제 데이터의 ‘분포’를 배우는 것이기 때문입니다. 생성자는 정규분포 같은 단순한 분포에서부터 실제 데이터의 복잡한 분포를 학습합니다.

예제의 생성자 모델 구조
# 생성자는 64차원의 랜덤한 텐서를 입력받아 이에 행렬곱(Linear)과 활성화 함수(ReLU, Tanh) 연산을 
# 실행, 생성자의 결과값은 784차원, 즉 Fashion MNIST 속의 이미지와 같은 차원의 텐서를 생성
G = nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 784),
        nn.Tanh()) # 결과값을 -1과 1 사이로 압축
 
 

- 판별자 구현

=> 판별자는 입력된 784차원의 텐서가 생성자가 만든 가짜 이미지인지, 혹은 실제 Fashion MNIST의 이미지인지 구분하는 분류합니다.

예제의 판별자 모델 구조
# 판별자는 784차원의 텐서를 입력받습니다. 판별자 역시 입력된 데이터에 행렬곱과 활성화 함수를 실행
# 시키지만, 생성자와 달리 판별자의 결과값은 입력받은 텐서가 진짜인지 구분하는 예측값임.
D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid())

 

- 진짜와 가짜 레이블 생성

=> 생성자가 만든 데이터는 ‘가짜’라는 레이블을 부여받고 Fashion MNIST 데이터셋의 ‘진짜’ 데이터와 함께 판별자 신경망에 입력됩니다. 진짜와 가짜 이미지에 레이블을 달아주기 위해 두 레이블 텐서를 정의합니다. real _labels 텐서는 ones( ) 함수를 불러 1로만 이루어진 텐서를 만들고 fake_labels는 zeros( ) 함수로 0으로 채워줍니다.

real_labels = torch.ones(BATCH_SIZE, 1).to(DEVICE)
fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

 

- 판별자가 진짜 이미지를 진짜로 인식하는 오차를 계산

outputs = D(images) # images 는 진짜 이미지
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
 

- 가짜 이미지 생성 후 판별자가 가짜 이미지를 가짜로 인식하는 오차 계산

# 무작위 텐서로 가짜 이미지 생성
z = torch.randn(BATCH_SIZE, 64).to(DEVICE)
fake_images = G(z)
        
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs

 

- 판별자 학습

=> 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산

# 진짜와 가짜 이미지를 갖고 낸 오차를 더해서 판별자의 오차 계산
d_loss = d_loss_real + d_loss_fake

d_optimizer.zero_grad()
g_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

 

- 생성자 학습

=> 생성자가 만들어낸 가짜 이미지를 판별자가 진짜라고 착각할 정도까지 만들어야 합니다. 즉, 생성자의 결과물을 다시 판별자에 입력시켜, 그 결과물과 real _labels 사이의 오차를 최소화하는 식으로 학습을 진행해야합니다.

# 생성자가 판별자를 속였는지에 대한 오차를 계산
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)

d_optimizer.zero_grad()
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

 

=> GAN을 통해 생성된 이미지들

▶ cGAN으로 생성 제어하기

GAN이 더욱 쓸모 있으려면 사용자가 원하는 이미지를 생성하는 기능을 제공해야 합니다. 이전 예제는 말 그대로 무작위 벡터를 입력받아 무작위로 패션 아이템을 출력했습니다. 이러한 GAN의 한계를 보완해서 출력할 아이템의 종류를 사용자로부터 입력받아 그에 해당하는 이미지를 생성하는 모델이 바로 '조건부 GAN(Conditional GAN)'입니다. 따라서, cGAN의 생성자는 학습 과정에서 생성하고픈 아이템의 종류 정보를 입력받아야 합니다. 아래 그림은 cGAN의 구조도 입니다.

cGAN의 구조도

구현은 크게 복잡하지 않습니다. 생성자와 판별자의 입력에 레이블 정보를 이어 붙이기만 하면 됩니다. 그림처럼 생성자와 판별자에 레이블 정보만 넣어주면 됩니다.

▶ 조건부 GAN의 예제

● 조건부 생성자와 판별자

- 이번 예제에선 생성자와 판별자가 하나의 입력이 아닌 레이블 정보까지 두 가지를 입력을 받습니다.
전체 코드 주소입니다!

 

GitHub - jgyy4775/3-min-pytorch: <펭귄브로의 3분 딥러닝, 파이토치맛> 예제 코드

<펭귄브로의 3분 딥러닝, 파이토치맛> 예제 코드. Contribute to jgyy4775/3-min-pytorch development by creating an account on GitHub.

github.com

 

- 생성자

=> 무작위로 생성한 레이블 정보를 받아 해당 레이블에 대한 이미지를 생성하도록 학습하게 됩니다.

예제의 cGAN 생성자 구조
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        # ‘배치 x 1’ 크기의 레이블 텐서를 받아 ‘배치 x 10 ‘의 연속적인 텐서로 전환
        self.embed = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(110, 256),  # 무작위 텐서100, 레이블 정보 10
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, z, labels):
        c = self.embed(labels)
        x = torch.cat([z, c], 1) # 두 벡터를 (두 번째 인수 차원에 대해서) 이어붙이는 연산
        return self.model(x)

 

- 판별자

=> 판별자 역시 레이블 정보를 받습니다. 이때 생성자에서 이미지를 만들때 쓴 레이블 정보를 입력받아 “레이블이 주어졌을때 가짜인 확률과 진짜인 확률”을 추정합니다.

예제의 cGAN 판별자 구조
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.embed = nn.Embedding(10, 10)
        
        self.model = nn.Sequential(
            nn.Linear(794, 1024), # 레이블정보 10 포함
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 각각 가짜와 진짜를 뜻하는 0과 1 사이를 반환
        )
    
    def forward(self, x, labels):
        c = self.embed(labels)
        x = torch.cat([x, c], 1)
        return self.model(x)

=> 생성자와 판별자의 학습은 이전 예제와 동일합니다.

=> cGAN으로 생성된 패션 아이템들을 확인해보면 아래와 같습니다.

728x90
반응형

댓글