Я навчаю нейронну мережу для свого проекту за допомогою Keras. Keras забезпечив функцію ранньої зупинки. Чи можу я знати, які параметри слід дотримуватись, щоб уникнути перенапруження нейронної мережі за допомогою ранньої зупинки?
Я навчаю нейронну мережу для свого проекту за допомогою Keras. Keras забезпечив функцію ранньої зупинки. Чи можу я знати, які параметри слід дотримуватись, щоб уникнути перенапруження нейронної мережі за допомогою ранньої зупинки?
Відповіді:
Рання зупинка - це, в основному, припинення навчання, як тільки ваша втрата починає збільшуватися (або іншими словами, точність перевірки починає зменшуватися). Згідно з документами він використовується наступним чином;
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0, mode='auto')
Значення залежать від вашої реалізації (проблема, розмір партії тощо ...), але загалом для запобігання переобладнанню я б використав;
monitor
аргумент на 'val_loss'.min_delta- це поріг для того, щоб кількісно визначити втрату в якусь епоху як поліпшення чи ні. Якщо різниця у збитках нижча min_delta, це кількісно визначається як відсутність покращення. Краще залишити це як 0, оскільки нас цікавить, коли збиток стає гіршим.patienceаргумент представляє кількість епох до зупинки, як тільки ваша втрата починає збільшуватися (перестає покращуватися). Це залежить від вашої реалізації, якщо ви використовуєте дуже маленькі партії
або велику швидкість навчання, ваша втрата зигзагоподібна (точність буде більш галасливою), тому краще встановіть великий patienceаргумент. Якщо ви використовуєте великі партії та малу швидкість навчання, ваші втрати будуть більш плавними, тому ви можете використовувати менший patienceаргумент. У будь-якому випадку я залишу це як 2, щоб я дав моделі більше шансів.verbose вирішує, що друкувати, залиште це за замовчуванням (0).modeаргумент залежить від того, в якому напрямку знаходиться ваша відстежувана кількість (чи має вона зменшуватися чи збільшуватися), оскільки ми контролюємо втрати, ми можемо використати min. Але залишмо кери обробляти це для нас і встановимо це наautoТому я б використав щось подібне та експериментував, будуючи графік втрати помилки з попередньою зупинкою та без неї.
keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=2,
verbose=0, mode='auto')
Для можливої неоднозначності щодо роботи зворотних викликів я спробую пояснити більше. Після того, як ви зателефонуєте fit(... callbacks=[es])на свою модель, Keras викликає задані об'єкти зворотного виклику заздалегідь визначеними функціями. Ці функції можуть бути викликані on_train_begin, on_train_end, on_epoch_begin, on_epoch_endі on_batch_begin, on_batch_end. Рання зупинка зворотного виклику викликається на кожному кінці епохи, порівнюється найкраще відстежуване значення з поточним і зупиняється, якщо виконуються умови (скільки епох минуло з моменту спостереження найкращого відстежуваного значення і чи це більше аргументу терпіння, різниця між останнє значення більше, ніж min_delta тощо.).
Як зазначає @BrentFaust в коментарях, навчання моделі триватиме доти, доки не будуть дотримані умови раннього зупинення або не буде виконано epochsпараметр (за замовчуванням = 10) fit(). Встановлення зворотного виклику з достроковим зупиненням не призведе до того, що модель тренується за межами своїх epochsпараметрів. Тож fit()функція виклику з більшим epochsзначенням отримає більше вигоди від дострокового зупинки зворотного дзвінка.
callbacks=[EarlyStopping(patience=2)]не має жодного ефекту, якщо лише епохи не надані model.fit(..., epochs=max_epochs).
epoch=1у циклі for (для різних випадків використання), коли цей зворотний виклик не вдався. Якщо у моїй відповіді є двозначність, я спробую викласти це краще.
restore_best_weightsаргумент (поки що не в документації), який завантажує модель з найкращими вагами після тренування. Але для ваших цілей я використав би ModelCheckpointзворотний виклик з save_best_onlyаргументом. Ви можете перевірити документацію, її просто використовувати, але вам потрібно вручну завантажувати найкращі ваги після тренувань.
min_delta- це поріг для того, щоб кількісно оцінити зміну відстежуваного значення як поліпшення чи ні. Так що, якщо ми дамо,monitor = 'val_loss'тоді це стосуватиметься різниці між поточними втратами перевірки та попередніми втратами перевірки. На практиці, якщо ви даєтеmin_delta=0.1зменшення втрат під час перевірки (поточні - попередні), менші за 0,1, це не буде кількісно, таким чином, ви припините навчання (якщо у вас єpatience = 0).