[밑바닥딥러닝] 9. 오차역전파법(backpropagation) - 계산그래프

2021. 10. 3. 17:56Deep Learning

본 게시글은 한빛미디어 『밑바닥부터 시작하는 딥러닝, 사이토 고키, 2020』의 내용을 참조하였음을 밝힙니다.

 

 

 지난 장에서는 예측 모델에서 예측한 결과와 실제 값과의 차이를 통해 손실값을 계산하고 

 

계산된 손실 값에서 가중치(및 편향)에 대해서 (가중치마다) 수치미분을 수행해 이를 가중치 갱신에 

 

반영하여 학습시키는 방법에 대해서 알아보았다. 

 

 하지만 이는 시간 복잡도 측면에서 매우 비효율적인 방법이므로 이번 장에서

 

오차역전파법(Backpropagation)에 대해서 알아보도록 하겠다. 

 

 

 

계산 그래프


 오차역전파에 대해 알아보기 앞서서 계산 그래프가 무엇인지 살펴보고 계산 그래프를 사용했을 때의 

 

이점에 대해서 알아보도록 하겠다. 

 

개당 가격이 a원인 사과 n개를 주문하고 이에 대한 소비세가 t %라고 할 때 이를 수식으로 나타내면 

 

최종 가격 = (a * n) * (1 + t/100) 로 나타낼 수 있다. 

 

계산 그래프(Computatinal graph)란 계산 과정에서 나타나는 연산의 대상(사과, 세금 등)과 대상에 수행하는 연산을

 

그래프의 형태로 도식화한 그래프이다. 아래의 그림은 위에 서술한 연산 과정을 계산 그래프로 나타낸 것이다.

계산 그래프를 통해 계산한 최종 결제 가격

 계산 그래프를 사용하는 것처럼 코드 내에서 각각의 연산과 그 대상들을 모듈화하여 구현하면 

 

해당 모듈들을 단순히 조합하는 것으로 순전파와 역전파를 간단하게 구현할 수 있다. 

 

계산 그래프에서는 합성함수의 미분 등 계산이 복잡해질 때 '국소적 계산'을 함으로써 용이하게 연산 수행이 

 

가능해진다. (합성 함수 미분에 대한 자세한 설명은 아래를 참조하도록 하자.)

https://en.wikipedia.org/wiki/Chain_rule

 

Chain rule - Wikipedia

From Wikipedia, the free encyclopedia Jump to navigation Jump to search Formula for the derivative of composed functions In calculus, the chain rule is a formula that expresses the derivative of the composition of two differentiable functions f and g in te

en.wikipedia.org

 

연산을 계산 그래프화하지 않았다면, 다시 말해 연산을 모듈화해두지 않았다면 

 

역전파 시 최종 출력으로부터 전해져오는 역전파부터 모두 고려해야하지만 계산 그래프에서는 

 

전해져오는 값이 무엇이든 그에 관계없이 자신 노드에서의 역전파 연산만 수행하면된다.

 

https://www.tutorialspoint.com/python_deep_learning/python_deep_learning_computational_graphs.htm 

 

Computational Graphs

Computational Graphs Backpropagation is implemented in deep learning frameworks like Tensorflow, Torch, Theano, etc., by using computational graphs. More significantly, understanding back propagation on computational graphs combines several different algor

www.tutorialspoint.com

 

계산 그래프 덧셈 노드, 곱셈 노드의 순전파와 역전파에 대한 기본적인 설명은 위의 글을 참조하도록 하자. 

 

요약하자면 덧셈 노드의 역전파는 흘러온 역전파를 단순히 1을 곱하는 방식으로 그대로 흘려보내고 

 

곱셈 노드의 역전파에서는 흘려온 역전파 값에 자신의 곱셈 상대였던 것을 곱해주면 된다. 

 

계산 그래프에서의 역전파 수행

계산 그래프에서 최종 출력인 z에서부터 역전파는 시작된다. (지금부터 빨간 글씨에 주목하자)

 

z를 z 자신에 대해서 미분하면 dz/dz = 1이 된다. (z는 변하는 정도가 z 자신에게 그대로 반영되기 때문이다)

 

