ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 오차역전파법 - 덧셈 계층, 곱셈 계층
    딥러닝 2020. 1. 27. 00:39

    덧셈 계층과 곱셈 계층을 구현해봅시다.

    각각 순전파와 역전파 모두 구현해야겠죠.

     

    우선 덧셈 계층부터 봅시다.

    순전파는 그냥 덧셈해주면 됩니다.

    역전파는 각 입력에 대해 미분을 해주고 그 값을 곱해주면 됩니다.(합성함수 미분의 원리)

    두 개 모두 미분값이 1이군요.

     

    구현해봅시다.

    class AddLayer:
        def __init__(self):
            pass
    	
        #순전파
        def forward(self, x, y):
            return x + y
    	
        #역전파
        def backward(self, dout):
            dx = dout * 1
            dy = dout * 1
            return dx, dy

     

    덧셈 계층은 역전파에서 순전파의 값을 이용하지 않습니다.

    그래서 초기화가 필요없습니다.(pass : 아무것도 하지 말라)

     

     

     

    곱셈 계층을 봅시다.

    순전파는 그냥 곱해주면 됩니다.

    역전파는 각 입력에 대해 미분을 해주고 그 값을 곱해주면 됩니다.(합성함수 미분의 원리)

    미분값은 다른 입력값이군요.

    입력값이 서로 바뀌어 곱해지는 꼴입니다.

     

    구현해봅시다.

    class MulLayer:
        def __init__(self):
            self.x = None
            self.y = None
    
        def forward(self, x, y):
            self.x = x
            self.y = y
            return x*y
    
        def backward(self, dout):
            dx = dout * self.y
            dy = dout * self.x
            return dx, dy
    

     

     

     

    역전파에서 순전파의 값이 필요합니다.

    그래서 초기화에서 순전파값 저장을 위한 변수를 만듭니다.

     

     

     

    그럼 전 포스팅에서 본 이 예시를 구현해봅시다.

     

     

    apple = 100
    orange = 150
    apple_num = 2
    orange_num = 3
    tax = 1.1
    
    #계층들
    mul_apple_layer = MulLayer()
    mul_orange_layer = MulLayer()
    add_apple_orange_layer = AddLayer()
    mul_tax_layer = MulLayer()
    
    #순전파
    apple_price = mul_apple_layer.forward(apple, apple_num)
    orange_price = mul_orange_layer.forward(orange, orange_num)
    all_price = add_apple_orange_layer.forward(apple_price, orange_price)
    price = mul_tax_layer.forward(all_price, tax)
    
    #역전파
    dprice = 1
    dall_price, dtax = mul_tax_layer.backward(dprice)
    
    dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)
    
    dapple, dapple_num = mul_apple_layer.backward(dapple_price)
    
    dorange, dorange_num = mul_orange_layer.backward(dorange_price)
    
    print(price) # 715
    print(dapple_num, dapple, dorange, dorange_num, dtax) # 110 2.2 3.3 165 650

    참고로

    -역전파의 리턴값은 2개 입니다.

    -순전파에서 순전파의 값이 저장됩니다.

     

    이것을 보고 우리가 알 수 있는 사실은

    순전파로 오른쪽 끝에 도달한 후, 역전파로 왼쪽 끝에 되돌아오면 모든 매개변수의 미분값을 알 수 있다는 것입니다.

     

     

     

     

     

     

    [참고]

    -밑바닥부터 시작하는 딥러닝

    -https://excelsior-cjh.tistory.com/171

Designed by Tistory.