Eu tenho um TF Dataset com o seguinte esquema:
tf_features = {
'searched_destination_ufi': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'booked_hotel_ufi': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0),
'user_id': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0),;
}
Eu também tenho um dict como:
candidates = {'111': [123, 444, ...], '222': [555, 888, ...]...}
Eu gostaria de realizar uma operação de mapa da seguinte maneira:
ds.map(lambda x, y: {**x, 'candidates': candidates[x['searched_destination_ufi'].numpy()]})
No entanto, sempre recebo:AttributeError: 'Tensor' object has no attribute 'numpy'
quando eu removo .numpy()
eu receboTypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
Você sugere alguma solução?
A função dataset.map
funciona no modo gráfico, onde .numpy()
não é possível chamar um tensor. Você pode tentar usar tf.py_function
para incluir os candidatos dict
em seu conjunto de dados:
import tensorflow as tf
tf_features = {
'searched_destination_ufi': ['111', '222'],
'booked_hotel_ufi': [2, 4],
'user_id': [3, 2]
}
ds = tf.data.Dataset.from_tensor_slices(tf_features)
candidates = {'111': [123, 444], '222': [555, 888]}
def py_func(x):
x = x.numpy().decode('utf-8')
return candidates[x]
ds = ds.map(lambda x: {**x, 'candidates': tf.py_function(py_func, [x['searched_destination_ufi']], [tf.int32]*2)})
for x in ds:
print(x)
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'111'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=3>, 'candidates': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([123, 444], dtype=int32)>}
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'222'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=4>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'candidates': <tf.Tensor: shape=(2,), dtype=int32, numpy=array([555, 888], dtype=int32)>}
Observe que [tf.int32]*2
corresponde ao comprimento das listas em candidates
.
Para uma abordagem mais sofisticada, você pode usar tf.lookup.StaticHashTable
and tf.gather
, que funcionará no modo gráfico:
import tensorflow as tf
tf_features = {
'searched_destination_ufi': ['111', '222'],
'booked_hotel_ufi': [2, 4],
'user_id': [3, 2]
}
ds = tf.data.Dataset.from_tensor_slices(tf_features)
candidates = {'111': [123, 444], '222': [555, 888]}
keys = list(candidates.keys())
values = tf.constant(list(candidates.values()))
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(tf.constant(keys), tf.range(len(keys))),
default_value=-1)
ds = ds.map(lambda x: {**x, 'candidates': tf.gather(values, [table.lookup(x['searched_destination_ufi'])])})
for x in ds:
print(x)
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'111'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=3>, 'candidates': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[123, 444]], dtype=int32)>}
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'222'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=4>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'candidates': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[555, 888]], dtype=int32)>}
Se o campo de candidatos for de comprimento variável, use um tensor irregular e a segunda abordagem, o restante do código permanece o mesmo:
candidates = {'111': [123, 444], '222': [555, 888, 323]}
keys = list(candidates.keys())
values = tf.ragged.constant(list(candidates.values()))
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'111'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=3>, 'candidates': <tf.RaggedTensor [[123, 444]]>}
{'searched_destination_ufi': <tf.Tensor: shape=(), dtype=string, numpy=b'222'>, 'booked_hotel_ufi': <tf.Tensor: shape=(), dtype=int32, numpy=4>, 'user_id': <tf.Tensor: shape=(), dtype=int32, numpy=2>, 'candidates': <tf.RaggedTensor [[555, 888, 323]]>}
Este artigo é coletado da Internet.
Se houver alguma infração, entre em [email protected] Delete.
deixe-me dizer algumas palavras