Як сказати Керасу припинити тренування на основі втрат?


82

В даний час я використовую такий код:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Це говорить Керасу припинити тренування, коли втрата не покращилася протягом 2 епох. Але я хочу припинити тренування після того, як втрата стала меншою за деякий постійний "THR":

if val_loss < THR:
    break

Я бачив у документації, що є можливість зробити власний зворотний дзвінок: http://keras.io/callbacks/ Але нічого не знайдено, як зупинити навчальний процес. Мені потрібна порада.

Відповіді:


85

Я знайшов відповідь. Я заглянув у джерела Keras і з’ясував код для EarlyStopping. Я зробив власний зворотний дзвінок, спираючись на нього:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

І використання:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Просто якщо це комусь буде корисно - у моєму випадку я використовував monitor = 'loss', він спрацював добре.
QtRoS

15
Здається, Keras оновлено. Функція зворотного виклику EarlyStopping має вбудовану функцію min_delta. Не потрібно більше зламати вихідний код, так! stackoverflow.com/a/41459368/3345375
jkdev

3
Перечитавши запитання та відповіді, мені потрібно виправитись: min_delta означає "Зупинитись рано, якщо недостатньо покращень за епоху (або за кілька епох)". Однак ОП запитав, як "Зупинитись достроково, коли втрата опуститься нижче певного рівня".
jkdev

NameError: ім'я 'Callback' не визначено ... Як це виправити?
alyssaeliyah

2
Елія спробуй це: from keras.callbacks import Callback
ZFTurbo

26

Зворотний виклик keras.callbacks.EarlyStopping має аргумент min_delta. З документації Keras:

min_delta: мінімальна зміна відстежуваної кількості, яка може кваліфікуватися як поліпшення, тобто абсолютна зміна менше min_delta, не буде враховуватися як покращення.


3
Для довідки, ось документи для попередньої версії Keras (1.1.0), в яких аргумент min_delta ще не був включений: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

як я міг зробити так, щоб він не зупинявся, поки не min_deltaтриватиме кілька епох?
zyxue

є ще один параметр EarlyStopping, який називається терпінням: кількість епох без покращення, після яких навчання буде припинено.
Девін

13

Одним із рішень є виклик model.fit(nb_epoch=1, ...)всередині циклу for, тоді ви можете помістити оператор break всередину циклу for і виконати будь-який інший користувацький потік управління, який ви хочете.


Було б непогано, якби вони зробили зворотний дзвінок, який приймає одну функцію, яка може це зробити.
Чесність

8

Я вирішив ту ж проблему, використовуючи користувальницький зворотний виклик.

У наступному користувацькому коді зворотного виклику призначте THR із значенням, при якому ви хочете припинити навчання, і додайте зворотний виклик до своєї моделі.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

Поки я приймав TensorFlow на спеціалізації , я навчився дуже елегантній техніці. Тільки трохи змінено з прийнятої відповіді.

Давайте подамо приклад з нашими улюбленими даними MNIST.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Отже, тут я встановив metrics=['accuracy'], і, отже, у класі зворотного виклику умова встановлена ​​на 'accuracy'> 0.90.

Ви можете вибрати будь-яку метрику та відстежувати навчання, як у цьому прикладі. Найголовніше, що ви можете встановити різні умови для різних метрик і використовувати їх одночасно.

Сподіваємось, це допомагає!


1
ім'я функції має бути on_epoch_end
xarion

0

Для мене модель зупинила б навчання лише в тому випадку, якщо б я додав оператор return після встановлення параметра stop_training значення True, оскільки я закликав після self.model.evaluate. Тож обов’язково поставте stop_training = True в кінці функції або додайте оператор return.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Якщо ви використовуєте спеціальний цикл тренувань, ви можете використовувати a collections.deque, який є "рухомим" списком, який можна додати, а ліві елементи вискакують, коли список довший за maxlen. Ось рядок:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Ось повний приклад:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.