그 다음으로 마주치는 것이 곱셈 노드인데,  세금이 반영되기 이전, 다시말해 tax값과 곱셈을 수행하기 이전인 

 

전체과일가격(total_fruit_price)에 대해서 z를 미분할 경우(dz/dtotal_fruit_price) 흘러온 역전파 값인 1에 

 

자신의 곱셈 상대인 tax (1.1)이 곱해진다. tax에 대해 미분한 값도 마찬가지고 역전파에 전체 과일 가격이 곱해진다. 

 

그 다음으로 만나는 것이 덧셈 노드인데, 덧셈 노드에서는 역전파 시 1을 곱한다. (앞선 역전파 값 그대로 흘려보냄)

 

이를 코드로 구현해보자. 

class AddLayer:
    def __init__(self):
        pass

    def forward(self, x, y):
        return x+y

    def backward(self, dout):
        return dout

덧셈 노드(AddLayer)는 순전파 시 덧셈 대상 2개를 서로 더해서 반환하고, 역전파 시 흘러온 역전파를 그대로 보낸다. 

 

class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y  ### 역전파를 위해 두 값을 인스턴스 변수에 저장
        out = x * y

        return out

    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x
        return dx, dy

곱셈 노드(MulLayer)는 순전파시 곱셈 대상 2개를 곱한 값을 반환하고, 역전파 시 앞선 역전파 값에 곱셈 대상을 서로

 

바꿔 곱하여 각각 반환한다. 초기화 함수(init)에서 곱셈 대상인 x, y를 인스턴스 변수에 저장해두는데, 이는 역전파를 

 

계산할 때 순전파 시 수행했던 연산 대상들을 기억하기 위한 목적이다. 

    apple_price = 100
    apple_num = 4
    mandarin_price = 70
    mandarin_num = 3
    tax = 1.1
    
    mul_apple_layer = MulLayer()
    mul_mandarin_layer = MulLayer()
    add_price_Layer = AddLayer()
    mul_tax_layer = MulLayer()

    total_apple_price = mul_apple_layer.forward(apple_price, apple_num)
    total_mandarin_price = mul_mandarin_layer.forward(mandarin_price, mandarin_num)
    total_fruit_price = add_price_Layer.forward(total_apple_price, total_mandarin_price)
    final_price = mul_tax_layer.forward(total_fruit_price, tax)

    print(final_price)
    
    결과 >>> 671.0

 연산 대상들을 변수로 선언 및 정의해주고, 곱셈노드 3개와 덧셈노드 1개를 선언하여 순전파를 수행하였다.

 

물론 결과는 계산 그래프에 나온 것과 일치한다.

 

다음은 역전파 수행이다.

    d_final_price = 1
    d_total_fruit_price, d_tax = mul_tax_layer.backward(d_final_price)
    d_total_apple_price = add_price_Layer.backward(d_total_fruit_price)
    d_total_mandarin_price = add_price_Layer.backward(d_total_fruit_price)
    d_apple_price, d_apple_num = mul_apple_layer.backward(d_total_apple_price)
    d_mandarin_price, d_mandarin_num = mul_mandarin_layer.backward(d_total_mandarin_price)

앞에서부터 순서대로 노드들을 순회하면서 역전파를 수행하였다. 

 

    d_var_list = [d_apple_price, d_apple_num, d_mandarin_price, d_mandarin_num, d_total_apple_price\
                  ,d_total_mandarin_price, d_total_fruit_price, d_tax]

    for i in range(len(d_var_list)):
        print(f"{d_var_list[i]}")
    
    결과 >>>
    4.4
    110.00000000000001
    3.3000000000000003
    77.0
    1.1
    1.1
    1.1
    610

계산 그래프를 통해 살펴본 역전파 값들과 일치한다. 

 

지금까지 계산그래프를 통한 연산 방법(순전파)과 역전파하는 방법에 대해 알아보았다. 

 

다음 장에서는 신경망 내에서 오차역전파를 수행하는 방법에 대해 알아보도록 하자.