TensorFlow збереження в / завантаження графіка з файлу


98

З того, що я зібрав до цих пір, існує кілька різних способів скидання графіка TensorFlow у файл і потім завантаження його в іншу програму, але я не зміг знайти чітких прикладів / інформації про те, як вони працюють. Я вже знаю це:

  1. Збережіть змінні моделі у файл контрольної точки (.ckpt) за допомогою a tf.train.Saver()та відновіть їх пізніше ( джерело )
  2. Збережіть модель у .pb-файл та завантажте її знову за допомогою tf.train.write_graph()та tf.import_graph_def()( source )
  3. Завантажте модель з .pb-файлу, перепідготовте її та скиньте її в новий .pb-файл за допомогою Bazel ( джерело )
  4. Заморожте графік, щоб зберегти графік та ваги разом ( джерело )
  5. Використовуйте as_graph_def()для збереження моделі, а для ваг / змінних наносіть їх на константи ( джерело )

Однак мені не вдалося прояснити кілька питань стосовно цих різних методів:

  1. Що стосується файлів контрольних точок, чи зберігають вони лише навчені ваги моделі? Чи можуть файли контрольної точки завантажуватися в нову програму та використовуватися для запуску моделі, чи вони просто служать способом збереження ваг у моделі в певний час / етап?
  2. Що стосується того tf.train.write_graph(), чи зберігаються також ваги / змінні?
  3. Що стосується Bazel, чи може вона зберігати / завантажувати лише з .pb файлів для перепідготовки? Чи є проста команда Bazel просто скинути графік у .pb?
  4. Що стосується заморожування, чи можна завантажувати заморожений графік за допомогою tf.import_graph_def()?
  5. Демо-версія Android для TensorFlow завантажується в модель Inception Google з файлу .pb. Якби я хотів замінити власний .pb файл, як би я це зробив? Чи потрібно мені змінити будь-який нативний код / ​​методи?
  6. Загалом, у чому саме різниця між усіма цими методами? Або, ширше, в чому різниця між as_graph_def()/.ckpt/.pb?

Якщо коротко, то, що я шукаю, це метод збереження як графіка (як, у різних операціях і подібних), так і його ваг / змінних у файл, який потім може бути використаний для завантаження графіка та ваг в іншу програму , для використання (не обов’язково продовження / перекваліфікація).

Документація на цю тему не є однозначною, тому будь-які відповіді / інформація будуть дуже вдячні.


2
Найновіший / найповніший API - це мета-графік, який дає вам можливість зберегти всі три одночасно - 1) графік 2) значення параметрів 3) колекції: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Ярослав Булатов

Відповіді:


80

Існує багато способів підходу до проблеми збереження моделі в TensorFlow, що може зробити її трохи заплутаною. Виконуючи по черзі кожне з ваших підзапитів:

  1. Файли контрольних точок (наприклад , проводиться шляхом виклику saver.save()на tf.train.Saverоб'єкт) містять тільки ваги, і будь-які інші змінні , визначені в одній і тій же програмі. Щоб використовувати їх в іншій програмі, ви повинні заново створити пов’язану структуру графа (наприклад, запустівши код, щоб знову створити її, або зателефонувавши tf.import_graph_def()), яка підкаже TensorFlow, що робити з цими вагами. Зауважте, що виклик saver.save()також створює файл, що містить a MetaGraphDef, який містить графік та детальну інформацію про те, як пов’язати ваги з контрольної точки з цим графіком. Докладніше див. У підручнику .

  2. tf.train.write_graph()записує лише структуру графа; не ваги.

  3. Базель не пов'язаний з читанням або написанням графіків TensorFlow. (Можливо, я неправильно розумію ваше запитання: сміливо уточнюйте це у коментарі.)

  4. Заморожений графік можна завантажити за допомогою tf.import_graph_def(). У цьому випадку ваги (як правило) вбудовуються у графік, тому не потрібно завантажувати окрему контрольну точку.

  5. Основна зміна полягала б у тому, щоб оновити імена тензора (ів), які подаються в модель, та імена тензорів (ив), які отримані з моделі. У демке TensorFlow Android, це буде відповідати inputNameі outputNameрядки, які передаються TensorFlowClassifier.initializeTensorFlow().

  6. Це GraphDefструктура програми, яка зазвичай не змінюється в процесі навчання. Контрольна точка - це знімок стану тренувального процесу, який зазвичай змінюється на кожному етапі навчального процесу. Як результат, TensorFlow використовує різні формати зберігання даних для цих типів даних, а API низького рівня забезпечує різні способи їх збереження та завантаження. Бібліотеки більш високого рівня, такі як MetaGraphDefбібліотеки, Keras і skflow на основі цих механізмів , щоб забезпечити більш зручні способи збереження і відновлення цілої моделі.


Чи означає це, що документація API C ++ лежить, коли вона говорить про те, що ви можете завантажити збережений графік tf.train.write_graph()і потім виконати його?
mnicky

2
Документація API C ++ не бреше, але в ній відсутні кілька деталей. Найважливіша деталь полягає в тому, що, крім GraphDefзбережених tf.train.write_graph(), вам також потрібно пам’ятати імена тензорів, які ви хочете подати та отримати під час виконання графіка (пункт 5 вище).
mrry

@mrry: Я намагався використовувати приклад DeensDream tensorflows. але, схоже, йому потрібні перевірені моделі у форматі pb! Я запустив приклад Cifar10, але він створює лише контрольні точки! Я не міг знайти будь-яких файлів pb чи що б там не було! як я можу конвертувати свої контрольні точки у формат pb, який використовує приклад deepdream?
Ріка

2
@ Coderx7 Я дійсно думаю, що ви не можете перетворити .ckpt в .pb, оскільки контрольна точка містить лише ваги та змінні і нічого не знає про структуру графіка
davidivad

1
чи є простий код для завантаження .pb-файлу та його запуску?
Конг

1

Ви можете спробувати наступний код:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Використовуючи наш веб-сайт, ви визнаєте, що прочитали та зрозуміли наші Політику щодо файлів cookie та Політику конфіденційності.
Licensed under cc by-sa 3.0 with attribution required.