2가지 모델을 훈련시켜 적대적인 프로세스를 통해 생성모델을 평가하는 새로운 프레임워크를 제시한다.
데이터를 제공하는 생성모델 G와 해당 데이터가 G가 만든것인지, 원본인지 확률을 평가하는 구별모델 D. G는 D가 실수할 확률을 높이는 방향으로 학습된다.(잘 구별하지 못하도록). 이 프레임워크는 2인용 minimax 게임과 비슷하다. 임의의 경우에서 G 와 D는 각각 유일한 답이 존재하며 모두 1/2의 확률을 갖는다. G와 D가 다층퍼셉트론의 구조를 가지고 있다면 역전파를 이용한 학습이 가능하다.
딥러닝은 계층적이고 풍부한 모델을 학습하여 이미지, 음성을 포함한 오디오 파형, 자연어 특징과 같은 데이터를 인공지능을 활용하여 확률로 나타내어 준다. 성공한 딥러닝 모델들은 역전파와 드롭아웃 알고리즘에 기초하였으며, 선형적인 단위를 사용하여 가중치에 잘 작용했다. Maximum Likelihood Estimation(MLE)과 같은 전략들에서 나오는 많은 확률론적 계산의 어려움과 생성적인 맥락에서 선형단위의 이점을 활용하는 것의 어려움이 generation 분야에서 있어왔다. 이러한 어려움을 피하는 새로운 생성모델을 소개한다.
구별모델은 샘플이 원본인지 생성모델이 만든 데이터였는지 결정하는 것을 학습한다. 생성모델은 위조지폐를 만드는 위조범과 비슷하고 구별모델은 위조지폐를 검거하려는 경찰과 비슷하다. 이러한 방식은 진짜와 가짜가 구별되지 않게끔 서로를 끌어올린다. 이 프레임워크는 다양한 모델과 최적화 알고리즘에 대해 특정한 훈련 알고리즘을 생성할 수 있다. 생성 모델이 다층 퍼셉트론을 통해 랜덤 노이즈가 부가된 샘플을 생성하고, 구별 모델도 다층 퍼셉트론으로 이루어져 있다. 역전파, 드롭아웃 알고리즘만을 사용하여 두 모델을 훈련하고, 생성모델은 순전파만을 사용하여 샘플을 생성할 수 있다. 다른 inference나 Markov chain은 필요없다.
딥러닝에서 잠재 변수가 있는 무향 그래픽 모델은 랜덤 변수의 모든 상태에 대한 전역 합산/적분으로 정규화된 비정규화 전위 함수의 산물로 표현된다. 이러한 모델은 마르코프 연쇄 몬테카를로(MCMC) 방법으로 추정할 수 있지만, 대부분의 경우에는 양과 그레이디언트를 다루기 어려워 한다. MCMC를 기반으로 하는 학습 알고리즘들은 혼합 문제에 영향을 미친다.
Deep belief networks(DBNs) : 무향 계층과 여러 개의 유향 계층을 가지는 하이브리드 모델로, 빠른 근사 계층별 훈련 기준이 있지만 무향과 유향 모델 모두에서 계산이 어렵다.
로그 우도를 근사하거나 제한하지 않는 대체 기준으로는 점수 매칭(score matching)과 소음 대비 추정(NCE, Noise-Contrastive Estimation)이 있으며, 이들은 학습된 확률 밀도를 정규화 상수까지 분석적으로 지정해야 한다. 하지만 몇몇 흥미로운 생성 모델에서는 다루기 쉬운 비정규 확률 밀도를 도출하는 것이 불가능할 수도 있다.
일부 모델에서는 확률 분포를 명시적으로 정의하는 대신 원하는 분포에서 샘플을 추출하여 생성 모델을 훈련한다. 이 접근 방식은 역전파로 모델을 훈련할 수 있는 장점이 있다.
최근의 주목할 만한 연구로는 generative stochastic network(GSN) 프레임워크가 있으며, 이는 일반화된 denoising 오토 인코더를 확장하는 것이다. GSN과 비교하여, 적대신경망은 샘플링에 Markov Chain 이 필요 없다. 적대신경망은 피드백 루프가 필요 없기 때문에 역전파의 성능을 향상시킬 수 있으며, 피드백 루프로 사용될 때 문제가 있는 단계별 선형 단위를 더 잘 활용할 수 있다.
위의 함수는 Discriminator model 와 Generator model 를 학습시키는 목적 함수이다.
dataset에서 추출해낸 값
D : 입력받은 값이 dataset에 속한 real data일 확률을 나타내는 함수, 를 최대화하도록 학습함.
노이즈 분포에서 추출해낸 노이즈 값.
z를 입력받아 는 새로운 가짜 데이터를 생성한다. 는 물론 이런 경우의 의 출력값을 최소화하도록 학습하게 되며, 반대로 는 이를 최대화하도록 학습하게 된다. 즉, Generator로 만들어낸 가짜 데이터들이 진짜 데이터셋의 분포를 최대한 따라가도록 학습하게 된다.
조금 더 실질적인 학습법을 제시하고 있는데, 초기 학습 시기에는 가 학습이 되어있지 않아 데이터셋과 한참 동떨어진 값을 내놓을 것이고 그렇게 되면 는 그 절댓값이 과도하게 커질 수 있다. 때문에 초기에 한해 가 너무 말도 안되는 결과물을 내놓는다면 를 최대화하는 방향으로 학습하도록 할 수 있다.
위의 그림은 학습 과정을 이해하기 쉽게 도식화한 것이다. 초록 선을 Generator 결과값의 분포, 검은 점들은 dataset의 분포, 파란 점선은 Discriminator의 경계선을 나타낸 것이다. 처음에는 노이즈로부터 Generate된 결과물들이 학습이 덜 된 Generator 모델의 분포를 따르기에 실제 데이터셋과 거리가 있는 것을 확인할 수 있다. 그러나 학습이 진행될수록 Generator의 분포가 데이터셋의 분포를 따라가게 되며, 이상적인 결과인 (d)에서는 Discriminator가 마침내 실제 데이터와 생성 데이터의 차이를 찾기 못해 어떠한 경우에도 반반의 확률을 출력하게 된다.
앞서 제시된 GAN의 minmax problem 이 제대로 작동한다면, minmax problem이 global optimum일 때 pg = pdata여야 하고 우리가 제안하는 알고리즘이 실제로 equation을 최적화하여 global optimum을 가질 수 있어야한다.
GAN 의 미니배치 sgd 훈련 알고리즘은 다음과 같다. discriminator model에 적용하는 k는 단계의 수를 나타내는 하이퍼파라미터이고 k=1을 사용했다.
먼저, G가 주어진 경우 discriminator D에 대한 훈련 기준은 V(G,D)를 최대화하는 것이다. 이를 통해 주어진 generator G에 대한 최적의 discriminator D를 구한다. V(G,D)를 D(x)에 대해 편미분하여 구한다.
증명 : 주어진 어떤 G에 대해 D의 훈련법은 V(G,D)의 양을 최대화 시키는 것이다.
D에 대한 훈련 목표는 조건부 확률 P(Y = y|x)를 추정하기 위한 로그 우도를 최대화하는 것으로 해석될 수 있다. optimal D를 원래의 목적함수에 넣어 목적함수를 재구성한다.
C(G)는 generator가 최소화하고자 하는 기준(virtual training criterion)이 되고, 이 식의 global optimal은 p_g = p_data 일 때 만족한다.
두 분포 사이의 JSD는 항상 음이 아니며 두 분포가 같을 때 0이 되기 때문에 -log4가 C(G)의 global minimum이고 여기서 p_g = p_data 이다. 즉, 실제 데이터를 완벽하게 복제하는 generator model임을 보여준다.
- KL : Kullback-Leibler divergence, p라는 분포가 있을때 q와 p가 얼마나 다른지 측정하는 값(대칭적이지 않음)
- JSD : Jensen-Shannon divergence, 두 확률 분포의 거리를 측정하며 두 분포가 같을 때 0이 되고 항상 양수이다.
실제로, 적대적 네트워크는 함수 G(z;θg)를 통해 제한된 확률분포 패밀리를 표현하며, 우리는 pg 자체보다는 θg를 최적화한다. 다층 퍼셉트론을 사용하여 G를 정의하면 매개변수 공간에 여러 개의 중요한 점들이 도입된된다. 그러나 다층 퍼셉트론의 우수한 성능은 이론적인 보장이 부족함에도 불구하고 합리적인 모델로 사용될 수 있.
Gaussian Parzen window를 G에 의해 생성된 샘플들에 fitting하고 이렇게 추정된 분포 하에 얻어진 log-likelihood를 확인함으로써 저자들은 하에서 test set 데이터의 확률을 추정하였다. 해당 방법을 옳은 평가 척도라고 할 수 없지만 이전 모델과 비교했을 때, 경쟁력을 갖추고 있고, 잠재력을 보여준다.
이렇게 G가 생성해낸 샘플이 기존 방법으로 만든 샘플보다 좋다고 주장할 수 없지만, 더 나은 생성 모델과 경쟁할 수 있다고 생각하며, adversarial framework의 잠재력을 강조한다.
단점 : 훈련 중에 D를 G와 잘 동기화해야 한다.
장점 :
Reference
https://www.youtube.com/watch?v=AVvlDmhHgC4
댓글 영역