在线文档教程

A Tool Developer's Guide to TensorFlow Model Files(工具开发者指南:TensorFlow模型文件)

TensorFlow模型文件的工具开发人员指南

大多数用户不需要关心TensorFlow如何将数据存储在磁盘上的内部细节,但如果您是工具开发人员,则可能会这样做。例如,您可能需要分析模型,或者在TensorFlow和其他格式之间来回转换。本指南试图解释一些关于如何使用保存模型数据的主要文件的细节,以便更容易地开发这些类型的工具。

协议缓冲区

所有TensorFlow的文件格式都基于协议缓冲区,所以开始值得熟悉它们的工作方式。总结是你在文本文件中定义数据结构,protobuf工具以C,Python和其他语言生成类,这些语言可以以友好的方式加载,保存和访问数据。我们经常将Protocol Buffers称为protobufs,我将在本指南中使用该约定。

GraphDef

TensorFlow中计算的基础是Graph对象。它拥有一个节点网络,每个节点代表一个操作,相互连接为输入和输出。创建Graph对象后,可以通过调用将其保存as_graph_def(),该GraphDef对象将返回一个对象。

GraphDef类是由定义在tensorflow / core / framework / graph.proto中的ProtoBuf库创建的对象。protobuf工具分析这个文本文件,并生成加载,存储和操作图形定义的代码。如果您看到代表模型的独立TensorFlow文件,则可能包含GraphDef由protobuf代码保存的这些对象之一的序列化版本。

此生成的代码用于保存和加载磁盘中的GraphDef文件。实际加载模型的代码如下所示:

graph_def = graph_pb2.GraphDef()

该行创建一个空GraphDef对象,该对象是从graph.proto中的文本定义创建的类。这是我们要用我们文件中的数据填充的对象。

with open(FLAGS.graph, "rb") as f:

这里我们得到了一个文件句柄,用于我们传递给脚本的路径

if FLAGS.input_binary: graph_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), graph_def)

文字还是二进制?

实际上有两种不同的格式可以保存ProtoBuf。TextFormat是一种人类可读的格式,它可以很好地进行调试和编辑,但是当数字数据(比如存储在其中的重量)时可以变大。你可以在graph_run_run2.pbtxt中看到一个小例子。

二进制格式文件比它们的文本等值文件小很多,尽管它们对我们来说不够可读。在这个脚本中,我们要求用户提供一个标志,指示输入文件是二进制还是文本,所以我们知道正确的调用函数。您可以在inception_v3存档中找到一个大型二进制文件的示例,如inception_v3_2016_08_28_frozen.pb

API本身可能有点令人困惑 - 二进制调用实际上是ParseFromString(),而您使用text_format模块中的实用程序函数来加载文本文件。

节点

一旦你将一个文件加载到graph_def变量中,你现在可以访问它里面的数据。对于大多数实际用途,重要部分是存储在节点成员中的节点列表。以下是循环访问的代码:

for node in graph_def.node

每个节点都是一个NodeDef对象,在tensorflow / core / framework / node_def.proto中定义。这些是TensorFlow图形的基本组成部分,每个图形都定义了一个操作以及其输入连接。这里是NodeDef的成员,以及它们的含义。

name

每个节点都应该有一个唯一的标识符,图中任何其他节点都不会使用该标识符。如果您在使用Python API构建图表时没有指定一个图表,则会为您选择一个反映操作名称的图表,例如“MatMul”,并与一个单调递增的数字(如“5”)相连接。在定义节点之间的连接时以及在运行时为整个图设置输入和输出时使用该名称。

op

这定义执行何种操作,例如"Add""MatMul""Conv2D"。运行图形时,会在注册表中查找该op名称以查找实现。注册表由对REGISTER_OP()宏的调用填充,如tensorflow / core / ops / nn_ops.cc中的调用。

input

字符串列表,每个字符串都是另一个节点的名称,可以选择后跟冒号和输出端口号。例如,一个有两个输入的节点可能有一个类似于["some_node_name", "another_node_name"](这相当于)的列表["some_node_name:0", "another_node_name:0"],并将该节点的第一个输入定义为具有该名称的节点的第一个输出,并将第一个输入定义为具有该名称的节点"some_node_name"的第一个输出"another_node_name"

device

在大多数情况下,您可以忽略它,因为它定义了在分布式环境中运行节点的位置,或者要将操作强制到CPU或GPU上。

attr

这是一个包含节点所有属性的关键/值存储。这些是节点的永久属性,在运行时不会改变的东西,例如卷积过滤器的大小,或常量操作的值。因为可以有很多不同类型的属性值,从字符串到整数,到张量值数组,在tensorflow / core / framework / attr_value.proto中有一个单独的protobuf文件来定义保存它们的数据结构。

每个属性都有唯一的名称字符串,并且在定义操作时会列出预期的属性。如果属性不存在于节点中,但其操作定义中列出了默认值,则在创建图形时使用该默认值。

你可以通过调用node.name来访问Python中的所有这些成员,node.op等。存储在节点中的节点列表GraphDef是模型体系结构的完整定义。

冷冻

一个令人困惑的部分是,在训练过程中权重通常不会存储在文件格式中。相反,它们被保存在单独的检查点文件中,并且Variable图表中有ops在初始化时加载最新值。在部署到生产环境中时,单独的文件通常不是很方便,因此有一个freeze_graph.py脚本,它接受图形定义和一组检查点,并将它们冻结成一个文件。

这样做是加载GraphDef,从最新的检查点文件中提取所有变量的值,然后将每个Variable操作替换Const为具有存储在其属性中权重的数字数据的操作。然后,将所有无关的节点用于前向推理,并将结果GraphDef保存到输出文件中。

重量格式

如果您正在处理代表神经网络的TensorFlow模型,最常见的问题之一是提取和解释重量值。存储它们的常用方法(例如,在由freeze_graph脚本创建的图表中)与Const包含权重as的ops一样Tensors。这些在tensorflow / core / framework / tensor.proto中定义,并包含有关数据大小和类型的信息以及值本身。在Python中,你可以通过调用类似的方法TensorProtoNodeDef代表Constop中获得一个对象some_node_def.attr['value'].tensor

这会给你一个表示权重数据的对象。数据本身将存储在其中一个带有后缀_val的列表中,如对象类型所示,例如float_val对于32位浮点数据类型。

在不同的框架之间转换时,卷积权值的排序通常很难处理。在TensorFlow中,操作的Conv2D的滤波器权重存储在第二个输入中,并且预期为顺序[filter_height, filter_width, input_depth, output_depth],其中filter_count增加1意味着移动到内存中的相邻值。

希望这个概要能够让你更好地了解TensorFlow模型文件中发生了什么,并且如果你需要操纵它们,它将会对你有所帮助。