본문 바로가기
Computer Science/Machine Learning

Back-propagation on Affine Layers

by zxcvber 2020. 4. 20.

Affine Layer in Neural Networks (Source: https://deepai.org/machine-learning-glossary-and-terms/affine-layer)

딥러닝을 하다 보면 affine layer를 반드시(!) 만나게 된다. Vectorized input/output에 대해 back-propagation을 처음으로 적용하게 되는 대상이기도 하다. 이 글은 딥러닝이나 affine layer의 역할을 설명하려는 것이 아니고, affine layer에서 gradient 구하는 과정을 헷갈려한 나 자신을 돌아보기 위함이 주목적이다.

두 번째 목적은 복잡한 notation을 정리하며, affine layer에서 gradient를 구하는 모든 과정을 분명하게 밝히는 것에 있다.\(\newcommand{\X}{\mathbf{X}}\newcommand{\Y}{\mathbf{Y}}\newcommand{\W}{\mathbf{W}}\newcommand{\x}{\mathbf{x}}\newcommand{\y}{\mathbf{y}}\newcommand{\w}{\mathbf{w}}\newcommand{\d}{\partial}\)

표기법. 모든 실수 값 (스칼라)는 italic 체로 (\(w_{ij}, x_{ij}, y_{ij}\)), 모든 벡터는 소문자 bold 체로 (\(\w, \x, \y\)), 행렬은 대문자 bold 체로 (\(\W, \X, \Y\)) 로 표기한다.

Affine layer는 \(\mathbf{X}\)를 입력으로 받아 \(\mathbf{Y} = \mathbf{X\cdot W}\)를 출력한다. 뒤에 bias term \(\mathbf{B}\)가 붙기도 하지만, 단순하게 생각하기로 하자. 그리고 이 layer의 output \(\mathbf{Y}\)는 결과적으로 신경망의 loss \(L\)을 구하는데 이용된다. Loss는 \(\Y\)에 대한 함수이므로, $$L = f(\Y)$$ 라 둘 수 있을 것이다. Back-propagation에서 우리가 필요한 값은 \(\dfrac{\partial L}{\partial \X}\), \(\dfrac{\partial L}{\partial \W}\)이다. 이제

$$\dfrac{\partial L}{\partial \X} = \dfrac{\partial L}{\partial \Y} \cdot \W^T \tag{1}$$ $$\dfrac{\partial L}{\partial \W} =  \X^T \cdot \dfrac{\partial L}{\partial \Y} \tag{2}$$ 를 증명하는 것이 목표이다.


미분

증명을 하기에 앞서, 몇 가지 명확하게 짚고 넘어가야 할 부분들이 있다. 우선 벡터나 행렬로의 미분에 대한 정의를 분명하게 하고 넘어가야 할 것이다. 익숙한 것들부터 시작해보자.

미분가능성에 대한 가정은 모두 생략했다.

스칼라로 미분

스칼라 함수 \(y\)가 주어질 때, \(\dfrac{\d y}{\d x}\)는 굉장히 편안하다. 설명이 필요 없다.

벡터를 스칼라로 미분할 수 있다. 미적분학에서 공간의 매개화된 곡선을 공부할 때 접했던 기억이 있다.

\(\y = (y_1, y_2, \dots, y_m)\) 일 때, $$\dfrac{\d\y}{\d x} = \left(\frac{\d y_1}{\d x}, \frac{\d y_2}{\d x}, \dots, \frac{\d y_m}{\d x}\right) $$

행렬도 스칼라로 미분할 수 있다. 미분방정식을 공부할 때 접했던 기억이 있다.

\(\mathbf{Y} = (y_{ij})_{m\times n}\) 로 정의하면, $$\frac{\d\Y}{\d x} = \left(\dfrac{\d y_{ij}}{\d x}\right)_{m\times n}$$

단, \(\y, \Y\)의 모든 성분들은 스칼라 함수이다.

스칼라 값으로 미분하는 경우는 굉장히 간단하다. 벡터나 행렬의 각 성분들을 미분하려는 변수로 각각 미분만 하면 되기 때문이다.

벡터로 미분

한 걸음 나아가, 벡터 \(\x = (x_1, x_2, \dots, x_n)\) 으로 미분하는 경우를 살펴보자.

먼저 스칼라 함수 \(y\)를 미분하는 경우이다.

$$\frac{\d y}{\d \x} = \left(\frac{\d y}{\d x_1}, \frac{\d y}{\d x_2}, \dots, \frac{\d y}{\d x_n}\right)$$

엄청 비직관적인 정의는 아니다. 벡터의 각 성분들로 스칼라 함수를 미분하여 얻은 벡터를 생각할 수 있을 것이다.

이제 벡터 \(\y\)를 미분하면 상황이 조금 복잡해진다.

\(\y = (y_1, y_2, \dots, y_m)\) 일 때, $$\dfrac{\d\y}{\d\x} = \left(\frac{\d y_i}{\d x_j}\right)_{m\times n} = \begin{pmatrix}
\frac{\d y_1}{\d x_1} &
\frac{\d y_1}{\d x_2} & \cdots &
\frac{\d y_1}{\d x_n} \\

\frac{\d y_2}{\d x_1} &
\frac{\d y_2}{\d x_2} & \cdots &
\frac{\d y_2}{\d x_n} \\
\vdots & \vdots & \ddots & \vdots \\

\frac{\d y_m}{\d x_1} &
\frac{\d y_m}{\d x_2} & \cdots &
\frac{\d y_m}{\d x_n}
\end{pmatrix}$$

그래도 정의를 자세히 관찰해보면 받아들일 만하다. \(\dfrac{\d\y}{\d\x}\) 의 \(i\)-번째 행은 스칼라 함수 \(y_i\)를 벡터 \(\x\)로 미분한 \(\dfrac{\d y_i}{\d\x}\)와 정확히 일치하기 때문이다.

그런데 이제 행렬 \(\Y = (y_{ij})_{m\times n}\) 을 벡터로 미분하려고 보니 상황이 좀 이상하다. 벡터로 미분하는 방법에 대한 흐름을 해치지 않으면서 행렬을 벡터로 미분하는 자연스러운 방법을 찾아보면...

\(\dfrac{\d\Y}{\d\x}\) 의 \((i, j)\)-성분을 스칼라 함수 \(y_{ij}\)를 벡터 \(\x\)로 미분한 \(\dfrac{\d y_{ij}}{\d\x}\) 로 정의할 수는 있을 것이다... 그런데 뭔가 이상하지 않은가?

행렬로 미분

마지막으로 한 걸음 더 나아간다. 행렬 \(\X = (x_{ij})_{m\times n}\) 으로 미분해 보자.

위와 마찬가지로 스칼라 함수 \(y\)부터 미분해 본다.

$$\dfrac{\d y}{\d\X} = \left(\frac{\d y}{\d x_{ij}}\right)_{m\times n} = \begin{pmatrix} \frac{\d y}{\d x_{11}} & \frac{\d y}{\d x_{12}} & \cdots & \frac{\d y}{\d x_{1n}} \\ \frac{\d y}{\d x_{21}} & \frac{\d y}{\d x_{22}} & \cdots & \frac{\d y}{\d x_{2n}} \\ \vdots & \vdots & \ddots & \vdots \\ \frac{\d y}{\d x_{m1}} & \frac{\d y}{\d x_{m2}} & \cdots & \frac{\d y}{\d x_{mn}} \end{pmatrix}
$$

이제는 행렬로 미분했더니 스칼라를 미분해도 행렬이 나오면서 상황이 복잡하다. 그래도 자세히 관찰해 보면 벡터로 미분했을 때 벡터의 각 성분으로 미분하여 벡터를 만들었던 것처럼, 행렬의 경우에도 행렬의 각 성분으로 미분하여 행렬을 만들 수 있을 것이다.

이제 벡터 \(\y=(y_1, y_2, \dots, y_m)\) 를 행렬 \(\X\)로 미분해보자. 이 경우에도 큰 흐름을 해치지 않는 방법으로 정의해보자면... 미분한 결과 \(\dfrac{\d\y}{\d\X}\) 의 \(i\)-번째 행은 스칼라 함수 \(y_i\)를 \(\X\)의 각 성분으로 미분한 \(\dfrac{\d y_i}{\d\X}\) 로 정의할 수는 있을 것이다.

또, 행렬 \(\Y = (y_{ij})_{m\times n}\) 를 행렬 \(\X\)로 미분하는 경우에도 마찬가지로 생각해 보면, 미분한 결과 \(\dfrac{\d\Y}{\d\X}\) 의 \((i, j)\)-성분은 스칼라 함수 \(y_{ij}\)를 \(\X\)의 각 성분으로 미분한 \(\dfrac{\d y_{ij}}{\d\X}\) 가 될 것이다.

억지로 정의를 확장시키는 느낌이 든다면 정상이다.

참고

\(\dfrac{\d\Y}{\d\x}\) 의 \((i, j)\)-성분을 \(\dfrac{\d y_{ij}}{\d\x}\) 로 정의할 수는 있을 것이라고 위에서 구렁이 담 넘어가듯 언급하고 지나갔다. 그런데 \(\dfrac{\d y_{ij}}{\d\x}\) 의 차원은? 얘는 생각해 보면 사실 벡터이다. 그러면 \(\dfrac{\d\Y}{\d\x}\) 는 각 성분이 벡터인 행렬이다. 물론 이런 게 존재하지 않아야 할 이유는 없다.

마찬가지로 \(\dfrac{\d\Y}{\d\X}\) 의 \((i, j)\)-성분은 \(\dfrac{\d y_{ij}}{\d\X}\) 로 정의하면 될 것이라고 했지만, 이 때는 \((i, j)\)-성분이 \(m\times n\) 행렬이 되며, \(\dfrac{\d\Y}{\d\X}\) 는 총 \(m^2n^2\) 개의 미분값을 갖고 있게 된다. 물론 이런 object (??) 또한 존재하지 말라는 법은 없다.

행렬을 벡터로 미분하거나, 벡터를 행렬로 미분하거나, 행렬을 행렬로 미분하는 경우에 대해서 위키백과는 'not as widely considered and a notation is not widely agreed upon' 이라고 언급하고 있다. 그래서 이 3가지 경우는 회색 박스에 담지 않았다.

이런 게 있구나 하고 적당히 넘어가는 편이 현명한지도 모르겠다.


연쇄 법칙

두 번째로 짚고 넘어가야 할 부분은 연쇄 법칙 (chain rule)이다. Stewart Calculus 7th Ed.의 명제를 복사해 왔다.

Chain Rule (General Version). Suppose that \(u\) is a differentiable funtion of the \(n\) variables \(x_1, x_2, \dots, x_n\) and each \(x_j\) is a differentiable function of the \(m\) variables \(t_1, t_2, \dots, t_m\). Then \(u\) is a function of \(t_1, t_2, \dots, t_m\) and
$$\frac{\d u}{\d t_i} = \sum_{j=1}^n \frac{\d u}{\d x_j}\frac{\d x_j}{\d t_i} = \frac{\d u}{\d x_1}\frac{\d x_1}{\d t_i}+\frac{\d u}{\d x_2}\frac{\d x_2}{\d t_i}+\cdots+\frac{\d u}{\d x_n}\frac{\d x_n}{\d t_i}$$ for each \(i=1, 2, \dots, m\).

식만 복잡하지 해석해 보면 일변수에서의 연쇄 법칙과 크게 다를 바 없다. 직관적으로 설명해 보자. 현재 상황은 \(u\)가 \(x_j\) 들의 함수이고, \(x_j\) 들이 각각 \(t_i\) 들의 함수이며, 우리는 \(t_i\)에 대한 \(u\)의 변화율을 구하고 싶다. 그렇다면, \(x_j\)에 대하여 (일변수 때와 마찬가지로) \(x_j\)에 대한 \(u\)의 변화율과 \(t_i\)에 대한 \(x_j\)의 변화율을 곱해 \(\dfrac{\d u}{\d x_j} \dfrac{\d x_j}{\d t_i}\) 를 얻고, 이 값들을 모두 더해주면 총 변화율이 나오지 않을까?

더불어, 스칼라로 미분하는 경우 뿐만 아니라 벡터나 행렬로 미분하는 경우에도 연쇄 법칙은 동일하게 성립한다.


증명

이제 식 \((1), (2)\)를 증명할 준비가 되었다. 우선 증명하기 전에 행렬의 dimension을 정해주자.

$$\Y = (y_{ij})_{N\times M} \quad \X = (x_{ij})_{N\times D} \quad \W = (w_{ij})_{D\times M}$$

이렇게 잡으면 \(\Y = \X\cdot \W\) 를 만족한다.

\((1)\)의 증명

\(\dfrac{\d L}{\d \X}\) 를 구하려고 한다. \(L\)은 \(\Y\)에 대한 함수이고 \(\Y\) 는 \(\X\)의 함수이므로 연쇄 법칙이 자연스럽게 떠오른다.

$$\dfrac{\d L}{\d \X} = \dfrac{\d L}{\d \Y}\dfrac{\d \Y}{\d \X}$$

하나 문제가 있다면 \(\dfrac{\d\Y}{\d\X}\) 의 정의가 명확하지 않다는 것이다. 일단 위키백과의 철학을 받아들이고, 다른 방식으로 접근해야 한다. 설령 정의가 명확하다고 하더라도, 실제로 신경망을 구현할 때 계산이 불가능하다. 위에서 언급한대로 이 '행렬'의 dimension은 무지하게 크다. 메모리가 과하게 많이 소비된다.

그렇다면 행렬의 각 성분별로 따져보면 어떨까? \(\dfrac{\d L}{\d \X}\) 의 \((i, j)\)-성분은 \(\dfrac{\d L}{\d x_{ij}}\) 이다. 얘를 구해보자. 이제는 연쇄 법칙을 자신있게 적용할 수 있다. \(L\)이 \(y_{\alpha\beta}\)에 대한 함수이므로, 연쇄 법칙을 사용하면 다음을 얻는다.

$$\begin{aligned}\dfrac{\d L}{\d x_{ij}} &= \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\dfrac{\d y_{\alpha\beta}}{\d x_{ij}}\end{aligned} \tag{3}$$

한편 우리는 \(\Y = \X\cdot \W\) 으로부터

$$y_{\alpha\beta} = \sum_{k=1}^D x_{\alpha k}w_{k\beta} \tag{4}$$

임도 알고 있다. Index가 지나치게 많아 헷갈리지만, 조심스럽게 \(x_{ij}\)에 대해 미분해 보면

$$\dfrac{\d y_{\alpha\beta}}{\d x_{ij}} = \delta_{\alpha i}w_{j\beta} = \begin{cases} w_{j\beta} & (i = \alpha) \\ 0 & (i \neq \alpha)\end{cases}$$

를 얻는다. 여기서 \(\delta_{ij}\) 는 Kronecker Delta function 이다. 이제 \((3)\)을 정리할 수 있다!

$$\begin{aligned}\dfrac{\d L}{\d x_{ij}} &= \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\dfrac{\d y_{\alpha\beta}}{\d x_{ij}} = \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\delta_{\alpha i}w_{j\beta} = \sum_{\beta = 1}^M \dfrac{\d L}{\d y_{i\beta}}w_{j\beta}\end{aligned}$$

마지막 등호는 \(\alpha = i\) 인 경우만 남고 나머지는 모두 \(0\) 이기 때문에 성립한다.

이제 이 식을 자세히 관찰하면, \(\dfrac{\d L}{\d y_{i\beta}}\) 는 행렬 \(\dfrac{\d L}{\d \Y}\) 의 \((i, \beta)\)-성분이다. 또, \(w_{j\beta}\) 는 \(\W^T\) 의 \((\beta, j)\)-성분임을 알 수 있다. 행렬 곱셈의 정의로부터 \(\dfrac{\d L}{\d x_{ij}}\) 는 \(\dfrac{\d L}{\d Y}\cdot \W^T\) 의 \((i, j)\)-성분임을 알 수 있다.

모든 \(i, j\) 에 대하여 성립하므로, $$\dfrac{\partial L}{\partial \X} = \dfrac{\partial L}{\partial \Y} \cdot \W^T $$ 를 얻는다.  \( \square\)

\((2)\)의 증명

\((1)\)의 증명에서 몇 글자를 고치면 \((2)\)의 증명을 얻는다. 연습문제로 남기고 싶지만, 데이터를 낭비하도록 하겠다.

\(\dfrac{\d L}{\d \W}\) 를 구하려고 한다. \(L\)은 \(\Y\)에 대한 함수이고 \(\Y\) 는 \(\W\)의 함수이므로 연쇄 법칙이 자연스럽게 떠오른다.

$$\dfrac{\d L}{\d \W} = \dfrac{\d L}{\d \Y}\dfrac{\d \Y}{\d \W}$$

여기서도 비슷한 이유로 \(\dfrac{\d\Y}{\d\W}\) 를 계산할 수 없기 때문에, 각 성분별로 따져본다.

\(\dfrac{\d L}{\d \W}\) 의 \((i, j)\)-성분은 \(\dfrac{\d L}{\d w_{ij}}\) 이므로, 연쇄 법칙을 자신있게 적용하여 다음을 얻는다.

$$\begin{aligned}\dfrac{\d L}{\d w_{ij}} &= \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\dfrac{\d y_{\alpha\beta}}{\d w_{ij}}\end{aligned}$$

이제 \((4)\)를 직접 \(w_{ij}\)에 대해 미분하여

$$\dfrac{\d y_{\alpha\beta}}{\d w_{ij}} = \delta_{\beta j}x_{\alpha i} = \begin{cases} x_{\alpha i} & (j = \beta) \\ 0 & (j \neq \beta)\end{cases}$$

를 얻고 \((5)\)를 정리하면,

$$\begin{aligned}\dfrac{\d L}{\d w_{ij}} &= \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\dfrac{\d y_{\alpha\beta}}{\d w_{ij}} = \sum_{\alpha = 1}^N \sum_{\beta=1}^M \dfrac{\d L}{\d y_{\alpha\beta}}\delta_{\beta j}x_{\alpha i} = \sum_{\alpha = 1}^N x_{\alpha i} \dfrac{\d L}{\d y_{\alpha j}}\end{aligned}$$

마지막 등호는 \(\beta = j\) 인 경우만 남고 나머지는 모두 \(0\) 이기 때문에 성립한다.

이제 이 식을 자세히 관찰하면, \(\dfrac{\d L}{\d y_{\alpha j}}\) 는 행렬 \(\dfrac{\d L}{\d \Y}\) 의 \((\alpha, j)\)-성분이다. 또, \(x_{\alpha i}\) 는 \(\X^T\) 의 \((i,\alpha)\)-성분임을 알 수 있다. 행렬 곱셈의 정의로부터 \(\dfrac{\d L}{\d w_{ij}}\) 는 \(\X^T\cdot \dfrac{\d L}{\d Y}\) 의 \((i, j)\)-성분임을 알 수 있다.

모든 \(i, j\) 에 대하여 성립하므로, $$\dfrac{\d L}{\d \W} = \X^T \cdot \dfrac{\d L}{\d \Y} $$ 를 얻는다.  \( \square\)


이제 편안하게 affine layer에서 back-propagation을 할 수 있다. 정의와 표기법, 그리고 핵심이 되는 연쇄법칙만 알고 있으면 어렵지 않게 증명할 수 있는 내용이었는데, 직관적으로 '대충 이렇게 되겠지~' 하고 넘어갔다가, 다시 자세히 증명을 하려다 보니 막상 잘 되지 않았다. 정의와 표기법이 중요하다는 사실을 다시 한 번 깨닫게 된다.

내가 보는 책에서는 이 글에서 증명한 식들을 대충 언급하고 넘어가며, 증명은 커녕 계산해야 하는 행렬들의 size를 바탕으로 식을 유추하면 결론이 나온다는 식으로 얼렁뚱땅 넘어간다. 개인적으로 굉장히 맘에 들지 않는다. 이것이 직접 증명을 시도한 이유 중 하나이기도 하다. 구현을 하거나, 실무에 활용하는 입장에서는 이론적 배경이 중요하지 않을 수 있다. 하지만, 이론적 배경을 잊어버리더라도 최소한 한 번쯤은 읽어보고 이해하고 지나가야, 더욱 근본이 탄탄한 지식을 쌓을 수 있지 않을까.

댓글