我正在使用TF Hub教程通过文本分类学习tensorflow 2。它使用了来自TF集线器的嵌入模块。我想知道是否可以修改模型以包含LSTM层。这是我尝试过的:
train_data, validation_data, test_data = tfds.load(
name="imdb_reviews",
split=('train[:60%]', 'train[60%:]', 'test'),
as_supervised=True)
embedding = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1"
hub_layer = hub.KerasLayer(embedding, input_shape=[],
dtype=tf.string, trainable=True)
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Embedding(10000, 50))
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(1))
model.summary()
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(train_data.shuffle(10000).batch(512),
epochs=10,
validation_data=validation_data.batch(512),
verbose=1)
results = model.evaluate(test_data.batch(512), verbose=2)
for name, value in zip(model.metrics_names, results):
print("%s: %.3f" % (name, value))
我不知道如何从hub_layer获取词汇量。所以我只把10000放在那里。运行它时,它将引发以下异常:
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[480,1] = -6 is not in [0, 10000)
[[node sequential/embedding/embedding_lookup (defined at .../learning/tensorflow/text_classify.py:36) ]] [Op:__inference_train_function_36284]
Errors may have originated from an input operation.
Input Source operations connected to node sequential/embedding/embedding_lookup:
sequential/embedding/embedding_lookup/34017 (defined at Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/contextlib.py:112)
Function call stack:
train_function
我困在这里。我的问题是:
我应该如何使用TF集线器中的嵌入模块来填充LSTM层?看起来嵌入查询的设置存在一些问题。
如何从中心层获取词汇量?
谢谢
最终找到了将预训练的嵌入链接到LSTM或其他层的方法。只要在这里张贴步骤,以防任何人觉得有帮助。
嵌入层必须是模型中的第一层。(hub_layer与嵌入层相同。)不是很直观的部分是,输入到集线器层的任何文本都将仅转换为形状为[embedding_dim]的一个矢量。您需要进行句子拆分和标记化,以确保对模型的任何输入都是数组形式的序列。例如,“让我们准备数据”。应该转换为[[“ let”],[“ us”],[“ prepare”],[“ the”],[“ data”]]。如果您使用批处理模式,则还需要填充序列。
此外,如果您的训练标签是字符串,则需要将目标标记转换为int。模型的输入是形状为[batch,seq_length]的字符串数组,集线器嵌入层将其转换为[batch,seq_length,embed_dim]。(如果添加LSTM或其他RNN层,则该层的输出为[batch,seq_length,rnn_units]。)输出密集层将输出文本索引,而不是实际文本。文本索引作为“ tokens.txt”存储在下载的tfhub目录中。您可以加载文件并将文本转换为相应的索引。否则,您将无法计算损失。
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句