Я навчаю нейронну мережу для свого проекту за допомогою Keras. Keras забезпечив функцію ранньої зупинки. Чи можу я знати, які параметри слід дотримуватись, щоб уникнути перенапруження нейронної мережі за допомогою ранньої зупинки?
Я навчаю нейронну мережу для свого проекту за допомогою Keras. Keras забезпечив функцію ранньої зупинки. Чи можу я знати, які параметри слід дотримуватись, щоб уникнути перенапруження нейронної мережі за допомогою ранньої зупинки?
Відповіді:
Рання зупинка - це, в основному, припинення навчання, як тільки ваша втрата починає збільшуватися (або іншими словами, точність перевірки починає зменшуватися). Згідно з документами він використовується наступним чином;
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0, mode='auto')
Значення залежать від вашої реалізації (проблема, розмір партії тощо ...), але загалом для запобігання переобладнанню я б використав;
monitor
аргумент на 'val_loss'
.min_delta
- це поріг для того, щоб кількісно визначити втрату в якусь епоху як поліпшення чи ні. Якщо різниця у збитках нижча min_delta
, це кількісно визначається як відсутність покращення. Краще залишити це як 0, оскільки нас цікавить, коли збиток стає гіршим.patience
аргумент представляє кількість епох до зупинки, як тільки ваша втрата починає збільшуватися (перестає покращуватися). Це залежить від вашої реалізації, якщо ви використовуєте дуже маленькі партії
або велику швидкість навчання, ваша втрата зигзагоподібна (точність буде більш галасливою), тому краще встановіть великий patience
аргумент. Якщо ви використовуєте великі партії та малу швидкість навчання, ваші втрати будуть більш плавними, тому ви можете використовувати менший patience
аргумент. У будь-якому випадку я залишу це як 2, щоб я дав моделі більше шансів.verbose
вирішує, що друкувати, залиште це за замовчуванням (0).mode
аргумент залежить від того, в якому напрямку знаходиться ваша відстежувана кількість (чи має вона зменшуватися чи збільшуватися), оскільки ми контролюємо втрати, ми можемо використати min
. Але залишмо кери обробляти це для нас і встановимо це наauto
Тому я б використав щось подібне та експериментував, будуючи графік втрати помилки з попередньою зупинкою та без неї.
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=2,
verbose=0, mode='auto')
Для можливої неоднозначності щодо роботи зворотних викликів я спробую пояснити більше. Після того, як ви зателефонуєте fit(... callbacks=[es])
на свою модель, Keras викликає задані об'єкти зворотного виклику заздалегідь визначеними функціями. Ці функції можуть бути викликані on_train_begin
, on_train_end
, on_epoch_begin
, on_epoch_end
і on_batch_begin
, on_batch_end
. Рання зупинка зворотного виклику викликається на кожному кінці епохи, порівнюється найкраще відстежуване значення з поточним і зупиняється, якщо виконуються умови (скільки епох минуло з моменту спостереження найкращого відстежуваного значення і чи це більше аргументу терпіння, різниця між останнє значення більше, ніж min_delta тощо.).
Як зазначає @BrentFaust в коментарях, навчання моделі триватиме доти, доки не будуть дотримані умови раннього зупинення або не буде виконано epochs
параметр (за замовчуванням = 10) fit()
. Встановлення зворотного виклику з достроковим зупиненням не призведе до того, що модель тренується за межами своїх epochs
параметрів. Тож fit()
функція виклику з більшим epochs
значенням отримає більше вигоди від дострокового зупинки зворотного дзвінка.
callbacks=[EarlyStopping(patience=2)]
не має жодного ефекту, якщо лише епохи не надані model.fit(..., epochs=max_epochs)
.
epoch=1
у циклі for (для різних випадків використання), коли цей зворотний виклик не вдався. Якщо у моїй відповіді є двозначність, я спробую викласти це краще.
restore_best_weights
аргумент (поки що не в документації), який завантажує модель з найкращими вагами після тренування. Але для ваших цілей я використав би ModelCheckpoint
зворотний виклик з save_best_only
аргументом. Ви можете перевірити документацію, її просто використовувати, але вам потрібно вручну завантажувати найкращі ваги після тренувань.
min_delta
- це поріг для того, щоб кількісно оцінити зміну відстежуваного значення як поліпшення чи ні. Так що, якщо ми дамо,monitor = 'val_loss'
тоді це стосуватиметься різниці між поточними втратами перевірки та попередніми втратами перевірки. На практиці, якщо ви даєтеmin_delta=0.1
зменшення втрат під час перевірки (поточні - попередні), менші за 0,1, це не буде кількісно, таким чином, ви припините навчання (якщо у вас єpatience = 0
).