상세 컨텐츠

본문 제목

[4주차/김희준/논문리뷰]Generative Adversarial Nets

2023 Summer Session/CV Team 1

by heestogram 2023. 8. 5. 16:30

본문

 

논문 제목: Generative Adversarial Nets

 

작성자: 17기 김희준

 

 


1. Introduction

딥러닝은 natural images, audio waveform, natural language 등에서 큰 성과를 거두었다. 특히 discriminative model이 가장 뛰어난데, 이는 고차원의 inputclass label을 매겨주는 모델로, 우리가 흔히 알고 있는 classification model 등을 의미한다.

 

그러나 deep generative modelmaximum likelihood estimation에 쓰이는 interactable probabilistic computation에서 적잖은 어려움을 겪었고, 본 논문은 새로운 방식의 adversarial nets framework를 제안한다.

 

adversarial netGenerative model(G)Discriminative model(D)이 서로 경쟁하면서 모델이 학습된다. G는 가짜 이미지를 생성해내서 D에게 보낸다. 이 때 D는 실제 이미지와 비교하여 G가 보낸 이미지가 진짜인지 가짜인지를 판별한다. 논문에선 G를 위조지폐범, D를 경찰로 비유한다.

 

결국 G는 어느 순간부터 D가 진짜로 속을 만큼 진짜같은 가짜를 만들 수 있을 것이고, 이 과정은 approximate inferenceMarkov chain을 활용하지 않고도 수행된다는 점에서 앞서 짚은 한계를 극복한 케이스이다.

 


3. Adversarial nets

x라는 input이 주어졌을 때 G의 분포인 p_g를 학습하기 위해 noise variablep_z(Z)를 정의하고, 이를 data spaceG(z; θ_g)mapping해야 한다. 이 때 G는 미분 가능한 multilayer perceptron 함수이다. 우리가 서론에서 언급한 G와 동일하다.

 

두 번째 multilayer perceptron으로 D(x; θ_d)를 정의한다. 이 역시 서론에서 언급한 D와 동일하다. D는 입력되는 xp_g라는 가짜 분포가 아니라 진짜 data 분포에 속할 확률을 single scalar로 출력한다.

 

 

위 식이 objective function이다. D(x)는 입력된 x가 진짜 데이터셋일 확률을 나타낸다. 이 때 D는 정확한 label을 지정할 확률을 최대화하기 위해 학습되어야 하고, Glog(1-D(G(z)))가 최소화되게끔 학습되어야 한다. , D(G(z))1이 되게끔 학습되어야 한다는 것인데 이는 즉 가짜 데이터인 G(z)가 입력되었을 때 D가 진짜로 착각하여 1로 레이블링하게 만들어야 한단 의미이다.

 

결론적으로 DV(D,G)를 최대화시키고, GV(D,G)를 최소화시키는 objective를 갖게 된다.

 

Dinner loop 안에서 학습시키는 것은 계산적인 한계가 있고, overfitting의 위험도 크다. 따라서 Dk step만큼 학습시키고, G1 step 학습시키는 과정을 번갈아가며 학습시킨다. 논문에서는 k=1로 설정하여 DG가 한번씩 번갈아가며 학습되게끔 했다. 구체적인 내용은 아래 사진 Algorithm1에 설명되어있다.

 

 

다만 훈련 초기에 G의 생성 성능은 좋지 못하기 때문에 D가 높은 confidence로 가짜를 판별해낸다. 이러한 경우엔 log(1-D(G(z)))gradient를 계산할 때 너무 작은 값이 나오므로 log(1-D(G(z)))를 최소화하는 것보다 logD(G(z))를 최대화하도록 학습하는 것이 유리하다.

 

  • 파란색 점선: D
  • 검은색 점선: 실제 데이터
  • 초록색 실선: G가 생성한 데이터

-(a): 학습이 되지 않은 상태. 실제 데이터의 분포와 가짜 데이터의 분포가 차이가 있다. D는 이 차이를 어느 정도 포착은 하지만 살짝 들쭉날쭉한 모습을 보인다.

-(b): D를 학습시킨 상태. 진짜 데이터가 모여있는 경우 1, 가짜 데이터는 0으로 잘 분류하고 있는 것을 알 수 있다.

-(c): G를 학습시킨 상태. 거의 진짜 데이터의 분포와 유사하게 가짜 데이터를 생성한 모습이다.

