Comment décorer une fonction qui prend un tf.variable comme paramètre avec tf.function et surtout en utilisant la signature d'entrée

Boubaker

J'ai un problème où je dois modifier une variable dans une fonction Tensorflow. Ensuite, je dois convertir cette fonction en un graphique tensorflow. Le problème est que la taille de la variable n'est pas fixe. Exemple : il peut s'agir d'un ténosr de forme (3,) ou (2,). C'est pourquoi la fonction prend cette variable en paramètre, pour pouvoir la modifier et la retourner.

Voici un exemple de classe qui contient un appel de fonction , cette fonction prend deux arguments (x,v). x est un Tf.tenseur et v est un tf.Variable. v est affecté à la multiplication de x*v.

import tensorflow as tf

class MyModule(tf.Module):
  def __init__(self):
    pass

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.int32), tf.TensorSpec(shape=[None], dtype=tf.int32)])
  def __call__(self, x, v):
    v.assign(x*v, read_value=False)
    return v

tf.config.run_functions_eagerly(False)
x = tf.constant([10,10])
v = tf.Variable(2*tf.ones_like(x), trainable=False)

module = MyModule()
module(x, v)

Cela fonctionne comme prévu en mode impatient, mais en mode graphique, j'obtiens l'erreur suivante : AttributeError : l'objet 'Tensor' n'a pas d'attribut 'assign'

Je sais que c'est à cause de la signature de tf.Variable. Ma question est comment puis-je spécifier la signature de tf.Variable étant donné que l'actuel produit une erreur?

Effrayer

En fait, il existe une opération qui peut réaliser ce que vous voulez, mais elle n'est pas répertoriée dans l'API publique. Attention, ce n'est peut-être pas la meilleure pratique.

Vous avez besoin resource_variable_opsde ce que vous pouvez trouver sous tensorflow.python.ops.

import tensorflow as tf
from tensorflow.python.ops import resource_variable_ops

class MyModule(tf.Module):
  def __init__(self):
    pass

  @tf.function(input_signature=[
                                tf.TensorSpec(shape=[None], dtype=tf.int32), 
                                resource_variable_ops.VariableSpec(shape=[None], dtype=tf.int32)
                                ])
  def __call__(self, x, v):
    v.assign(x*v, read_value=False)
    return v

x = tf.constant([10,10])
v = tf.Variable(2*tf.ones_like(x), trainable=False)

module = MyModule()
module(x, v)

Cet article est collecté sur Internet, veuillez indiquer la source lors de la réimpression.

En cas d'infraction, veuillez [email protected] Supprimer.

modifier le
0

laisse moi dire quelques mots

0commentaires
connexionAprès avoir participé à la revue

Articles connexes

TOP liste

  1. 1

    comment afficher un bouton au-dessus d'un autre élément ?

  2. 2

    impossible d'obtenir l'image d'arrière-plan en plein écran dans reactjs

  3. 3

    Je continue à obtenir l'objet 'WSGIRequest' n'a pas d'attribut 'Get' sur django

  4. 4

    comment supprimer "compte de connexion google" à des fins de développement - actions sur google

  5. 5

    Conversion double en BigDecimal en Java

  6. 6

    Impossible d'accéder à la vue personnalisée pendant le test de l'interface utilisateur dans XCode

  7. 7

    Algorithme: diviser de manière optimale une chaîne en 3 sous-chaînes

  8. 8

    Passer la taille d'un tableau 2D à une fonction ?

  9. 9

    Comment obtenir l'intégration contextuelle d'une phrase dans une phrase à l'aide de BERT ?

  10. 10

    Comment changer le navigateur par défaut en Microsoft Edge pour Jupyter Notebook sous Windows 10 ?

  11. 11

    CSS: before ne fonctionne pas sur certains éléments,: after fonctionne très bien

  12. 12

    Comment créer un bot à compte à rebours dans Discord en utilisant Python

  13. 13

    Comment ajouter une entrée à une table de base de données pour une combinaison de deux tables

  14. 14

    Exporter la table de l'arborescence vers CSV avec mise en forme

  15. 15

    Comment activer le message Pylint "too-many-locals" dans VS Code?

  16. 16

    Créer un système Buzzer à l'aide de python

  17. 17

    Spring @RequestParam DateTime format comme ISO 8601 Date Heure facultative

  18. 18

    Empêcher l'allocation de mémoire dans la génération de combinaison récursive

  19. 19

    Déplacement des moindres carrés d'ajustement pour les déplacements de points ayant des problèmes

  20. 20

    Comment choisir le nombre de fragments et de répliques Elasticsearch

  21. 21

    Microsoft.WebApplication.targets

chaudétiquette

Archive