Найкращий спосіб зберегти навчену модель в PyTorch?


193

Я шукав альтернативні способи збереження навченої моделі в PyTorch. Поки що я знайшов дві альтернативи.

  1. torch.save () для збереження моделі та torch.load () для завантаження моделі.
  2. model.state_dict () для збереження навченої моделі та model.load_state_dict () для завантаження збереженої моделі.

Я натрапив на цю дискусію, де рекомендується підхід 2 над підходом 1.

Моє запитання: чому перевагу надається другому підходу? Це лише тому, що модулі torch.nn мають ці дві функції, і нам рекомендується їх використовувати?


2
Я думаю, це тому, що torch.save () також зберігає всі проміжні змінні, як проміжні виходи для використання зворотного розповсюдження. Але вам потрібно лише зберегти параметри моделі, наприклад вагу / зміщення тощо. Іноді перший може бути набагато більшим, ніж другий.
Давей Ян

2
Я тестував torch.save(model, f)і torch.save(model.state_dict(), f). Збережені файли мають однаковий розмір. Тепер я розгублений. Крім того, я знайшов використання соління для збереження model.state_dict () надзвичайно повільним. Я думаю, що найкращим способом є використання, torch.save(model.state_dict(), f)оскільки ви керуєтесь створенням моделі, а факел справляється з завантаженням ваг моделі, тим самим усуваючи можливі проблеми. Довідка: обговорити.pytorch.org/t/saving-torch-models/838/4
Yang

Схоже, PyTorch вирішив це дещо виразніше у своєму розділі навчальних посібників - там є багато корисної інформації, яка не вказана у відповідях тут, включаючи збереження більше однієї моделі одночасно та теплі стартові моделі.
whlteXbread

що не так з використанням pickle?
Чарлі Паркер

1
@CharlieParker torch.save заснований на соління. Далі йде підручник, зв'язаний вище: "[torch.save] збереже весь модуль, використовуючи модуль підбору Python. Недоліком такого підходу є те, що серіалізовані дані прив'язані до конкретних класів та точної структури каталогів, що використовується при моделі Причина цього полягає в тому, що pickle не зберігає сам клас моделі. Скоріше, він зберігає шлях до файлу, що містить клас, який використовується під час завантаження. Через це ваш код може розбиватися різними способами, коли використовується в інших проектах або після рефакторів. "
Девід Міллер

Відповіді:


215

Я знайшов цю сторінку на їхньому github repo, я просто вставлю тут вміст.


Рекомендований підхід для збереження моделі

Існує два основні підходи до серіалізації та відновлення моделі.

Перший (рекомендований) зберігає та завантажує лише параметри моделі:

torch.save(the_model.state_dict(), PATH)

Потім пізніше:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Друга економить і завантажує всю модель:

torch.save(the_model, PATH)

Потім пізніше:

the_model = torch.load(PATH)

Однак у цьому випадку серіалізовані дані прив'язані до конкретних класів та точної використовуваної структури каталогів, тому вони можуть порушуватися різними способами при використанні в інших проектах або після деяких серйозних рефакторів.


8
Згідно з @smth дискусією.pytorch.org/ t/ saving- and- loading- a- model- in- pytorch/… модель перезавантажується для тренування моделі за замовчуванням. тому після завантаження потрібно вручну викликати the_model.eval (), якщо ви завантажуєте його для висновку, а не продовжуючи навчання.
WillZ

другий метод дає stackoverflow.com/questions/53798009/… помилку в Windows 10. не вдалося її вирішити
Gulzar

Чи є можливість зберегти без доступу до класу моделі?
Майкл Д

З таким підходом, як ви відстежуєте * args та ** kwargs, які вам потрібно пройти для випадку завантаження?
Маріано Камп

що не так з використанням pickle?
Чарлі Паркер

144

Це залежить від того, що ви хочете зробити.

Випадок №1: Збережіть модель, щоб використовувати її самостійно для висновку : Ви зберігаєте модель, відновлюєте її, а потім змінюєте модель на режим оцінки. Це робиться тому , що ви , як правило, BatchNormі Dropoutшари , які за замовчуванням знаходяться в режимі поїзди на будівництво:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Випадок №2: Збережіть модель, щоб відновити навчання пізніше : Якщо вам потрібно продовжувати навчати модель, яку ви збираєтеся зберегти, вам потрібно зберегти більше, ніж просто модель. Вам також потрібно зберегти стан оптимізатора, епохи, оцінка тощо. Ви зробите це так:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Щоб відновити навчання, ви зробили б такі речі, як: state = torch.load(filepath)а потім відновити стан кожного окремого об'єкта:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Оскільки ви поновлюєте навчання, НЕ дзвоніть, model.eval()коли ви відновите стани під час завантаження.

