Cơ bản về Backpropagation

Nội dung bài viết này được tham khảo chủ yếu từ: https://leonardoaraujosantos.gitbooks.io/artificial-inteligence/content/backpropagation.html

1. Giới thiệu

    Backpropagation (Truyền ngược) là một thuật toán mà ta rất hay gặp trong các mô hình mạng học sâu (Deep Learning), nó tính toán đạo hàm thành phần phần trên các nút của mô hình (Ví dụ: Convnet, Neural Network). Các đạo hàm thành phần này được sử dụng trong suốt quá trình huấn luyện mạng. Để hiểu rõ hơn các định nghĩa và các cơ chế đi kèm, bạn đọc có thể tham khảo tại: https://en.wikipedia.org/wiki/Backpropagation. Trong bài viết này, chúng ta sẽ cùng nhau xem xét cách thực hiện backpropagation trực quan và đơn giản nhất.
    Để thực hiện backpropagation một cách đơn giản,  ta sẽ biểu diễn mô hình như một đồ thị tính toán. Sau đó, ta sẽ tính forward propagation (Truyền xuôi) và đạo hàm trên mỗi block (khối).


2. Các block cơ bản

    Các hình dưới đây minh họa cách tính forward và backward propagation trên block cơ bản như add (cộng), nhân (multiply), exp, và max. Đường màu đỏ thể hiện forward và đường màu xanh thể hiện backforward. Chúng ta sẽ tuân theo và áp dụng các quy tắc tính toán này cho các đồ thị tính toán phức tạp hơn.





Một số đạo hàm khác: 

Để ý rằng chúng ta có 2 đạo hàm thành phần (gradient) bởi vì chúng ta có 2 input (đầu vào). Cũng lưu ý rằng chúng ta cần phải lưu các input trước đó vào bộ nhớ cache.

3. Quy tắc dây chuyền (Chain rule)

    Giả sử rằng ta có một output y, là một hàm g, g lại là hàm hợp của f, f là một hàm của x. Nếu ta muốn biết g sẽ thay đổi như thế nào nếu như có một sự thay đổi nhỏ ở dx (dg/dx), ta sẽ sử dụng quy tắc dây chuyền. Đó là một quy tắc để tính toán đạo hàm của một hàm là hàm hợp của hai hay nhiều hàm khác.



    Để hiểu rõ hơn về quy tắc dây chuyền, ta hãy cùng xem hai hình bên dưới. Hình bên trái biểu diễn một nút f(x, y) sẽ tính một hàm f với 2 input x và y và cho ra một output (đầu ra) z = f(x, y). Ở hình bên phải, chúng ta có một nút tương tự nhận được đạo hàm dL/dz từ một hàm L nào đó (ví dụ hàm mất mát - loss function) với ý nghĩa: "L sẽ thay đổi như thế nào nếu có một sự thay đổi nhỏ ở z ?". Do nút có 2 input nên nó sẽ có 2 đạo hàm thành phần tương ứng. Một đạo hàm thể hiện sự thay đổi của L khi có một sự thay đổi nhỏ dx và đạo hàm còn lại tương ứng với sự thay đổi nhỏ dy.


    Để tính toán các đạo hàm thành phần, chúng ta cần đạo hàm dL/dz (dout), đạo hàm thành phần của hàm f(x, y) và sau đó ta nhân chúng lại với nhau. Chúng ta cũng cần bộ nhớ cache của input trước đó, được lưu trong suốt quá trình forward propagation.

4. Thực hiện với Python

    Trong phần này chúng ta sẽ thực hiện tính forward và backpropagation với 2 hàm đơn giản là nhân và cộng.



5. Các ví dụ

    Với những kiến thức cơ bản vừa trình bày, chúng ta hãy cùng tính đạo hàm thành phần của một số đồ thị.

5.1. Ví dụ đơn giản

    Dưới đây ta có đồ thị của hàm f(x, y, z) = (x + y).z


1. Bắt đầu ở nút output f, và giả sử đạo hàm của f liên quan đến một số tiêu chí là 1 (dout).
2. dq = dout(1) * z = 1 * (-4) = -4. 
3. dz = dout(1) * q = 1 * 3 = 3.
4. Ở cổng sum, như đã trình bày ở mục 2, ta có dx = dy = dq = -4.

5.2. Perceptron với 2 input

    Đồ thị dưới đây biểu diễn forward và backpropagation của một neural network đơn giản với 2 input và 1 output layer với hàm kích hoạt (activation function) là hàm sigmoid.




1. Bắt đầu ở nút ouput, giả sử rằng dout = 1.
2. Đạo hàm tại input của 1/x là (-1/(1.37^2) ) * 1 = -0.53.
3. Nút "+1" không làm thay đổi đạo hàm tại input của nó, vì vậy nó bằng -0.53 * 1 = -0.53.
4. Nút exp có đạo hàm tại input là exp(cached input) * -0.53 = exp(-1) * -0.53 = -0.2.
5. Nút "*-1" có đạo hàm tại input là (-1 * -0.2) = 0.2.
6. Nút sum sẽ gồm có 2 thành phần đạo hàm, dw2 = 0.2 và nút sum ở thành phần còn lại có đạo hàm 0.2.
7. Tương tự như 6, nút sum có 2 thành phần đạo hàm bằng 0.2.
8. dw0 = (0.2 * -1) = -0.2, dx0 = (0.2 * 2) = 0.4.
9. dw1 = (0.2 * -2) = -0.4, dx1 = (0.2 * -3) = 0.6.

    Hi vọng với những kiến thức đã trình bày trên đây, bạn đọc đã có cái nhìn cơ bản nhất về backpropagation cũng như cách tính nó trong những đồ thị tính toán phức tạp hơn. Chúng ta sẽ còn gặp lại nó trong rất nhiều những bài viết về các mô hình học sâu mà tôi sẽ trình bày trong những bài sắp tới. 


Comments

Popular posts from this blog

Intersection over Union (IoU) cho object detection

Giới thiệu về Generative Adversarial Networks (GANs)

Inception modules