Чому нам потрібно викликати zero_grad () у PyTorch?


Відповіді:


144

У PyTorch, нам потрібно встановити градієнти на нуль, перш ніж починати робити зворотне розповсюдження, оскільки PyTorch накопичує градієнти при наступних зворотних проходах. Це зручно під час навчання RNN. Отже, типовою дією є накопичення (тобто підсумовування) градієнтів кожного loss.backward()виклику.

Через це, починаючи цикл тренувань, в ідеалі ви повинні zero out the gradientsзробити так, щоб правильно виконувати оновлення параметрів. В іншому випадку градієнт вказував би в іншому напрямку, крім передбачуваного, до мінімуму (або максимуму , у разі цілей максимізації).

Ось простий приклад:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Як варіант, якщо ви робите ванільний градієнтний спуск , тоді:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Примітка : Накопичення (тобто сума ) градієнтів відбувається, коли .backward()його викликають на lossтензор .


3
велике спасибі, це справді корисно! Ви випадково знаєте, чи має поведінку тензорпотік?
layser

Просто щоб бути впевненим .. якщо ви цього не зробите, то зіткнетеся з проблемою вибуху градієнта, так?
zwep

2
@zwep Якщо ми накопичуємо градієнти, це не означає, що їх величина зростає: прикладом може бути, якщо знак градієнта постійно гортається. Тож це не буде гарантувати, що ви зіткнетеся з проблемою вибуху градієнта. Крім того, вибухові градієнти існують, навіть якщо ви правильно обнулите.
Том Рот,

Коли ви запускаєте ванільний градієнтний спуск, чи не виникає помилка "змінна листа, яка вимагає, щоб град використовувався в операції на місці", коли ви намагаєтесь оновити ваги?
MUAS

1

zero_grad () - це перезапуск циклу без втрат з останнього кроку, якщо ви використовуєте метод градієнта для зменшення помилки (або втрат)

якщо ви не використовуєте zero_grad (), втрата буде зменшуватися, а не збільшуватись відповідно до вимог

наприклад, якщо ви використовуєте zero_grad (), ви знайдете такий результат:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

якщо ви не використовуєте zero_grad (), ви знайдете такий результат:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.