How to sample batch from a specific class?

Richard_wth

I'd like to train a classifier on one ImageNet dataset (1000 classes each with around 1300 images). For some reason, I need each batch to contain 64 images from a specific class (provided as int or placeholder). How to do it efficiently with the latest TensorFlow?

This is a follow-up question to How to sample batch from only one class at each iteration.

My current thought is to use tf.data.Dataset.filter:

specific_class = 2  # as an example

dataset = tf.data.TFRecordDataset(filenames)
# __parser_fun__ produces datum tuple (x, y)
dataset = dataset.map(__parser_fun__, num_parallel_calls=num_threads)
dataset = dataset.shuffle(20000)
# print(dataset) gives <ShuffleDataset shapes: ((3, 128, 128), (1,)), 
# types: (tf.float32, tf.int64)>

dataset = dataset.filter(lambda x, y: tf.equal(y[0], specific_class))
dataset = dataset.batch(64)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch, y_batch = iterator.get_next()

A minor problem with filter is that I need to construct an iterator every time I want to sample from a new class.

Another idea is to use tf.contrib.data.rejection_resample but it seems prohibitive computationally (or is it?).

I wonder if there is other efficient way to sample batches from a particular class?

Allen Lavoie

Conceptually your Dataset is parameterized by a variable (the label to sample). This is totally doable!

Executing eagerly:

import numpy as np
import tensorflow as tf
tf.enable_eager_execution()

data = dict(
    x=tf.constant([1., 2., 3., 4.]),
    y=tf.constant([1, 2, 1, 2])
)

requested_label = tf.Variable(1)
dataset = (
    tf.data.Dataset.from_tensor_slices(data)
    .repeat()
    .filter(lambda d: tf.equal(d["y"], requested_label)))


it = dataset.make_one_shot_iterator()
for i, datum in enumerate(it):
  assert int(datum["y"]) == 1
  assert float(datum["x"]) in [1., 3.]
  if i > 5:
    break

requested_label.assign(2)

for i, datum in enumerate(it):
  assert int(datum["y"]) == 2
  assert float(datum["x"]) in [2., 4.]
  if i > 5:
    break

Graph building:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
  data = dict(
      x=tf.constant([1., 2., 3., 4.]),
      y=tf.constant([1, 2, 1, 2])
  )

  requested_label = tf.Variable(1)
  dataset = (
      tf.data.Dataset.from_tensor_slices(data)
      .repeat()
      .filter(lambda d: tf.equal(d["y"], requested_label)))


  it = dataset.make_initializable_iterator()
  datum_tensors = it.get_next()
  switch_label_op = requested_label.assign(2)

  graph.finalize()
  with tf.Session() as session:
    session.run(requested_label.initializer)  # label=1
    session.run(it.initializer)
    for _ in range(5):
      datum = session.run(datum_tensors)
      assert int(datum["y"]) == 1
      assert float(datum["x"]) in [1., 3.]

    session.run(switch_label_op)  # label=2

    for _ in range(5):
      datum = session.run(datum_tensors)
      assert int(datum["y"]) == 2
      assert float(datum["x"]) in [2., 4.]

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

How to sample a # of rows from a specific class in python?

How to sample batch from only one class at each iteration

how to return sample json data in response from angular service class

Sample from a data frame using group-specific sample sizes

Snakemake and Pandas syntax: Getting sample specific parameters from the sample table

How to call a specific Python function from a batch file?

How to Kill Running Process from a specific folder via Batch File

How to check in only a specific branch in TFS from batch?

How do I fail a specific SQS message in a batch from a Lambda?

How get specific word from the file using batch scripting?

How to remove a specific class from all elements?

How to only export specific methods from a class?

How to get a particular link from a specific class?

beautifulsoup how to extract text from a specific class

How to get data from a specific class in javascript?

Highlight specific sample in stripplot from pandas dataframe

How to record log information from a specific class into a specific file?

How to sample rows from a data frame that was previously grouped by a specific column, according to different conditions?

How to get a Class literal from a generically specific Class

How to inherit specific variables from one class to another class?

How to sample an item from a variable

How to sample from normal distribution?

Sample each image from dataset N times in single batch

How to search for a specific file with batch

How do i access a specific object from with in the class it was created from

How to find specific span from th class using selenium

Java(8): How to extract an specific class item from objects array?

How do I select elements from one specific class?

How to make the instances of a class using a metaclass inherit from a specific superclass