Маринована бібліотека Python реалізує виконавчі протоколи для сериализации і де-сериализации об'єкта Python.
Коли ви import torch
(або коли ви використовуєте PyTorch), це стане import pickle
для вас, і вам не потрібно дзвонити pickle.dump()
і pickle.load()
безпосередньо, що є методами збереження та завантаження об'єкта.
Насправді, torch.save()
і torch.load()
буде обгортати pickle.dump()
і pickle.load()
для вас.
state_dict
Згадується інший відповідь заслуговує лише кілька нот.
Що state_dict
ми маємо всередині PyTorch? Насправді два state_dict
с.
Модель PyTorch - torch.nn.Module
це model.parameters()
дзвінок, щоб отримати параметри, що засвоюються (w і b). Ці параметри, що навчаються, колись випадково встановлені, з часом буде оновлюватися з часом. Параметри, що вивчаються, є першими state_dict
.
Другий state_dict
- дикта про стан оптимізатора. Ви пам'ятаєте, що оптимізатор використовується для покращення наших навчальних параметрів. Але оптимізатор state_dict
виправлений. Нічого навчитися там.
Оскільки state_dict
об'єкти є словниками Python, їх можна легко зберігати, оновлювати, змінювати та відновлювати, додаючи велику модульність моделям та оптимізаторам PyTorch.
Створимо надзвичайно просту модель для пояснення цього:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Цей код виведе наступне:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Зауважте, це мінімальна модель. Ви можете спробувати додати стек послідовних
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Зауважте, що лише шари з вивченими параметрами (згорткові шари, лінійні шари тощо) та зареєстровані буфери (шари batchnorm) мають записи у моделях state_dict
.
Речі, які не вивчаються, належать до об'єкта оптимізатора state_dict
, який містить інформацію про стан оптимізатора, а також використовувані гіперпараметри.
Решта історії та сама; у фазі умовиводу (це фаза, коли ми використовуємо модель після тренувань) для прогнозування; ми робимо прогноз на основі вивчених нами параметрів. Тож для висновку нам просто потрібно зберегти параметри model.state_dict()
.
torch.save(model.state_dict(), filepath)
І використовувати пізніший model.load_state_dict (torch.load (filepath)) model.eval ()
Примітка. Не забувайте останній рядок, який model.eval()
є надзвичайно важливим після завантаження моделі.
Також не намагайтеся зберегти torch.save(model.parameters(), filepath)
. model.parameters()
Тільки об'єкт генератора.
З іншого боку, torch.save(model, filepath)
зберігає сам об’єкт моделі, але майте на увазі, що модель не має оптимізатора state_dict
. Перевірте іншу чудову відповідь від @Jadiel de Armas, щоб зберегти диктат про стан оптимізатора.