Keras generator keeps shuffling though it is asked not to

Sergey Petrov

I use a Keras data generator initializing shuffle to false by default:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

  ...

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

   ...

And this is how I create instances for training and validation:

train_generator = data_generator(x_train[:num_train_examples], y_train[:num_train_examples], batch_size, data_dir)
val_generator = data_generator(x_train[num_train_examples:], y_train[num_train_examples:], batch_size, data_dir)

And then train the model:

model.fit_generator(train_generator,
                        validation_data=val_generator,
                        callbacks=[history],
                        epochs=num_epochs)

But the generator keeps producing random indexes:

starting training
Epoch 1/1

batch start: 0, batch end: 2

batch start: 24, batch end: 26

batch start: 2, batch end: 4

batch start: 114, batch end: 116

batch start: 4, batch end: 6

batch start: 60, batch end: 62

batch start: 6, batch end: 8

batch start: 68, batch end: 70

batch start: 8, batch end: 10

batch start: 94, batch end: 96

What can I do to make it not to shuffle?

A getitem function from the generator class:

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)

        return x_batch, y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))

        return np.array(x_batch), self.labels[index*self.batch_size:limit]

EDIT: Now I can see the pattern, looks like non-random batches alternate with random ones

Bashir Kazimi

I am assuming the problem might be in your __len__(self) function (if you have defined it that is). I added the __len__(self) function to your code and tried, it does not shuffle now. The code is here:

class data_generator(keras.utils.Sequence):
    def __init__(self, frames, labels, batch_size, data_dir, shuffle=False):
        'Initialization'
        self.batch_size = batch_size
        self.labels = labels
        self.frames = frames
        self.data_dir = data_dir
        self.shuffle = shuffle
        self.size = len(self.frames)
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(self.size/self.batch_size))

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.frames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        'Generate one batch of data'
        x_batch, y_batch = self.__data_generation(index)
        return x_batch, y_batch

    # def __data_generation(self, index):
    #     'Generates data containing batch_size samples'
    #     current_indices = self.indexes[index*self.batch_size:(index + 1)*self.batch_size]
    #     x_batch = []
    #     y_batch = []
    #     for idx in current_indices:
    #         # video_array = np.load(self.data_dir + '/' + self.frames[idx] + '.npy')
    #         # x_batch.append(np.array(video_array))
    #         y_batch.append(self.labels[idx])

    #     return np.array(x_batch), y_batch

    def __data_generation(self, index):
        'Generates data containing batch_size samples'
        limit = min(self.size, (index + 1)*self.batch_size)
        x_batch = []
        print('\nbatch start: ' + str(index*self.batch_size) + ', batch end: ' + str(limit))
        for frame in self.frames[index*self.batch_size:limit]:
            video_array = np.load(self.data_dir + '/' + frame + '.npy')
            x_batch.append(np.array(video_array))
        return np.array(x_batch), self.labels[index*self.batch_size:limit]

The above code works as you expected, it does not shuffle. However, the way you have defined your __data_generation function, it does not work if you want it to shuffle. Therefore, I wrote my own __data_generation function that you can see commented out. If you use this, you can get the functionality you desire. If shuffle is True, it will shuffle. If shuffle is False, it won't shuffle. Hope it helps.

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

Keras Generator keeps looping for no reason

keras predict_generator is shuffling its output when using a keras.utils.Sequence

Keras shuffling of multiple inputs

Setting Keras Variables in Generator

Keras: badalloc with custom generator

Using sqlite with keras generator

Select keeps resetting even though model updates

PHP system() keeps echo even though

.append keeps appending even though if statement is false

flask keeps returning "Username - This field is required. " when it is not asked

Octave keeps giving results from function although not asked

How to create a Generator keeps going after StopIteration?

Using fit_generator in Keras

generator called at the wrong time (keras)

Keras fit_generator issue

Custom Keras Data Generator with yield

Advantage of fit_generator() in keras

Keras Fit_generator Callback

Keras create your own generator

keras is not working on docker container but not throwing errors though

Firefox still keeps an input field with a red border though I removed it

REACT, setInterval keeps returning NAN even though state is a number

Why forEach keeps incrementing in js? But in console looks perfect though.

Pentaho Data Integration: The job keeps running even though it has succeeded

Java keeps giving me an error even though there is not red line thingy

Conda keeps trying to connect to a proxy even though it should be desabled

List keeps getting same first value, even though value changes

While statement keeps on looping even though it should not be technically possible

String keeps saying null on print even though it has data in it

TOP Ranking

  1. 1

    Failed to listen on localhost:8000 (reason: Cannot assign requested address)

  2. 2

    How to import an asset in swift using Bundle.main.path() in a react-native native module

  3. 3

    Loopback Error: connect ECONNREFUSED 127.0.0.1:3306 (MAMP)

  4. 4

    pump.io port in URL

  5. 5

    Spring Boot JPA PostgreSQL Web App - Internal Authentication Error

  6. 6

    BigQuery - concatenate ignoring NULL

  7. 7

    ngClass error (Can't bind ngClass since it isn't a known property of div) in Angular 11.0.3

  8. 8

    Do Idle Snowflake Connections Use Cloud Services Credits?

  9. 9

    maven-jaxb2-plugin cannot generate classes due to two declarations cause a collision in ObjectFactory class

  10. 10

    Compiler error CS0246 (type or namespace not found) on using Ninject in ASP.NET vNext

  11. 11

    Can't pre-populate phone number and message body in SMS link on iPhones when SMS app is not running in the background

  12. 12

    Generate random UUIDv4 with Elm

  13. 13

    Jquery different data trapped from direct mousedown event and simulation via $(this).trigger('mousedown');

  14. 14

    Is it possible to Redo commits removed by GitHub Desktop's Undo on a Mac?

  15. 15

    flutter: dropdown item programmatically unselect problem

  16. 16

    Change dd-mm-yyyy date format of dataframe date column to yyyy-mm-dd

  17. 17

    EXCEL: Find sum of values in one column with criteria from other column

  18. 18

    Pandas - check if dataframe has negative value in any column

  19. 19

    How to use merge windows unallocated space into Ubuntu using GParted?

  20. 20

    Make a B+ Tree concurrent thread safe

  21. 21

    ggplotly no applicable method for 'plotly_build' applied to an object of class "NULL" if statements

HotTag

Archive