[생성모델 연구] GAN 발전동향
GAN 학습과 관련된 설명
- p(data):
- 여기서 x1, x2, x3, x4는 특정 샘플 이미지의 표현 벡터를 나타내는 것이 아니라, 고차원 공간에서 해당 샘플 이미지가 분포할 확률을 의미합니다.
- Discriminator 학습:
- GAN을 학습할 때, Discriminator를 학습할 때는 Generator를 고정한 채로 Discriminator를 학습시킵니다. 이때 Generator는 이미지만 생성하는 역할을 합니다.
- Generator 학습:
- GAN을 학습할 때, Generator를 학습할 때는 Discriminator를 고정한 채로 Generator를 학습시킵니다. 이때 Discriminator는 Real과 Fake만 구분하는 역할을 합니다.
- 분리된 학습 방식:
- Generator와 Discriminator는 각각 따로 학습하는 방식을 취합니다.
- Mode collapse 문제:
- Generator와 Discriminator가 경쟁적으로 (따로따로) 학습할 때, Discriminator가 너무 빨리 학습되면 Generator가 제대로 된 이미지를 생성하지 못하는 경우가 발생하는데, 이를 Mode collapse라고 명명합니다. 이를 해결하기 위해 Generator를 더 많이 학습시키는 방식이 있었으나, 현재 트렌드는 loss 측면에서 이 문제를 해결하는 방향으로 발전하고 있습니다.
- Generator weight 공유:
- 일반적으로 Generator weight만 GitHub에서 공유하며, Discriminator weight는 공개하지 않습니다. 이는 Generator가 이미지를 잘 생성하는 것이 가장 중요하기 때문입니다.
- Latent vector 조작:
- “man with glasses - man without glasses + woman without glasses = woman with glasses”라는 변환을 할 때 각 특징에 해당하는 latent vector를 평균하여 더하고 빼서 이미지를 생성합니다. 이미지에 변화를 주기 위해서는 전체적으로 작은 노이즈를 latent vector에 더해주는 방식을 취합니다.
- Conditional GAN (CGAN):
- z를 input으로 줄 때, z (1,256)와 y (1,10)을 concat하여 (1,266)으로 만들어서 입력합니다. Generator에서 latent vector z와 conditional vector y를 concat하여 입력하고, Real image를 넣을 때도 Real에 y를 concat하여 Discriminator에서도 y를 조건으로 입력하는 방식입니다. 이 부분이 그림에서 명확히 표현되지 않아 혼선이 있을 수 있습니다.
- Semi Supervised GAN (SGAN):
- SGAN은 스스로 condition을 찾도록 학습하는 네트워크입니다.
GAN 학습의 발전
- 초기 단계 학습:
- 학습 초기 단계에서 Generator가 생성하는 데이터는 좋지 않기 때문에 Discriminator의 accuracy가 100%입니다. 그러나 학습이 진행되면서 Generator의 성능이 향상되어 Real과 가까운 이미지를 생성하게 되면 Discriminator의 accuracy는 점차 떨어져 이상적으로는 50%에 도달합니다. Ideal한 Discriminator의 학습 그래프를 그려보면 Discriminator의 accuracy가 60%에 도달하는 것을 보여줍니다.
- Discriminator의 실제 accuracy:
- 실제 Discriminator의 accuracy는 학습 초기에는 보통 구분력이 없기 때문에 100%에서 시작하는 것이 아니라, 0%에서 100%로 올라가는 형태로 관찰됩니다. 또는 0%에서 시작해 spike가 생겨 100%를 찍었다가 떨어져서 60~70%에 수렴하는 형태를 보이기도 합니다.
- Discriminator accuracy 기반 학습 수렴 결정:
- Discriminator의 accuracy를 기반으로 학습의 수렴을 결정하면 40%~60%에서 변동폭이 심할 수 있습니다. 그러나 Discriminator의 accuracy가 60%라고 해서 Generator가 생성한 이미지가 꼭 좋은 것은 아니고, Discriminator의 accuracy가 40%라고 해서 Generator가 생성한 이미지가 꼭 나쁜 것은 아닙니다. 따라서 Discriminator의 accuracy만으로 성능의 최적점을 평가하기에는 부적절합니다.
- Logit 값을 이용한 학습:
- 이 문제를 해결하기 위해 Discriminator의 accuracy만으로 학습하는 것이 아니라, Discriminator가 Real인지 Fake인지 구분한 샘플들의 logit 값들의 기대값을 구하여 이를 Real, Fake의 augmentation 기준으로 사용합니다. 만약 r=1이면 Discriminator가 Real을 Real이라고만 구분하는 상황이므로 overfitting이 되었다고 볼 수 있습니다. r=0이면 Discriminator가 Real과 Fake를 헷갈리지 않는 상황이며, 우리가 원하는 방향입니다. Augmentation의 강도를 조절하여 Discriminator의 학습 난이도를 조정할 수 있습니다.
- Sigmoid와 logit 값:
- Sigmoid를 사용하면 0~1 사이로 변환되는데, Sigmoid에 넣기 전의 Discriminator의 logit 값을 r을 구할 때 사용합니다.
- 균형 잡힌 학습:
- Discriminator와 Generator가 균형 있게 학습하는 것이 중요합니다. Discriminator가 Generator보다 살짝 성능이 좋아야 하지만, 너무 성능이 좋으면 Real을 Real로, Fake를 Fake로만 구분하게 되어 Mode collapse에 빠질 수 있습니다.
- 학습 주기 조절:
- 학습 밸런스가 잘 맞지 않아 Generator와 Discriminator의 학습 주기를 iteration마다 Generator는 1번, Discriminator는 5번으로 설정하는 기법을 사용하기도 했습니다. 그러나 이는 근본적인 해결책이 아니므로, 이후 발전된 Discriminator에서는 patch 단위로 least square로 구하여 loss를 더 복잡하게 만들어 균형 있게 학습하도록 발전하였고, Discriminator와 Generator의 학습 주기는 1:1로 학습하는 방향으로 발전하였습니다.
- 균형 조절 기법:
- 일반적으로 Discriminator가 더 많이 iteration을 수행하는 것이 맞지만, Discriminator가 지나치게 강해지면 Generator의 학습 횟수를 오히려 더 키우는 것이 맞습니다.
- WGAN과 LSGAN의 logit 값 사용:
- WGAN(Wasserstein GAN)과 LSGAN(Least Square GAN)에서는 sigmoid를 통과하지 않은 logit 값을 사용합니다. 이는 sigmoid 함수의 사용으로 인해 발생할 수 있는 gradient vanishing 문제를 피하고, 보다 안정적이고 효율적인 학습을 위해서입니다. WGAN은 Wasserstein distance를 이용하여 생성된 데이터와 실제 데이터 간의 차이를 줄이려고 하며, LSGAN은 least square loss를 사용하여 generator와 discriminator의 학습을 더 부드럽게 합니다.
- 최신 GAN의 Loss Function:
- 최근의 GAN들은 더 이상 0과 1로 구분하는 전통적인 binary cross-entropy loss를 사용하지 않습니다. 초기의 Conditional GAN (pix2pix)나 CycleGAN도 처음에는 Discriminator loss로 0과 1로 구분하는 binary cross-entropy loss를 사용했지만, 이러한 접근법은 gradient vanishing 문제를 일으킬 수 있었습니다. 최근의 발전에서는 logit 값을 활용하여 보다 안정적이고 효율적인 학습을 가능하게 하는 방향으로 GAN의 loss function이 변경되었습니다. 이는 학습 과정에서 gradient flow를 개선하고, 보다 현실적인 이미지를 생성하는 데 기여합니다.
- CycleGAN과 Unpaired Data:
- CycleGAN은 unpaired data를 사용하여 이미지 간의 변환을 학습할 수 있는 모델입니다. 이는 이미지 쌍이 필요했던 기존의 모델들과 달리, 서로 다른 두 도메인 간의 매핑을 학습할 때 paired data가 필요하지 않습니다. CycleGAN은 두 개의 생성자와 두 개의 판별자를 사용하여, 한 도메인의 이미지를 다른 도메인의 이미지로 변환한 후 다시 원래 도메인으로 복원하는 사이클 일관성(cycle consistency)을 통해 학습합니다. 이를 통해 CycleGAN은 unpaired data가 없는 상황에서도 높은 품질의 이미지 변환을 달성할 수 있습니다.
SSIM 이나 PSNR이 높다고 해서 이미지 퀄리티가 좋다고 볼 수는 없습니다.
아래 그림에서 SRGAN이 PSNR, SSIM은 낮지만 이미지의 blur한 영역이 SRResNet보다 적고 이미지 퀄리티 자체가 original과 더 비슷합니다.
PSNR, SSIM을 높이는 L2 loss가 아니라, GAN loss로 이미지 자체의 퀄리티를 높이는 목적으로 GAN loss를 사용할 수도 있습니다.
Leave a comment