tf.GradientTape giving None gradient while writing custom training loop

Al Shahreyaj

I'm trying to write a custom training loop. Here is a sample code of what I'm trying to do. I have two training parameter and one parameter is updating another parameter. See the code below:

x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)

with tf.GradientTape() as tape:
    n = x2 + 4
    x = x1 + 1
    y = x**2
    val = tape.gradient(y, [x1, x2])
    for v in val:

and the output is

tf.Tensor(12.0, shape=(), dtype=float32)

It seems like GradientTape is not watching the first(x2) parameter. Both parameter is tf.Variable type, so GradientTape should watch both the parameter. I also tried, which is also not working. Am I missing something?


Check the docs regarding a gradient of None. To get the gradients for x1, you have to track x with

x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)

with tf.GradientTape() as tape:
    n = x2 + 4
    x = x1 + 1
    y = x**2

dv0, dv1 = tape.gradient(y, [x1, x2])

However, regarding x2, the output y is not connected to x2 at all, since x1.assign(n) does not seem to be tracked and that is why the gradient is None. This is consistent with the docs:

State stops gradients. When you read from a stateful object, the tape can only observe the current state, not the history that lead to it.

A tf.Tensor is immutable. You can't change a tensor once it's created. It has a value, but no state. All the operations discussed so far are also stateless: the output of a tf.matmul only depends on its inputs.

A tf.Variable has internal state—its value. When you use the variable, the state is read. It's normal to calculate a gradient with respect to a variable, but the variable's state blocks gradient calculations from going farther back

If, for example, you do something like this:

x1 = tf.Variable(1.0, dtype=float)
x2 = tf.Variable(1.0, dtype=float)

with tf.GradientTape() as tape:
    n = x2 + 4
    x1 = n
    x = x1 + 1
    y = x**2 

dv0, dv1 = tape.gradient(y, [x1, x2])

It should work.

