본문 바로가기
Computer Science/Machine Learning

Autodiff 직접 구현하기

by zxcvber 2021. 3. 21.

원래 Julia 로 구현해둔 것이 있었는데, 친구가 코드가 안예쁘다고 해서 파이썬으로 다시 짰다.

Julia 구현은 calofmijuck.tistory.com/30 참고!

Disclaimer. 저는 torch, tf 등의 라이브러리에서 autodiff 를 어떻게 하는지 모릅니다.

Autodiff

Autodiff (Automatic Differentiation) 는 주어진 함수의 미분계수를 자동으로 계산하는 방법이다.

예를 들어, 함수 \(f(x) = x^2\) 에 입력 \(x = 2\) 를 주면 미분계수 \(f'(2) = 4\) 를 출력한다.

주로 ML 에서 역전파(backpropagation) 할 때 사용한다고 들었다. Forward 한 번 할 때 미분계수가 자동으로 계산돼서 좋을 것 같다. (잘 모른다 ㅎㅎ)

아이디어 & 계산 과정

아이디어는 backprop 과 비슷하다. 함수가 합, 곱, 상수배 또는 함수 합성으로 표현 혹은 근사될 수 있다면, 연쇄 법칙 (chain rule) 을 적절히 활용하여 복잡한 함수도 간단한 함수로 나누어 계산 과정에서 미분계수를 함께 계산할 수 있게 되는 것이다.

예를 들어 함수 \(f(x) = (2x + 3)^3\) 을 생각해 보자. 이 함수의 계산 과정은 다음처럼 (비교적) 간단한 함수들의 합성으로 나타낼 수 있다.

\[x \xrightarrow{p(x) = 2x} 2x \xrightarrow{q(x) = x + 3} 2x + 3 \xrightarrow{r(x) = x^3} (2x+3)^3\]

그렇다면, \(f\) 의 미분계수는 다음과 같이 연쇄 법칙을 이용하여 구하면 된다.

\[\frac{df}{dx} = \frac{d(r\circ q\circ p)}{dx} = \frac{dr}{dq} \frac{dq}{dp} \frac{dp}{dx}\]

\(x = 1\) 에서의 미분계수를 한번 구해보자. 입력으로 (함숫값, 미분계수) pair 를 넘겨주면 된다. 미분계수의 초기 값은 \(1\) 로 설정한다. 이유는 항등함수 \(id(x) = x\) 의 미분계수가 \(1\) 이기 때문이다.

우선 \((1, 1)\) 을 입력으로 넣어주면, \(\dfrac{dp}{dx} = 2x\) 이므로 \((2, 2)\) 가 출력될 것이다.

이제 이 중간 결과물은 함수 \(q(x)\) 의 입력으로 주어진다. 다만 여기서 연쇄 법칙의 식을 보면 \(\dfrac{dq}{dp}\) 이지 \(\dfrac{dq}{dx}\) 가 아니므로, 최초 입력이었던 \((1, 1)\) 을 사용하는 것이 아니라 위에서 \(p(x)\) 를 한 번 통과한 \((2, 2)\) 를 사용한다. 따라서 함숫값은 \(5\) 가 되고, 미분계수는 \(\dfrac{dq}{dp} = 1\) 이므로 변화 없다. 여기까지의 결과물은 \((5, 2)\) 가 된다.

마찬가지 방법으로 위 결과를 \(r(x)\) 의 입력으로 넘겨주면, 함숫값은 \(125\) 가 되고, 미분계수는 \(\dfrac{dr}{dq} = 3q^2\) 이므로 \(75\) 를 미분계수에 추가로 곱해 \(150\) 을 얻는다.[1] 그러므로 최종 결과물은 \((125, 150)\) 이 된다.

실제로 \(f\) 를 미분해 보면, \(f'(x) = 6(2x+3)^2\) 이므로 \(f'(1) = 150\) 을 얻어 계산이 잘 되었음을 확인할 수 있다.

[1]: \(r(x)\) 에 넘겨주기 전까지의 결과는 \((5, 2)\) 였다. 그러므로 \(\dfrac{dr}{dq} = 3q^2\) 의 \(q\) 에는 \(5\) 를 대입해야 하며, 곱하는 이유는 연쇄 법칙에서 미분계수들을 곱하기 때문이다. 아래 식을 보면 이해하는데 도움이 될지도 모르겠다.

\[\frac{df}{dx}\bigg\lvert_{x=1} = \frac{dr}{dq}\bigg\lvert_{q=q(p(1))=5} \cdot \dfrac{dq}{dp}\bigg\lvert_{p=p(1)=2} \cdot \dfrac{dp}{dx}\bigg\lvert_{x=1}\]

구현

구현할 때는 함숫값과 미분계수를 함께 들고있는 객체를 사용하면 된다. 함숫값을 함께 들고 있는 이유는, 애초에 함숫값 계산 과정에서 함께 미분계수를 계산하고 싶었던 것일 뿐만 아니라, 미분계수 계산시에 함숫값이 필요하기도 하기 때문이다.

다음과 같이 DiffObject 를 정의한다. 함숫값은 x 에, 미분계수는 dx 에 저장하며, 항등함수의 미분계수가 \(1\) 이기 때문에 dx 는 default parameter 로 \(1\) 을 준다. __repr__ 는 적당히 해둔다.

class DiffObject:
    def __init__(self, x, dx=1):
        self.x = x
        self.dx = dx

    def __repr__(self):
        return 'DiffObject(x={}, dx={})'.format(self.x, self.dx)

이제 함숫값 계산 과정에서 미분계수를 함께 계산하도록 각종 연산자를 오버로딩 하면 된다!

우선 가장 쉬운, 두 결과를 합하는 경우부터 구현한다. \(y = f + g\) 라면

\[dy = df + dg\]

이므로, 다음과 같이 구현하면 된다.

def __add__(self, other):
    return self.__class__(self.x + other.x, self.dx + other.dx)

def __radd__(self, other):
    return self.__add__(other)

즉, 함숫값은 더하고 미분계수도 더한다.

곱을 구현하기 위해서는 곱의 미분법 (product rule) 을 사용한다. \(y = f\cdot g\) 라면

\[dy = df \cdot g + f \cdot dg \]

이므로, 다음과 같이 구현한다.

def __mul__(self, other):
    product = self.x * other.x
    dproduct = self.dx * other.x + self.x * other.dx
    return self.__class__(product, dproduct)

상수배

상수배를 처리하기 위해 __rmul__ 을 구현한다. \(y = kf\) (\(k\): 상수) 라면 \(dy = k \cdot df\) 이므로, 다음과 같이 구현한다.

def __rmul__(self, k):
    return self.__class__(k * self.x, k * self.dx)

Data Promotion

다만 위 구현에서 한 가지 문제가 있는데, other 의 타입이 DiffObject 로 제한되지 않는다! 그래서 DiffObject 를 만든 후 + 3, * 2.5 와 같이 int / float 를 연산하게 되면 에러가 발생한다.

위와 같은 경우를 처리하기 위해 int / float 타입과 연산하는 경우 값을 DiffObject 로 변환해야 한다.

def promote(self, other):
    if type(other) in (int, float):
        other = self.__class__(other, 0)
    return other

other 의 타입이 intfloat 이면 DiffObject 로 변환하여 돌려준다. 상수이므로 dx=0 으로 설정해야 한다.

중간 점검

잘 동작하는지 테스트 해본다. \(f(x) = 2x^2+3\) 에 대해서 구해보자. 다만 아직 거듭제곱은 구현하지 않은 상태이므로 x * x 로 적어야 한다.

def f(x):
    return 2 * x * x + 3

x = DiffObject(2)
print(f(x)) # DiffObject(x=5, dx=4)

단항 연산자와 차

단항 연산자는 +, - 이렇게 2개가 있다. 이는 상수배로 표현할 수 있으므로 다음과 같이 구현하면 된다.

def __pos__(self):
    return self.__rmul__(1)

def __neg__(self):
    return self.__rmul__(-1)

그리고 차의 경우 \(-1\) 을 곱하고 더하는 것이므로 다음과 같이 구현한다. __rsub__ 의 경우 promote 이후 __sub__ 를 호출한다.

def __sub__(self, other):
    return self.__add__(-other)

def __rsub__(self, other):
    other = self.promote(other)
    return other.__sub__(self)

몫을 구현하기 위해서는 몫의 미분법 (quotient rule) 을 사용한다. \(y = f/g\) 라면

\[dy = \frac{df \cdot g - f \cdot dg}{g^2}\]

이므로, 다음과 같이 구현한다.

def __truediv__(self, other):
    other = self.promote(other)
    ddiv = (self.dx * other.x - self.x * other.dx) / (other.x ** 2)
    return self.__class__(self.x / other.x, ddiv)

def __rtruediv__(self, other):
    other = self.promote(other)
    return other.__truediv__(self)

거듭제곱

마지막으로 거듭제곱을 구현하기 위해서는 연쇄 법칙을 사용한다. (ㅋㅋ) \(y = f^g\) 라면

\[dy = \frac{\partial y}{\partial f} df + \frac{\partial y}{\partial g}dg =gf^{g-1} \cdot df + f^g \log f \cdot dg\]

이므로, 다음과 같이 구현한다.

def __pow__(self, other):
    other = self.promote(other)
    dexp = other.x * self.x ** (other.x - 1) * self.dx + \
        self.x ** other.x * log(self.x) * other.dx
    return self.__class__(self.x ** other.x, dexp)

def __rpow__(self, other):
    other = self.promote(other)
    return other.__pow__(self)

끝! 이 정도면 어느 정도 커버 가능하다.

테스트

이제 다양한 함수들을 만들어서 잘 동작하는지 확인해보자!

import numpy as np
from diff_object import DiffObject


def poly1(x):
    return x - 2


def poly2(x):
    return -x ** 2


def poly3(x):
    return 2 * x ** 2 + 3 * x


def poly4(x):
    return 3 * x ** 4 - 3 * x ** 2 + x - 2


def quotient1(x):
    return 1 / x


def quotient2(x):
    return (x ** 2 + 1) / (x - 1)


def irrational1(x):
    return x ** 0.5


def irrational2(x):
    return 1 / (x ** 2 - 1) ** (1 / 3)


def mySin(x):
    return x - x ** 3 / 6 + x ** 5 / 120


def myCos(x):
    return 1 - x ** 2 / 2 + x ** 4 / 24

def myTan(x):
    return mySin(x) / myCos(x)

def exp(x):
    return np.e ** x

def sigmoid(x):
    return 1 / (1 + np.e ** (-x))

x = DiffObject(2)

print(poly1(x)) # DiffObject(x=0, dx=1)
print(poly2(x)) # DiffObject(x=-4, dx=-4.0)
print(poly3(x)) # DiffObject(x=14, dx=11.0)
print(poly4(x)) # DiffObject(x=36, dx=85.0)
print(quotient1(x)) # DiffObject(x=0.5, dx=-0.25)
print(quotient2(x)) # DiffObject(x=5.0, dx=-1.0)
print(irrational1(x)) # DiffObject(x=1.4142135623730951, dx=0.3535533905932738)
print(irrational2(x)) # DiffObject(x=0.6933612743506348, dx=-0.3081605663780598)

y = DiffObject(0.1)

print(mySin(y)) # DiffObject(x=0.09983341666666667, dx=0.9950041666666667)
print(myCos(y)) # DiffObject(x=0.9950041666666667, dx=-0.09983333333333334)
print(myTan(y)) # DiffObject(x=0.10033467196536028, dx=1.0100670379951928)

print(exp(x)) # DiffObject(x=7.3890560989306495, dx=7.3890560989306495)
print(sigmoid(x)) # DiffObject(x=0.8807970779778823, dx=0.1049935854035065)

def complicated1(x):
    return (-x ** 2 + 1) * mySin(x)

def complicated2(x):
    return exp(x) * myCos(x)

print(complicated1(y)) # DiffObject(x=0.0988350825, dx=0.9650874416666667)
print(complicated2(y)) # DiffObject(x=1.0996496683640948, dx=0.9893167717095427)

# Compositions

print(exp(myCos(y))) # DiffObject(x=2.704735610987686, dx=-0.2700227718302707)
print(mySin(myCos(y))) # DiffObject(x=0.8389502013815642, dx=-0.0544913894766819)

아주 잘 된다! (WolframAlpha 로 확인했다 ㅋㅋ)

복잡한 함수를 비롯하여 합성 함수까지도 처리되는 것을 확인할 수 있다!

TODO

  • np.sin 등을 비롯한 함수에 대해서 동작하지 않는데 한 번 이유 찾아보기
  • 이렇게 하면 행렬이나 벡터에 대해서도 처리가 되는지 확인해보기

후기

실제로 라이브러리들이 autodiff 를 어떻게 처리하는지 몰라서, 라이브러리 등의 실제 구현과는 다를 수 있다.

특별히 참고한 문서가 있는 것도 아니고 그냥 이렇게 하면 될 것 같아서 내 생각대로, 의식의 흐름대로 짠 코드라 굉장히 비효율적일 수 있다.

기회가 되면 실제 라이브러리들의 코드도 한 번 파헤쳐 볼 생각이다!

댓글