Custom Data Readers(自定义数据读取器)
自定义数据读取器
先决条件:
- 熟悉C ++。
- 必须下载TensorFlow源代码,并且能够构建它。
我们将支持文件格式的任务分成两部分:
- 文件格式:我们使用
Reader
Op 从文件中读取记录
(可以是任何字符串)。
- 记录格式:我们使用解码器或解析Ops将一个字符串记录转换为TensorFlow可用的张量。
例如,要读取CSV文件,我们使用Reader作为文本文件,然后使用Op来分析一行文本中的CSV数据。
为文件格式编写Reader
Reader
是从文件读取记录的东西。TensorFlow中已经内置了一些Reader
Ops的例子:
tf.TFRecordReader
(来源于kernels/tf_record_reader_op.cc
)
tf.FixedLengthRecordReader
(来源于kernels/fixed_length_record_reader_op.cc
)
tf.TextLineReader
(来源于kernels/text_line_reader_op.cc
)
你可以看到这些都暴露了相同的接口,唯一的区别在于它们的构造函数。最重要的方法是read
。它需要一个队列参数,这是它获取文件名以从需要时读取的文件名(例如,当read
op首次运行时,或前read
一次从文件读取最后一个记录时)。它产生两个标量张量:一个字符串键和一个字符串值。
要创建一个新的读者SomeReader
,你需要:
1. 在C ++中,定义一个tensorflow::ReaderBase
被调用的子类SomeReader
。
2. 在C ++中,用名称注册一个新的读取器操作系统和内核"SomeReader"
。
3. 在Python中,定义一个tf.ReaderBase
被调用的子类SomeReader
。
你可以把所有的C ++代码放在一个文件tensorflow/core/user_ops/some_reader_op.cc
中。读取文件的代码将存放在C ++ ReaderBase
类的后代中,C ++ 类定义在后者中tensorflow/core/kernels/reader_base.h
。您将需要实施以下方法:
OnWorkStartedLocked
:打开下一个文件
ReadLocked
:读取记录或报告EOF /错误
OnWorkFinishedLocked
:关闭当前文件,并
ResetLocked
:在例如错误之后得到干净的平板
这些方法的名称以“Locked”结尾,因为ReaderBase
在调用这些方法之前确保获得互斥体,所以您通常不必担心线程安全性(尽管只保护类的成员,而不是全局状态) 。
对于OnWorkStartedLocked
要打开的文件的名称是该current_work()
方法返回的值。ReadLocked
有这样的签名:
Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
如果ReadLocked
成功从文件中读取记录,则应填写:
*key
:带有记录的标识符,人可以用来再次查找该记录。你可以包含文件名current_work()
,并附加一个记录号码或其他。
*value
:与记录的内容。
*produced
:设为true
。
如果您点击文件末尾(EOF),请设置*at_end
为true
。无论哪种情况,都会返回Status::OK()
。如果出现错误,只需使用其中一个辅助函数即可返回它,tensorflow/core/lib/core/errors.h
而无需修改任何参数。
接下来,您将创建实际的Reader操作。如果您熟悉添加操作方法,这将有所帮助。主要步骤是:
- 注册操作。
- 定义并注册一个
OpKernel
。
要注册该操作,您将使用在中REGISTER_OP
定义的呼叫tensorflow/core/framework/op.h
。读者操作系统从不接受任何输入,并且始终只有一个带有类型的输出resource
。他们应该有字符串container
和shared_name
attrs。您可以选择定义额外的attrs进行配置或在文档中包含一个Doc
。例如,请参阅tensorflow/core/ops/io_ops.cc
:例如:
#include "tensorflow/core/framework/op.h"
REGISTER_OP("TextLineReader")
.Output("reader_handle: resource")
.Attr("skip_header_lines: int = 0")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc"
要定义一个OpKernel
,读者可以使用降序的快捷方式ReaderOpKernel
,定义tensorflow/core/framework/reader_op_kernel.h
和实现调用的构造函数SetReaderFactory
。定义你的课程后,你需要使用注册REGISTER_KERNEL_BUILDER(...)
。没有attrs的例子:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TFRecordReaderOp : public ReaderOpKernel {
public:
explicit TFRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
Env* env = context->env(
SetReaderFactory([this, env]() { return new TFRecordReader(name(), env }
}
};
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
TFRecordReaderOp
有attrs的一个例子:
#include "tensorflow/core/framework/reader_op_kernel.h"
class TextLineReaderOp : public ReaderOpKernel {
public:
explicit TextLineReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
int skip_header_lines = -1;
OP_REQUIRES_OK(context,
context->GetAttr("skip_header_lines", &skip_header_lines)
OP_REQUIRES(context, skip_header_lines >= 0,
errors::InvalidArgument("skip_header_lines must be >= 0 not ",
skip_header_lines)
Env* env = context->env(
SetReaderFactory([this, skip_header_lines, env]() {
return new TextLineReader(name(), skip_header_lines, env
}
}
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
TextLineReaderOp
最后一步是添加Python包装器。你可以通过编译一个动态库来实现,或者如果你是从源代码构建TensorFlow,添加到user_ops.py
。对于后者,您将导入tensorflow.python.ops.io_ops
在tensorflow/python/user_ops/user_ops.py
添加的后裔io_ops.ReaderBase
。
from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops
class SomeReader(io_ops.ReaderBase):
def __init__(self, name=None):
rr = gen_user_ops.some_reader(name=name)
super(SomeReader, self).__init__(rr)
ops.NotDifferentiable("SomeReader")
你可以看到一些例子tensorflow/python/ops/io_ops.py
。
为记录格式编写操作
通常这是一个普通的操作,它将标量字符串记录作为输入,因此按照说明添加操作。您可以选择使用标量字符串键作为输入,并将其包含在报告格式不正确的数据的错误消息中。这样用户可以更轻松地追踪坏数据的来源。
可用于解码记录的Ops示例:
tf.parse_single_example
(和tf.parse_example
)
tf.decode_csv
tf.decode_raw
请注意,使用多个Ops来解码特定的记录格式会很有用。例如,可能必须保存为一个字符串的图像一个tf.train.Example
协议缓冲器。根据该图像的格式,你可能会采取相应的输出从tf.parse_single_example
OP和呼叫tf.image.decode_jpeg
,tf.image.decode_png
或tf.decode_raw
。采用输出tf.decode_raw
和使用tf.slice
以及tf.reshape
提取碎片是很常见的。