Які параметри слід використовувати для ранньої зупинки?


97

Я навчаю нейронну мережу для свого проекту за допомогою Keras. Keras забезпечив функцію ранньої зупинки. Чи можу я знати, які параметри слід дотримуватись, щоб уникнути перенапруження нейронної мережі за допомогою ранньої зупинки?

Відповіді:


157

рання зупинка

Рання зупинка - це, в основному, припинення навчання, як тільки ваша втрата починає збільшуватися (або іншими словами, точність перевірки починає зменшуватися). Згідно з документами він використовується наступним чином;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

Значення залежать від вашої реалізації (проблема, розмір партії тощо ...), але загалом для запобігання переобладнанню я б використав;

  1. Відстежуйте втрати при валідації (потрібно використовувати перехресну перевірку або принаймні тренувальні / тестові набори), встановивши monitor аргумент на 'val_loss'.
  2. min_delta- це поріг для того, щоб кількісно визначити втрату в якусь епоху як поліпшення чи ні. Якщо різниця у збитках нижча min_delta, це кількісно визначається як відсутність покращення. Краще залишити це як 0, оскільки нас цікавить, коли збиток стає гіршим.
  3. patienceаргумент представляє кількість епох до зупинки, як тільки ваша втрата починає збільшуватися (перестає покращуватися). Це залежить від вашої реалізації, якщо ви використовуєте дуже маленькі партії або велику швидкість навчання, ваша втрата зигзагоподібна (точність буде більш галасливою), тому краще встановіть великий patienceаргумент. Якщо ви використовуєте великі партії та малу швидкість навчання, ваші втрати будуть більш плавними, тому ви можете використовувати менший patienceаргумент. У будь-якому випадку я залишу це як 2, щоб я дав моделі більше шансів.
  4. verbose вирішує, що друкувати, залиште це за замовчуванням (0).
  5. modeаргумент залежить від того, в якому напрямку знаходиться ваша відстежувана кількість (чи має вона зменшуватися чи збільшуватися), оскільки ми контролюємо втрати, ми можемо використати min. Але залишмо кери обробляти це для нас і встановимо це наauto

Тому я б використав щось подібне та експериментував, будуючи графік втрати помилки з попередньою зупинкою та без неї.

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=2,
                              verbose=0, mode='auto')

Для можливої ​​неоднозначності щодо роботи зворотних викликів я спробую пояснити більше. Після того, як ви зателефонуєте fit(... callbacks=[es])на свою модель, Keras викликає задані об'єкти зворотного виклику заздалегідь визначеними функціями. Ці функції можуть бути викликані on_train_begin, on_train_end, on_epoch_begin, on_epoch_endі on_batch_begin, on_batch_end. Рання зупинка зворотного виклику викликається на кожному кінці епохи, порівнюється найкраще відстежуване значення з поточним і зупиняється, якщо виконуються умови (скільки епох минуло з моменту спостереження найкращого відстежуваного значення і чи це більше аргументу терпіння, різниця між останнє значення більше, ніж min_delta тощо.).

Як зазначає @BrentFaust в коментарях, навчання моделі триватиме доти, доки не будуть дотримані умови раннього зупинення або не буде виконано epochsпараметр (за замовчуванням = 10) fit(). Встановлення зворотного виклику з достроковим зупиненням не призведе до того, що модель тренується за межами своїх epochsпараметрів. Тож fit()функція виклику з більшим epochsзначенням отримає більше вигоди від дострокового зупинки зворотного дзвінка.


3
@AizuddinAzman close min_delta- це поріг для того, щоб кількісно оцінити зміну відстежуваного значення як поліпшення чи ні. Так що, якщо ми дамо, monitor = 'val_loss'тоді це стосуватиметься різниці між поточними втратами перевірки та попередніми втратами перевірки. На практиці, якщо ви даєте min_delta=0.1зменшення втрат під час перевірки (поточні - попередні), менші за 0,1, це не буде кількісно, ​​таким чином, ви припините навчання (якщо у вас є patience = 0).
умутто

3
Зверніть увагу, що callbacks=[EarlyStopping(patience=2)]не має жодного ефекту, якщо лише епохи не надані model.fit(..., epochs=max_epochs).
Brent Faust,

1
@BrentFaust Це також я розумію, я написав відповідь, припускаючи, що модель навчається щонайменше 10 епох (за замовчуванням). Після Вашого коментаря я зрозумів, що може бути випадок, коли програміст викликає fit epoch=1у циклі for (для різних випадків використання), коли цей зворотний виклик не вдався. Якщо у моїй відповіді є двозначність, я спробую викласти це краще.
умуто

4
@AdmiralWen З моменту написання відповіді код трохи змінився. Якщо ви використовуєте останню версію Keras, ви можете використовувати restore_best_weightsаргумент (поки що не в документації), який завантажує модель з найкращими вагами після тренування. Але для ваших цілей я використав би ModelCheckpointзворотний виклик з save_best_onlyаргументом. Ви можете перевірити документацію, її просто використовувати, але вам потрібно вручну завантажувати найкращі ваги після тренувань.
умутто

1
@umutto Привіт, дякую за пропозицію відновити_бест_ваги, однак я не можу її використовувати, `es = EarlyStopping (monitor = 'val_acc', min_delta = 1e-4, patience = patience_, verbose = 1, restore_best_weights = True) TypeError: __init __ () отримав несподіваний аргумент ключового слова 'restore_best_weights' '. Якісь ідеї? keras 2.2.2, tf, 1.10 яка ваша версія?
Харамоз
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.