关于使用TensorFlow计算单热嵌入有一些堆栈溢出问题,这是公认的解决方案:
num_labels = 10
sparse_labels = tf.reshape(label_batch, [-1, 1])
derived_size = tf.shape(label_batch)[0]
indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
concated = tf.concat(1, [indices, sparse_labels])
outshape = tf.reshape(tf.concat(0, [derived_size, [num_labels]]), [-1])
labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0)
这几乎与官方教程中的代码相同:https : //www.tensorflow.org/versions/0.6.0/tutorials/mnist/tf/index.html
在我看来,既然tf.nn.embedding_lookup
存在,它可能会更有效率。这是使用此版本的版本,它支持任意形状的输入:
def one_hot(inputs, num_classes):
with tf.device('/cpu:0'):
table = tf.constant(np.identity(num_classes, dtype=np.float32))
embeddings = tf.nn.embedding_lookup(table, inputs)
return embeddings
您是否希望此实现更快?还有其他原因的缺陷吗?
one_hot()
您问题中的函数看起来正确。但是,我们不建议以此方式编写代码的原因是它的内存效率非常低。为了理解原因,假设您的批处理大小为32,并且具有1,000,000个类。
在本教程建议的版本中,最大张量将是的结果tf.sparse_to_dense()
,即32 x 1000000
。
在所讨论的one_hot()
函数中,最大张量将是的结果np.identity(1000000)
,即4 TB。当然,分配此张量可能不会成功。即使类的数量小得多,它仍然会浪费内存来显式存储所有这些零。TensorFlow不会自动将数据转换为稀疏表示,即使这样做可能会有利可图。
最后,我想提供一个新功能的插件,该插件最近已添加到开源存储库中,并且将在下一个版本中提供。tf.nn.sparse_softmax_cross_entropy_with_logits()
允许您指定一个整数向量作为标签,并使您不必构建密集的一键表示。相对于大量类的任何一种解决方案,效率都应该更高。
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句