Determine batch size during `tensorflow.keras` Custom Class `call` method

Jed

I already asked this question here, but I thought StackOverflow would have more traffic/people that might know the answer.

I'm building a custom keras Layer similar to an example found here. I want the call method inside the class to be able to know what the batch_size of the inputs data flowing through the method is, but the inputs.shape is showing as (None, 3) during model prediction. Here's a concrete example:

I initialize a simple data set like this:

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model

# Create fake data to use for model testing
n = 1000
np.random.seed(123)
x1 = np.random.random(n)
x2 = np.random.normal(0, 1, size=n)
x3 = np.random.lognormal(0, 1, size=n)

X = pd.DataFrame(np.concatenate([
    np.reshape(x1, (-1, 1)),
    np.reshape(x2, (-1, 1)),
    np.reshape(x3, (-1, 1)),
], axis=1))

Then I define a custom class to test/show what I'm talking about:

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)

        print(inputs)
        record_count, n = inputs.shape
        print(f'inputs.shape = {inputs.shape}')

        return inputs

Then, when I create a simple model and force it to do a forward pass...

input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])

... I get this output printed to the screen

model.predict(X.loc[:9, :])
Tensor("model_1/Cast:0", shape=(None, 3), dtype=float32)
inputs.shape = (None, 3)
1/1 [==============================] - 0s 28ms/step
Out[34]: 
array([[ 0.5335418 ,  0.7788839 ,  0.64132416],
       [ 0.2924202 , -0.08321562,  0.412311  ],
       [ 0.5118007 , -0.6822934 ,  1.1782378 ],
       [ 0.03780456, -0.19350041,  0.7637337 ],
       [ 0.86494124, -3.196387  ,  4.8535166 ],
       [ 0.26708454, -0.49397194,  0.91296834],
       [ 0.49734482, -1.6618049 ,  0.50054324],
       [ 0.8563762 ,  0.7956695 ,  0.29466265],
       [ 0.7682351 ,  0.86538637,  0.6633331 ],
       [ 0.85322225,  0.868021  ,  0.1776046 ]], dtype=float32)

You can see that during the model.predict call the inputs.shape prints out a value of (None, 3), but obviously that's not true since the call method returns an output with a shape of (10, 3). How can I capture the 10 value in this example while in the call method?

UPDATE 1

When I use tf.shape as suggested in the current answer, I can print the value to the screen, but I get an error when I try to capture that value in a variable.

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        record_count, n = tf.shape(inputs)
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

This code causes an error on the record_count, ... line.

Traceback (most recent call last):
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-22-104d812c32e6>", line 1, in <module>
    test = TestClass()(input_layer)
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/username/opt/miniconda3/envs/myenv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling layer "test_class_4" (type TestClass).
in user code:
    File "<ipython-input-21-2dec1d5b9547>", line 12, in call  *
        record_count, n = tf.shape(inputs)
    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Call arguments received by layer "test_class_4" (type TestClass):
  • inputs=tf.Tensor(shape=(None, 3), dtype=float32)

I tried decorating the call method with @tf.function, but I get the same error.

UPDATE 2

I tried a couple other things and found that, oddly, tensorflow doesn't seem to like the tuple assignment. It seems to work fine if it's coded like this instead.

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        shape = tf.shape(inputs)
        record_count = shape[0]
        n = shape[1]
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs
Frightera

TL;DR --> Use tf.shape(inputs)[0] if you want to capture dynamic batch size in call method, or you can just use static batch size which can be specified in model creation.


Under the hood TensorFlow decorates call and __call__ (that's what call method calls) method with tf.function. Using print and .shape will not work as expected.

With tf.function python codes are traced and converted to native TensorFlow operations. After that, a static graph is created, this is just an instance of tf.Graph. In the end, the operations are executed in that graph.

Python's print function only considered in the first step only, so this is not the correct way to print things in graph mode (decorated with tf.function).

Tensor shapes are dynamic in runtime so you need to use tf.shape(inputs)[0] which will give you the batch size for that batch.

If you really want to see that 10 in call:

class TestClass(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestClass, self).__init__(**kwargs)

    def get_config(self):
        config = super(TestClass, self).get_config()
        return config

    def call(self, inputs: tf.Tensor):
        if inputs.dtype.base_dtype != self._compute_dtype_object.base_dtype:
            inputs = tf.cast(inputs, dtype=self._compute_dtype_object)
        tf.print("Dynamic batch size", tf.shape(inputs)[0])
        return inputs

Running:

input_layer = layers.Input(3)
test = TestClass()(input_layer)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025)
model = Model(input_layer, test)
model.compile(loss='mse', optimizer=optimizer, metrics=['mae', 'mse'])
model.predict(X.loc[:9, :])

Will return:

Dynamic batch size 10
1/1 [==============================] - 0s 65ms/step
array([[ 6.9646919e-01, -1.0032653e-02,  3.7556963e+00],
       [ 2.8613934e-01, -8.4564441e-01,  9.9685013e-01],
       [ 2.2685145e-01,  9.1146064e-01,  6.5008003e-01],
       [ 5.5131477e-01, -1.3744969e+00,  8.6379850e-01],
       [ 7.1946895e-01, -5.4706562e-01,  3.1904945e+00],
       [ 4.2310646e-01, -7.5526608e-05,  5.2649558e-01],
       [ 9.8076421e-01, -1.2116680e-01,  7.4064606e-01],
       [ 6.8482971e-01, -2.0085855e+00,  5.3138912e-01],
       [ 4.8093191e-01, -9.2064655e-01,  8.1520426e-01],
       [ 3.9211753e-01,  1.6823435e-01,  1.2382457e+00]], dtype=float32)

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

Tensorflow batch size in input placholder

Keras: TypeError with batch_size

Determine the super class reference used to call method in subclass

Batch shape in Keras Custom layer call method

batch size setting during training

batch_dot with variable batch size in Keras

Why is the batch size None in the method call of a Keras layer?

How to fix the batch size in Keras?

How to get batch_size if shape method in Keras & TF returns None for the batch_size?

Get batch size in Keras custom layer and use tensorflow operations (tf.Variable)

Tensorflow tensordot for unknown batch size

Defining custom gradient as a class method in Tensorflow

Keras Model using Tensorflow Distribution for loss fails with batch size > 1

How to use batch size to create a tensor within a custom TensorFlow Layer

Tensorflow with batch size and wrong dimesion

irregular/varying batch size in tensorflow?

Determine if a class implements a generic interface and then call a the interfaces method

Error in batch size with custom loss function in Keras

TensorFlow Custom Layer: Get the actual Batch Size

How to call class method from JNI and cast the return to a custom class

Cant call method of custom ConstraintLayout class?

Problem using Elmo from tensorflow hub as custom tf.keras layer during prediction

Batch size of 1 in custom data generator in Keras

How to call a method as a custom callback in Keras?

Tensorflow training with variable batch size

How to write a custom call function for a Tensorflow LSTM class?

Tensorflow - batch size issue thrown

How to call a method of a custom class on a parameter of Object class?

Error using batch_size inside custom TensorFlow layer