본문 바로가기
Computer Science/Machine Learning

Back-propagation on Affine Layers

by zxcvber 2020. 4. 20.

Affine Layer in Neural Networks (Source: https://deepai.org/machine-learning-glossary-and-terms/affine-layer)

딥러닝을 하다 보면 affine layer를 반드시(!) 만나게 된다. Vectorized input/output에 대해 back-propagation을 처음으로 적용하게 되는 대상이기도 하다. 이 글은 딥러닝이나 affine layer의 역할을 설명하려는 것이 아니고, affine layer에서 gradient 구하는 과정을 헷갈려한 나 자신을 돌아보기 위함이 주목적이다.

두 번째 목적은 복잡한 notation을 정리하며, affine layer에서 gradient를 구하는 모든 과정을 분명하게 밝히는 것에 있다.

표기법. 모든 실수 값 (스칼라)는 italic 체로 (wij,xij,yij), 모든 벡터는 소문자 bold 체로 (w,x,y), 행렬은 대문자 bold 체로 (W,X,Y) 로 표기한다.

Affine layer는 X를 입력으로 받아 Y=XW를 출력한다. 뒤에 bias term B가 붙기도 하지만, 단순하게 생각하기로 하자. 그리고 이 layer의 output Y는 결과적으로 신경망의 loss L을 구하는데 이용된다. Loss는 Y에 대한 함수이므로, L=f(Y) 라 둘 수 있을 것이다. Back-propagation에서 우리가 필요한 값은 LX, LW이다. 이제

(1)LX=LYWT (2)LW=XTLY 를 증명하는 것이 목표이다.


미분

증명을 하기에 앞서, 몇 가지 명확하게 짚고 넘어가야 할 부분들이 있다. 우선 벡터나 행렬로의 미분에 대한 정의를 분명하게 하고 넘어가야 할 것이다. 익숙한 것들부터 시작해보자.

미분가능성에 대한 가정은 모두 생략했다.

스칼라로 미분

스칼라 함수 y가 주어질 때, yx는 굉장히 편안하다. 설명이 필요 없다.

벡터를 스칼라로 미분할 수 있다. 미적분학에서 공간의 매개화된 곡선을 공부할 때 접했던 기억이 있다.

y=(y1,y2,,ym) 일 때, yx=(y1x,y2x,,ymx)

행렬도 스칼라로 미분할 수 있다. 미분방정식을 공부할 때 접했던 기억이 있다.

Y=(yij)m×n 로 정의하면, Yx=(yijx)m×n

단, y,Y의 모든 성분들은 스칼라 함수이다.

스칼라 값으로 미분하는 경우는 굉장히 간단하다. 벡터나 행렬의 각 성분들을 미분하려는 변수로 각각 미분만 하면 되기 때문이다.

벡터로 미분

한 걸음 나아가, 벡터 x=(x1,x2,,xn) 으로 미분하는 경우를 살펴보자.

먼저 스칼라 함수 y를 미분하는 경우이다.

yx=(yx1,yx2,,yxn)

엄청 비직관적인 정의는 아니다. 벡터의 각 성분들로 스칼라 함수를 미분하여 얻은 벡터를 생각할 수 있을 것이다.

이제 벡터 y를 미분하면 상황이 조금 복잡해진다.

y=(y1,y2,,ym) 일 때, yx=(yixj)m×n=(y1x1y1x2y1xny2x1y2x2y2xnymx1ymx2ymxn)

그래도 정의를 자세히 관찰해보면 받아들일 만하다. yxi-번째 행은 스칼라 함수 yi를 벡터 x로 미분한 yix와 정확히 일치하기 때문이다.

그런데 이제 행렬 Y=(yij)m×n 을 벡터로 미분하려고 보니 상황이 좀 이상하다. 벡터로 미분하는 방법에 대한 흐름을 해치지 않으면서 행렬을 벡터로 미분하는 자연스러운 방법을 찾아보면...

Yx(i,j)-성분을 스칼라 함수 yij를 벡터 x로 미분한 yijx 로 정의할 수는 있을 것이다... 그런데 뭔가 이상하지 않은가?

행렬로 미분

마지막으로 한 걸음 더 나아간다. 행렬 X=(xij)m×n 으로 미분해 보자.

위와 마찬가지로 스칼라 함수 y부터 미분해 본다.

yX=(yxij)m×n=(yx11yx12yx1nyx21yx22yx2nyxm1yxm2yxmn)

이제는 행렬로 미분했더니 스칼라를 미분해도 행렬이 나오면서 상황이 복잡하다. 그래도 자세히 관찰해 보면 벡터로 미분했을 때 벡터의 각 성분으로 미분하여 벡터를 만들었던 것처럼, 행렬의 경우에도 행렬의 각 성분으로 미분하여 행렬을 만들 수 있을 것이다.

이제 벡터 y=(y1,y2,,ym) 를 행렬 X로 미분해보자. 이 경우에도 큰 흐름을 해치지 않는 방법으로 정의해보자면... 미분한 결과 yX i-번째 행은 스칼라 함수 yiX의 각 성분으로 미분한 yiX 로 정의할 수는 있을 것이다.

또, 행렬 Y=(yij)m×n 를 행렬 X로 미분하는 경우에도 마찬가지로 생각해 보면, 미분한 결과 YX (i,j)-성분은 스칼라 함수 yijX의 각 성분으로 미분한 yijX 가 될 것이다.

억지로 정의를 확장시키는 느낌이 든다면 정상이다.

참고

Yx(i,j)-성분을 yijx 로 정의할 수는 있을 것이라고 위에서 구렁이 담 넘어가듯 언급하고 지나갔다. 그런데 yijx 의 차원은? 얘는 생각해 보면 사실 벡터이다. 그러면 Yx 는 각 성분이 벡터인 행렬이다. 물론 이런 게 존재하지 않아야 할 이유는 없다.

마찬가지로 YX (i,j)-성분은 yijX 로 정의하면 될 것이라고 했지만, 이 때는 (i,j)-성분이 m×n 행렬이 되며, YX 는 총 m2n2 개의 미분값을 갖고 있게 된다. 물론 이런 object (??) 또한 존재하지 말라는 법은 없다.

행렬을 벡터로 미분하거나, 벡터를 행렬로 미분하거나, 행렬을 행렬로 미분하는 경우에 대해서 위키백과는 'not as widely considered and a notation is not widely agreed upon' 이라고 언급하고 있다. 그래서 이 3가지 경우는 회색 박스에 담지 않았다.

이런 게 있구나 하고 적당히 넘어가는 편이 현명한지도 모르겠다.


연쇄 법칙

두 번째로 짚고 넘어가야 할 부분은 연쇄 법칙 (chain rule)이다. Stewart Calculus 7th Ed.의 명제를 복사해 왔다.

Chain Rule (General Version). Suppose that u is a differentiable funtion of the n variables x1,x2,,xn and each xj is a differentiable function of the m variables t1,t2,,tm. Then u is a function of t1,t2,,tm and
uti=j=1nuxjxjti=ux1x1ti+ux2x2ti++uxnxnti for each i=1,2,,m.

식만 복잡하지 해석해 보면 일변수에서의 연쇄 법칙과 크게 다를 바 없다. 직관적으로 설명해 보자. 현재 상황은 uxj 들의 함수이고, xj 들이 각각 ti 들의 함수이며, 우리는 ti에 대한 u의 변화율을 구하고 싶다. 그렇다면, xj에 대하여 (일변수 때와 마찬가지로) xj에 대한 u의 변화율과 ti에 대한 xj의 변화율을 곱해 uxjxjti 를 얻고, 이 값들을 모두 더해주면 총 변화율이 나오지 않을까?

더불어, 스칼라로 미분하는 경우 뿐만 아니라 벡터나 행렬로 미분하는 경우에도 연쇄 법칙은 동일하게 성립한다.


증명

이제 식 (1),(2)를 증명할 준비가 되었다. 우선 증명하기 전에 행렬의 dimension을 정해주자.

Y=(yij)N×MX=(xij)N×DW=(wij)D×M

이렇게 잡으면 Y=XW 를 만족한다.

(1)의 증명

LX 를 구하려고 한다. LY에 대한 함수이고 YX의 함수이므로 연쇄 법칙이 자연스럽게 떠오른다.

LX=LYYX

하나 문제가 있다면 YX 의 정의가 명확하지 않다는 것이다. 일단 위키백과의 철학을 받아들이고, 다른 방식으로 접근해야 한다. 설령 정의가 명확하다고 하더라도, 실제로 신경망을 구현할 때 계산이 불가능하다. 위에서 언급한대로 이 '행렬'의 dimension은 무지하게 크다. 메모리가 과하게 많이 소비된다.

그렇다면 행렬의 각 성분별로 따져보면 어떨까? LX(i,j)-성분은 Lxij 이다. 얘를 구해보자. 이제는 연쇄 법칙을 자신있게 적용할 수 있다. Lyαβ에 대한 함수이므로, 연쇄 법칙을 사용하면 다음을 얻는다.

(3)Lxij=α=1Nβ=1MLyαβyαβxij

한편 우리는 Y=XW 으로부터

(4)yαβ=k=1Dxαkwkβ

임도 알고 있다. Index가 지나치게 많아 헷갈리지만, 조심스럽게 xij에 대해 미분해 보면

yαβxij=δαiwjβ={wjβ(i=α)0(iα)

를 얻는다. 여기서 δijKronecker Delta function 이다. 이제 (3)을 정리할 수 있다!

Lxij=α=1Nβ=1MLyαβyαβxij=α=1Nβ=1MLyαβδαiwjβ=β=1MLyiβwjβ

마지막 등호는 α=i 인 경우만 남고 나머지는 모두 0 이기 때문에 성립한다.

이제 이 식을 자세히 관찰하면, Lyiβ 는 행렬 LY(i,β)-성분이다. 또, wjβWT(β,j)-성분임을 알 수 있다. 행렬 곱셈의 정의로부터 LxijLYWT(i,j)-성분임을 알 수 있다.

모든 i,j 에 대하여 성립하므로, LX=LYWT 를 얻는다. 

(2)의 증명

(1)의 증명에서 몇 글자를 고치면 (2)의 증명을 얻는다. 연습문제로 남기고 싶지만, 데이터를 낭비하도록 하겠다.

LW 를 구하려고 한다. LY에 대한 함수이고 YW의 함수이므로 연쇄 법칙이 자연스럽게 떠오른다.

LW=LYYW

여기서도 비슷한 이유로 YW 를 계산할 수 없기 때문에, 각 성분별로 따져본다.

LW(i,j)-성분은 Lwij 이므로, 연쇄 법칙을 자신있게 적용하여 다음을 얻는다.

Lwij=α=1Nβ=1MLyαβyαβwij

이제 (4)를 직접 wij에 대해 미분하여

yαβwij=δβjxαi={xαi(j=β)0(jβ)

를 얻고 (5)를 정리하면,

Lwij=α=1Nβ=1MLyαβyαβwij=α=1Nβ=1MLyαβδβjxαi=α=1NxαiLyαj

마지막 등호는 β=j 인 경우만 남고 나머지는 모두 0 이기 때문에 성립한다.

이제 이 식을 자세히 관찰하면, Lyαj 는 행렬 LY(α,j)-성분이다. 또, xαiXT(i,α)-성분임을 알 수 있다. 행렬 곱셈의 정의로부터 LwijXTLY(i,j)-성분임을 알 수 있다.

모든 i,j 에 대하여 성립하므로, LW=XTLY 를 얻는다.  


이제 편안하게 affine layer에서 back-propagation을 할 수 있다. 정의와 표기법, 그리고 핵심이 되는 연쇄법칙만 알고 있으면 어렵지 않게 증명할 수 있는 내용이었는데, 직관적으로 '대충 이렇게 되겠지~' 하고 넘어갔다가, 다시 자세히 증명을 하려다 보니 막상 잘 되지 않았다. 정의와 표기법이 중요하다는 사실을 다시 한 번 깨닫게 된다.

내가 보는 책에서는 이 글에서 증명한 식들을 대충 언급하고 넘어가며, 증명은 커녕 계산해야 하는 행렬들의 size를 바탕으로 식을 유추하면 결론이 나온다는 식으로 얼렁뚱땅 넘어간다. 개인적으로 굉장히 맘에 들지 않는다. 이것이 직접 증명을 시도한 이유 중 하나이기도 하다. 구현을 하거나, 실무에 활용하는 입장에서는 이론적 배경이 중요하지 않을 수 있다. 하지만, 이론적 배경을 잊어버리더라도 최소한 한 번쯤은 읽어보고 이해하고 지나가야, 더욱 근본이 탄탄한 지식을 쌓을 수 있지 않을까.

댓글