-(d): (a)~(c)를 계속해서 반복하여 진짜 데이터와 완전히 유사하게 가짜 데이터를 생성했고, 그 결과 D가 두 데이터를 구분할 수 없어 0.5output이 출력된다.

 


4.1. Global Optimality of p_g=p_data

G가 고정되어있을 때 최적의 D는 다음과 같다.

 

 

즉 아래 식을 최대화시키는 D가 위 식이란 의미이다.

 

 

미분을 해서 최적의 D를 구하는 증명과정은 아래와 같다.

 

 

이 때 찾은 최적의 D를 사용하면 V(G,D)를 최대화할 수 있다.

 

 

V(G,D)를 최대화하기 위해선 D(x) 자리에 앞서 구한 최적의 D를 대입하면 된다.

 

반면에 GV(G,D)를 최소화시키는 것이 목적이다. C(G)가 최솟값을 갖기 위해선 p_data(x)=p_g(x)여야 한다. 그 경우 global minimumlog(1/2)+log(1/2)=-log4가 된다는 것을 어렵지 않게 알 수 있다.

 

, -log4라는 결과값은 D는 계속해서 최대화를 시도하고 G가 최소화를 시도하는 minmax game에서 완벽하게 G가 진짜 같은 데이터를 생산해내는 minimum global에 도달한 것이라고 볼 수 있다.

 

그렇다면 어째서 p_data(x)=p_g(x)여야만 최솟값을 갖게 되는 것인지 증명해보자.

 

 

C(G)를 풀어서 쓰면 위와 같다. KL(p||q)는 확률분포 p와 확률분포 q가 얼마나 다른지 측정하는 지표이다.

 

 

위 증명에서 JSD(P||Q)P라는 확률분포와 Q라는 확률분포의 거리를 측정하는 지표로, 두 분포가 같을 경우 0이고 아닌 경우 양수이다. , C(G)global minimumlog4임을 증명해냈다.

 


4.2. Convergence of Algorithm1

Algorithm1이 제대로 작동하는지를 증명해야 한다.

 

D가 학습되며 G의 생산 퀄리티가 높아져갈 때 아래 criterion을 잘 개선해나가면 P_g 분포가 P_data 분포, 즉 진짜 데이터 확률분포에 수렴하게 된다.

 

 

이를 증명해보자. V(G,D)P_g의 함수 U(P_g, D)로 가정하자. U(P_g, D)P_g에서 convex(볼록)하다. convex한 특성 덕분에 U함수의 supremum(상한)subderivatives(하방미분)하는 행위가 곧 P_ggradient를 계산하는 행위라고 한다. 이로써 P_g를 적당히만 업데이트해주어도 진짜 데이터 분포에 수렴하는 것을 증명할 수 있다.

 


5. Experiments

이제 adversarial nets 모델을 MNIST, TFD, CIFAR-10 데이터셋으로 훈련을 시킨다. D는 훈련을 할 때 dropout층과 maxout 활성화함수를 사용한다. Gsigmoidrectifier linear 활성화함수를 사용한다.

 

훈련은 Parzen windowfitting하는 식으로 이루어진다. Parzen window란 데이터가 특정 분포를 따르지 않는다는 가정 하에 확률을 추정해나가는 non-parametric(비모수) 방식이다. 이 때 얻은 log-likelihoodtest 데이터의 확률을 추정했다.

 

 

그 결과 다른 모델에 비해 압도적으로 좋은 성과를 보인 것은 아니지만 괄목할만한 잠재성을 가지고 있다.

 


6. Advantages and disadvantages

 

Adversarial Net의 단점은 아래와 같다.

  1. P_g(x)가 명시적으로 존재하지 않는다.
  2. GD의 학습이 잘 동기화되어야 하는데, 어느 한쪽만 가중치를 업데이트하는 경우 Helvetica scenario에 봉착하여 제대로 된 G를 얻을 수 없다.

 

장점은 아래와 같다.

  1. Markov chain을 사용할 필요가 없고 back-propagation으로 gradient를 간편히 계산할 수 있다.
  2. 추론(inference)과정이 필요하지 않다.
  3. Markov chain 방식보다 훨씬 sharp(선명)한 이미지를 생성할 수 있다.

 

관련글 더보기

댓글 영역