Випадок №3: Модель, яку повинен використовувати хтось інший, не маючи доступу до вашого коду : У Tensorflow ви можете створити .pbфайл, який визначає і архітектуру, і ваги моделі. Це дуже зручно, особливо при використанні Tensorflow serve. Еквівалентним способом зробити це в Pytorch було б:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Цей спосіб все ще не є кулезахисним, і оскільки pytorch все ще зазнає багатьох змін, я б не рекомендував його.


1
Чи є рекомендований файл, який закінчується для 3-х справ? Або це завжди .pth?
Верена Хауншмід

1
У справі №3 torch.loadповертається просто OrdersDict. Як ви отримуєте модель для того, щоб робити прогнози?
Alber8295

Привіт, чи можу я знати, як зробити згаданий "Випадок №2: Збережіть модель, щоб пізніше відновити навчання"? Мені вдалося завантажити контрольну точку на модель, тоді я не зміг запустити або відновити тренування моделі типу "model.to (пристрій) модель = train_model_epoch (модель, критерій, оптимізатор, графік, епохи)"
день

1
Привіт, для випадку, який призначений для висновку, в офіційному документі pytorch говорять, що необхідно зберегти optim_er state_dict або для виводу, або для завершення навчання. "Під час збереження загальної контрольної точки, яка буде використовуватися для виводу або відновлення навчання, ви повинні зберегти більше, ніж просто стан_виконання моделі. Важливо також зберегти state_dict оптимізатора, оскільки це містить буфери та параметри, які оновлюються як модель поїздів. . "
Мухаммед Ауні

1
У випадку №3 десь слід визначити клас моделі.
Майкл Д

12

Маринована бібліотека Python реалізує виконавчі протоколи для сериализации і де-сериализации об'єкта Python.

Коли ви import torch(або коли ви використовуєте PyTorch), це стане import pickleдля вас, і вам не потрібно дзвонити pickle.dump()і pickle.load()безпосередньо, що є методами збереження та завантаження об'єкта.

Насправді, torch.save()і torch.load()буде обгортати pickle.dump()і pickle.load()для вас.

state_dictЗгадується інший відповідь заслуговує лише кілька нот.

Що state_dictми маємо всередині PyTorch? Насправді два state_dictс.

Модель PyTorch - torch.nn.Moduleце model.parameters()дзвінок, щоб отримати параметри, що засвоюються (w і b). Ці параметри, що навчаються, колись випадково встановлені, з часом буде оновлюватися з часом. Параметри, що вивчаються, є першими state_dict.

Другий state_dict- дикта про стан оптимізатора. Ви пам'ятаєте, що оптимізатор використовується для покращення наших навчальних параметрів. Але оптимізатор state_dictвиправлений. Нічого навчитися там.

Оскільки state_dictоб'єкти є словниками Python, їх можна легко зберігати, оновлювати, змінювати та відновлювати, додаючи велику модульність моделям та оптимізаторам PyTorch.

Створимо надзвичайно просту модель для пояснення цього:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Цей код виведе наступне:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Зауважте, це мінімальна модель. Ви можете спробувати додати стек послідовних

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Зауважте, що лише шари з вивченими параметрами (згорткові шари, лінійні шари тощо) та зареєстровані буфери (шари batchnorm) мають записи у моделях state_dict.

Речі, які не вивчаються, належать до об'єкта оптимізатора state_dict, який містить інформацію про стан оптимізатора, а також використовувані гіперпараметри.

Решта історії та сама; у фазі умовиводу (це фаза, коли ми використовуємо модель після тренувань) для прогнозування; ми робимо прогноз на основі вивчених нами параметрів. Тож для висновку нам просто потрібно зберегти параметри model.state_dict().

torch.save(model.state_dict(), filepath)

І використовувати пізніший model.load_state_dict (torch.load (filepath)) model.eval ()

Примітка. Не забувайте останній рядок, який model.eval()є надзвичайно важливим після завантаження моделі.

Також не намагайтеся зберегти torch.save(model.parameters(), filepath). model.parameters()Тільки об'єкт генератора.

З іншого боку, torch.save(model, filepath)зберігає сам об’єкт моделі, але майте на увазі, що модель не має оптимізатора state_dict. Перевірте іншу чудову відповідь від @Jadiel de Armas, щоб зберегти диктат про стан оптимізатора.


Хоча це не є прямолінійним рішенням, суть проблеми глибоко аналізується! Оновлення
Джейсон Янг

7

Поширена умова PyTorch - це збереження моделей, використовуючи розширення файлу .pt або .pth.

Зберегти / завантажити цілу модель Зберегти:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Завантажте:

Клас моделі потрібно десь визначити

model = torch.load(PATH)
model.eval()

4

Якщо ви хочете зберегти модель і хочете відновити навчання пізніше:

Єдиний GPU: Зберегти:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Завантажте:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Кілька GPU: Зберегти

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Завантажте:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.