Importing Data(数据导入)
导入数据
该Dataset
API使您能够构建从简单的,可重用代码的复杂的输入管道。例如,图像模型的管道可能会聚合来自分布式文件系统中的文件的数据,将随机扰动应用于每个图像,并将随机选择的图像合并为一批以进行操练。文本模型的流水线可能涉及从原始文本数据中提取符号,将它们转换为使用查找表嵌入标识符,以及将不同长度的序列进行批处理。该Dataset
API使处理大量数据,不同数据格式和复杂转换变得容易。
该Dataset
API为TensorFlow引入了两个新的概念:
tf.data.Dataset
表示一系列元素,其中每个元素包含一个或多个Tensor
对象。例如,在图像流水线中,元素可能是单个操练样例,其中一对张量表示图像数据和标签。有两种不同的方法来创建数据集:
- 创建
源
(例如Dataset.from_tensor_slices()
)从一个或多个tf.Tensor
对象构建数据集。
- 应用
转换
(例如Dataset.batch()
)从一个或多个tf.data.Dataset
对象构建数据集。
tf.data.Iterator
提供了从数据集中提取元素的主要方法。Iterator.get_next()
返回的操作产生Dataset
执行时的下一个元素,并且通常充当输入管道代码和模型之间的接口。最简单的迭代器是一个“一次迭代器”,它与特定的Dataset
迭代器相关联并迭代一次。对于更复杂的用途,该Iterator.initializer
操作使您可以使用不同的数据集重新初始化和参数化迭代器,以便您可以在同一个程序中多次迭代操练和验证数据。
基本原理
本指南的这一部分介绍了创建不同种类Dataset
和Iterator
对象的基础知识,以及如何从中提取数据。
要启动输入管道,您必须定义一个源
。例如,Dataset
要从内存中构造一些张量,可以使用tf.data.Dataset.from_tensors()
或tf.data.Dataset.from_tensor_slices()
。或者,如果您的输入数据以建议的TFRecord格式存储在磁盘上,则可以构建一个tf.data.TFRecordDataset
。
一旦你有一个Dataset
对象,你可以通过对对象的链式方法调用将它转换
成新Dataset
的tf.data.Dataset
对象。例如,您可以应用每元素转换
,例如Dataset.map()
(为每个元素应用一个函数)以及多元转换
,如Dataset.batch()
。请参阅文档以tf.data.Dataset
获取转换
的完整列表。
从一个Dataset
消费值的最常用方法是创建一个迭代器
对象,以便一次提供对数据集的一个元素的访问权限(例如,通过调用Dataset.make_one_shot_iterator()
)。一个tf.data.Iterator
提供了两个操作:Iterator.initializer
,它使您能够(重新)初始化迭代器
的状态; 并Iterator.get_next()
返回tf.Tensor
对应于符号下一个元素的对象。根据您的使用情况,您可能会选择不同类型的迭代器
,并在下面列出选项。
数据集结构
数据集包含每个都具有相同结构的元素。元素包含一个或多个tf.Tensor
称为组件的
对象。每个组件都有一个tf.DType
代表张量中元素的类型,并且tf.TensorShape
代表每个元素的(可能部分指定的)静态形状。Dataset.output_types
和Dataset.output_shapes
特性允许用户检查推断类型和数据集元素的每个部件的形状。这些属性的嵌套结构
映射到元素的结构,该元素可以是单张量,张量元组或张量的嵌套元组。例如:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
print(dataset1.output_types) # ==> "tf.float32"
print(dataset1.output_shapes) # ==> "(10,)"
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random_uniform([4]),
tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)))
print(dataset2.output_types) # ==> "(tf.float32, tf.int32)"
print(dataset2.output_shapes) # ==> "((), (100,))"
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
print(dataset3.output_types) # ==> (tf.float32, (tf.float32, tf.int32))
print(dataset3.output_shapes) # ==> "(10, ((), (100,)))"
给元素的每个元素命名通常比较方便,例如,如果它们表示操练示例的不同特征。除了元组之外,还可以使用collections.namedtuple
字典或字典将字符串映射到张量以表示一个Dataset
的单个元素。
dataset = tf.data.Dataset.from_tensor_slices(
{"a": tf.random_uniform([4]),
"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
Dataset
转换支持任何结构的数据集。当使用Dataset.map()
,Dataset.flat_map()
和Dataset.filter()
转换,其中应用一个函数到每个元件中,元件结构决定了函数的自变量:
dataset1 = dataset1.map(lambda x: ...)
dataset2 = dataset2.flat_map(lambda x, y: ...)
# Note: Argument destructuring is not available in Python 3.
dataset3 = dataset3.filter(lambda x, (y, z): ...)
创建迭代器
一旦你建立了一个Dataset
表示你的输入数据,下一步就是创建一个Iterator
访问该数据集中的元素。该Dataset
API目前支持以下迭代器,其级别越来越高:
一次性
可初始化
,
可重新初始化,
和
可馈入
一次性
的迭代器是迭代器的最简单的形式,只支持一次迭代通过数据集,而无需显式初始化。一次迭代器处理几乎所有现有的基于队列的输入流水线支持的情况,但它们不支持参数化。使用以下示例Dataset.range()
:
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):
value = sess.run(next_element)
assert i == value
注意:
目前,一次迭代器是唯一可以轻松使用的Estimator
类型。
可初始化的
迭代器需要你在使用iterator.initializer
之前运行一个明确的操作。为了交换这种不便,您可以使用一个或多个tf.placeholder()
张量进行参数化
数据集的定义,使得初始化迭代器时,张量可以输入。继续Dataset.range()
举例说明:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Initialize an iterator over a dataset with 10 elements.
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
# Initialize the same iterator over a dataset with 100 elements.
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value
可重新初始化的
迭代器可以从多个不同的Dataset
对象初始化。例如,您可能有一个操练输入流水线,它使用输入图像的随机扰动来改善泛化,以及一个验证输入流水线,用于评估对未修改数据的预测。这些管道通常会使用Dataset
具有相同结构的不同对象(即每个组件具有相同类型和兼容形状)。
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)
# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = Iterator.from_structure(training_dataset.output_types,
training_dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)
# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
for _ in range(100):
sess.run(next_element)
# Initialize an iterator over the validation dataset.
sess.run(validation_init_op)
for _ in range(50):
sess.run(next_element)
可馈入的
迭代器可以与tf.placeholder
一起通过熟悉的feed_dict
机制,在每次tf.Session.run
调用中选择Iterator
使用的内容。它提供了与可重新初始化的迭代器相同的功能,但不需要在迭代器之间切换时,从数据集的起始处初始化迭代器。例如,使用上述相同的操练和验证示例,您可以使用tf.data.Iterator.from_string_handle
定义可供给的迭代器,以便您在两个数据集之间切换:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})
从迭代器中消费值
该Iterator.get_next()
方法返回一个或多个tf.Tensor
与迭代器的符号下一个元素相对应的对象。每次评估这些张量时,它们都会获取基础数据集中下一个元素的值。(请注意,与TensorFlow中的其他有状态对象一样,调用Iterator.get_next()
并不会立即推进迭代器,而必须在TensorFlow表达式中使用返回的tf.Tensor
对象,并将该表达式的结果传递tf.Session.run()
给下一个元素并推进迭代器。)
如果迭代器到达数据集的末尾,执行该Iterator.get_next()
操作将引发一次tf.errors.OutOfRangeError
。在此之后,迭代器将处于不可用状态,如果要进一步使用它,则必须重新初始化它。
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Typically `result` will be the output of a model, or an optimizer's
# training operation.
result = tf.add(next_element, next_element)
sess.run(iterator.initializer)
print(sess.run(result)) # ==> "0"
print(sess.run(result)) # ==> "2"
print(sess.run(result)) # ==> "4"
print(sess.run(result)) # ==> "6"
print(sess.run(result)) # ==> "8"
try:
sess.run(result)
except tf.errors.OutOfRangeError:
print("End of dataset") # ==> "End of dataset"
一个常见的模式是包装在一个“操练循环” try
- except
块:
sess.run(iterator.initializer)
while True:
try:
sess.run(result)
except tf.errors.OutOfRangeError:
break
如果数据集的每个元素都有嵌套结构,则返回值Iterator.get_next()
将是同一个嵌套结构中的一个或多个tf.Tensor
对象:
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
需要注意的是评估任何
的next1
,next2
或next3
将推动迭代器的所有组件。迭代器的典型使用者将在一个表达式中包含所有组件。
读取输入数据
消费NumPy数组
如果所有的输入数据都适应内存,那么Dataset
从它们创建一个最简单的方法就是将它们转换为tf.Tensor
对象并使用Dataset.from_tensor_slices()
。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
请注意,上面的代码片段会将TensorFlow图形features
和labels
阵列作为tf.constant()
操作嵌入。这适用于小数据集,但会浪费内存---因为数组的内容将被复制多次---并且可以运行到tf.GraphDef
协议缓冲区的2GB限制。
作为替代方案,可以定义Dataset
在tf.placeholder()
张量,和嵌入
NumPy阵列,当你在数据集初始化Iterator
。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
消费TFRecord数据
该Dataset
API支持多种文件格式,因此您可以处理不适合内存的大型数据集。例如,TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序用于操练数据。tf.data.TFRecordDataset
类可以一个或多个TFRecord文件的内容流过作为输入管道的一部分。
# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
初始值设定项的filenames
参数TFRecordDataset
可以是字符串,字符串列表或tf.Tensor
字符串。因此,如果您有两组文件用于操练和验证目的,您可以使用tf.placeholder(tf.string)
来表示文件名,并使用适当的文件名初始化迭代器:
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})
消费文本数据
许多数据集都是作为一个或多个文本文件分发的。在tf.data.TextLineDataset
提供了一种简单的方法来提取一个或多个文本文件行。给定一个或多个文件名,TextLineDataset
会为这些文件的每行生成一个字符串值元素。像TFRecordDataset
,TextLineDataset
接受filenames
为a tf.Tensor
,所以你可以通过传递一个tf.placeholder(tf.string)
来化参数给它。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filenames)
默认情况下,每个文件的每一
行都会TextLineDataset
产生一个
文件,这可能不是预期的,例如文件以标题行开头或包含注释。这些行可以使用Dataset.skip()
和Dataset.filter()
转换来删除。要将这些转换分别应用于每个文件,我们使用Dataset.flat_map()
为每个文件创建一个
嵌套Dataset
。
filenames = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
# Use `Dataset.flat_map()` to transform each file as a separate nested dataset,
# and then concatenate their contents sequentially into a single "flat" dataset.
# * Skip the first line (header row).
# * Filter out lines beginning with "#" (comments).
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename)
.skip(1)
.filter(lambda line: tf.not_equal(tf.substr(line, 0, 1), "#"))))
有关使用数据集解析CSV文件的完整示例,请参阅imports85.py
回归示例。
使用预处理数据 Dataset.map()
该Dataset.map(f)
转换通过将给定函数f
应用于输入数据集的每个元素来产生新的数据集。它是基于map()
函数(其被共同地应用于在功能编程语言列表(和其它结构),该函数f
取tf.Tensor
表示对象输入中的单个元素,并返回tf.Tensor
将在新数据集中表示单个元素的对象。其实现使用标准的TensorFlow操作将一个元素转换为另一个元素。
本节涵盖如何使用的常见示例Dataset.map()
。
解析tf.Example协议缓冲区消息
许多输入管道tf.train.Example
从TFRecord格式文件(例如,使用tf.python_io.TFRecordWriter
)中提取协议缓冲区消息。每条tf.train.Example
记录都包含一个或多个“特征”,输入管道通常会将这些特征转换为张量。
# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):
features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int32, default_value=0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["image"], parsed_features["label"]
# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)
解码图像数据并调整其大小
当在真实世界的图像数据上训练神经网络时,经常需要将不同尺寸的图像转换为通用尺寸,以便它们可以批量化为固定尺寸。
# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_image(image_string)
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])
# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
应用任意的Python逻辑 tf.py_func()
出于性能原因,我们鼓励您尽可能使用TensorFlow操作预处理数据。但是,解析输入数据时调用外部Python库有时很有用。为此,请tf.py_func()
在Dataset.map()
转换中调用该操作。
import cv2
# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
return image_decoded, label
# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
image_decoded.set_shape([None, None, None])
image_resized = tf.image.resize_images(image_decoded, [28, 28])
return image_resized, label
filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
lambda filename, label: tuple(tf.py_func(
_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)
批处理数据集元素
简单批处理
最简单的批处理形式n
将数据集的连续元素堆叠到单个元素中。这种Dataset.batch()
转换完全是这样做的,与tf.stack()
操作符一样,约束条件应用于元素的每个元素:即对于每个元素i
,所有元素都必须具有完全相同形状的张量。
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)
iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> ([0, 1, 2, 3], [ 0, -1, -2, -3])
print(sess.run(next_element)) # ==> ([4, 5, 6, 7], [-4, -5, -6, -7])
print(sess.run(next_element)) # ==> ([8, 9, 10, 11], [-8, -9, -10, -11])
填充张量
上述配方适用于所有尺寸相同的张量。然而,许多模型(例如序列模型)与可能具有不同大小的输入数据(例如,不同长度的序列)一起工作。为了处理这种情况,该Dataset.padded_batch()
转换使您能够通过指定可能填充的一个或多个尺寸来批量处理不同形状的张量。
dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
print(sess.run(next_element)) # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element)) # ==> [[4, 4, 4, 4, 0, 0, 0],
# [5, 5, 5, 5, 5, 0, 0],
# [6, 6, 6, 6, 6, 6, 0],
# [7, 7, 7, 7, 7, 7, 7]]
该Dataset.padded_batch()
转换允许您为每个组件的每个维设置不同的填充,并且它可以是可变长度(None
在上面的示例中表示)或恒定长度。也可以重写填充值,该值默认为0。
操练工作流程
多时期处理
该Dataset
API提供了两种主要方式来处理相同数据的多个时期。
在多个时期迭代数据集的最简单方法是使用Dataset.repeat()
变换。例如,要创建一个重复10个时期输入的数据集:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.repeat(10)
dataset = dataset.batch(32)
应用Dataset.repeat()
不带参数的转换将无限期地重复输入。这个Dataset.repeat()
转换连接了它的论点,而没有表明一个时期的结束和下一个时期的开始。
如果你想在每个时期结束时收到一个信号,你可以编写一个训练循环来捕捉tf.errors.OutOfRangeError
数据集的末尾。此时,您可能会收集一些统计数据(例如验证错误)。
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# Compute for 100 epochs.
for _ in range(100):
sess.run(iterator.initializer)
while True:
try:
sess.run(next_element)
except tf.errors.OutOfRangeError:
break
# [Perform end-of-epoch calculations here.]
随机清洗输入数据
The Dataset.shuffle()
transformation randomly shuffles the input dataset using a similar algorithm to tf.RandomShuffleQueue
: it maintains a fixed-size buffer and chooses the next element uniformly at random from that buffer.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat()
高级API使用
该tf.train.MonitoredTrainingSession
API简化了在分布式环境中运行TensorFlow的许多方面。MonitoredTrainingSession
使用该tf.errors.OutOfRangeError
信号表示操练已完成,因此要将其与Dataset
API 一起使用,建议使用Dataset.make_one_shot_iterator()
。例如:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:
while not sess.should_stop():
sess.run(training_op)
要使用Dataset
中的input_fn
某个tf.estimator.Estimator
,我们也推荐使用Dataset.make_one_shot_iterator()
。例如:
def dataset_input_fn():
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
# Use `tf.parse_single_example()` to extract data from a `tf.Example`
# protocol buffer, and perform any additional per-record preprocessing.
def parser(record):
keys_to_features = {
"image_data": tf.FixedLenFeature((), tf.string, default_value=""),
"date_time": tf.FixedLenFeature((), tf.int64, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.decode_jpeg(parsed["image_data"])
image = tf.reshape(image, [299, 299, 1])
label = tf.cast(parsed["label"], tf.int32)
return {"image_data": image, "date_time": parsed["date_time"]}, label
# Use `Dataset.map()` to build a pair of a feature dictionary and a label
# tensor for each example.
dataset = dataset.map(parser)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
# `features` is a dictionary in which each value is a batch of values for
# that feature; `labels` is a batch of labels.
features, labels = iterator.get_next()
return features, labels