What is the fastest way to get k smallest (or largest) elements of array in Java?

Antoine

I have an array of elements (in the example, these are simply integers), which are compared using some custom comparator. In this example, I simulate this comparator by defining i SMALLER j if and only if scores[i] <= scores[j].

I have two approaches:

  • using heap of the current k candidates
  • using array of the current k candidates

I update the upper two structures in the following way:

  • heap: methods PriorityQueue.poll and PriorityQueue.offer,
  • array: index top of the worst among top k candidates in the array of candidates is stored. If a newly seen example is better than the element at the index top, the latter is replaced by the former and top is updated by iterating through all k elements of the array.

However, when I have tested, which of the approaches is faster, I found out that this is the second. The questions are:

  • Is my use of PriorityQueue suboptimal?
  • What is the fastest way to compute k smallest elements?

I am interested in the case, when the number of examples can be large, but the number of neighbours is relatively small (between 10 and 20).

Here is the code:

public static void main(String[] args) {
    long kopica, navadno, sortiranje;

    int numTries = 10000;
    int numExamples = 1000;
    int numNeighbours = 10;

    navadno = testSimple(numExamples, numNeighbours, numTries);
    kopica = testHeap(numExamples, numNeighbours, numTries);

    sortiranje = testSort(numExamples, numNeighbours, numTries, false);
    System.out.println(String.format("tries: %d examples: %d neighbours: %d\n time heap[ms]: %d\n time simple[ms]: %d", numTries, numExamples, numNeighbours, kopica, navadno));
}

public static long testHeap(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        PriorityQueue<Integer> myHeap = new PriorityQueue(numberNeighbours, new Comparator<Integer>(){
            @Override
            public int compare(Integer o1, Integer o2) {
                return -Double.compare(scores[o1], scores[o2]);
            }
        });

        int top;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                myHeap.offer(i);
            } else{
                top = myHeap.peek();
                if(scores[top] > scores[i]){
                    myHeap.poll();
                    myHeap.offer(i);
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

public static long testSimple(int numberExamples, int numberNeighbours, int numberTries){
    Random rnd = new Random(123);   
    long startTime = System.currentTimeMillis();
    for(int iteration = 0; iteration < numberTries; iteration++){
        final double[] scores = new double[numberExamples];
        for(int i = 0; i < numberExamples; i++){
            scores[i] = rnd.nextDouble();
        }
        int[] candidates = new int[numberNeighbours];
        int top = 0;
        for(int i = 0; i < numberExamples; i++){
            if(i < numberNeighbours){
                candidates[i] = i;
                if(scores[candidates[top]] < scores[candidates[i]]) top = i;
            } else{
                if(scores[candidates[top]] > scores[i]){
                    candidates[top] = i;
                    top = 0;
                    for(int j = 1; j < numberNeighbours; j++){
                        if(scores[candidates[top]] < scores[candidates[j]]) top = j;                            
                    }
                }
            }
        }

    }
    long endTime = System.currentTimeMillis();
    return endTime - startTime;     
}

This produces the following result:

tries: 10000 examples: 1000 neighbours: 10
   time heap[ms]: 393
   time simple[ms]: 388
user3707125

First of all, your benchmarking method is incorrect. You are measuring input data creation along with an algorithm performance, and you aren't warming up the JVM before measuring. Results for your code, when tested through the JMH:

Benchmark                     Mode  Cnt      Score   Error  Units
CounterBenchmark.testHeap    thrpt    2  18103,296          ops/s
CounterBenchmark.testSimple  thrpt    2  59490,384          ops/s

Modified benchmark pastebin.

Regarding 3x times difference between two provided solutions. In the terms of big-O notation your first algorithm may seem better, but in fact big-O notation only tells your how good the algorithm is in the terms of scaling, it never tells you how fast it performs (see this question also). And in your case scaling is not the issue, as your numNeighbours is limited to 20. In other words big-O notation describes how many ticks of algorithm is necessary for it to complete, but it doesn't limit the duration of a tick, it just says that the tick duration doesn't change when inputs change. And in terms of tick complexity your second algorithm surely wins.

What is the fastest way to compute k smallest elements?

I've came up with the next solution which I do believe allows branch prediction to do its job:

@Benchmark
public void testModified(Blackhole bh) {
    final double[] scores = sampleData;
    int[] candidates = new int[numberNeighbours];
    for (int i = 0; i < numberNeighbours; i++) {
        candidates[i] = i;
    }
    // sorting candidates so scores[candidates[0]] is the largest
    for (int i = 0; i < numberNeighbours; i++) {
        for (int j = i+1; j < numberNeighbours; j++) {
            if (scores[candidates[i]] < scores[candidates[j]]) {
                int temp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = temp;
            }
        }
    }
    // processing other scores, while keeping candidates array sorted in the descending order
    for (int i = numberNeighbours; i < numberExamples; i++) {
        if (scores[i] > scores[candidates[0]]) {
            continue;
        }
        // moving all larger candidates to the left, to keep the array sorted
        int j; // here the branch prediction should kick-in
        for (j = 1; j < numberNeighbours && scores[i] < scores[candidates[j]]; j++) {
            candidates[j - 1] = candidates[j];
        }
        // inserting the new item
        candidates[j - 1] = i;
    }
    bh.consume(candidates);
}

Benchmark results (2x times faster than your current solution):

(10 neighbours) CounterBenchmark.testModified    thrpt    2  136492,151          ops/s
(20 neighbours) CounterBenchmark.testModified    thrpt    2  118395,598          ops/s

Others mentioned quickselect, but as one may expect, the complexity of that algorithm neglects its strong sides in your case:

@Benchmark
public void testQuickSelect(Blackhole bh) {
    final int[] candidates = new int[sampleData.length];
    for (int i = 0; i < candidates.length; i++) {
        candidates[i] = i;
    }
    final int[] resultIndices = new int[numberNeighbours];
    int neighboursToAdd = numberNeighbours;

    int left = 0;
    int right = candidates.length - 1;
    while (neighboursToAdd > 0) {
        int partitionIndex = partition(candidates, left, right);
        int smallerItemsPartitioned = partitionIndex - left;
        if (smallerItemsPartitioned <= neighboursToAdd) {
            while (left < partitionIndex) {
                resultIndices[numberNeighbours - neighboursToAdd--] = candidates[left++];
            }
        } else {
            right = partitionIndex - 1;
        }
    }
    bh.consume(resultIndices);
}

private int partition(int[] locations, int left, int right) {
    final int pivotIndex = ThreadLocalRandom.current().nextInt(left, right + 1);
    final double pivotValue = sampleData[locations[pivotIndex]];
    int storeIndex = left;
    for (int i = left; i <= right; i++) {
        if (sampleData[locations[i]] <= pivotValue) {
            final int temp = locations[storeIndex];
            locations[storeIndex] = locations[i];
            locations[i] = temp;

            storeIndex++;
        }
    }
    return storeIndex;
}

Benchmark results are pretty upsetting in this case:

CounterBenchmark.testQuickSelect  thrpt    2   11586,761          ops/s

Collected from the Internet

Please contact [email protected] to delete if infringement.

edited at
0

Comments

0 comments
Login to comment

Related

What is the fastest way to set an arbitrary range of elements in a Java array to null?

What is the fastest way to get double[][] to a MATLAB matrix in Java?

Fastest way to get the first n elements of a List into an Array

Fastest way to get sign in Java?

Sort array from smallest to largest using java

In Java, what is the fastest way to get the system time?

What's the fastest way to only rotate certain elements in an array?

Find K Smallest Elements in an Array (Java)

What is the fastest way to get the value of π?

Fastest way to get elements in viewport

What is the fastest way to count elements in an array?

Largest subset in an array such that the smallest and largest elements are less than K apart

What is the fastest way to select the smallest n elements from an array?

What is the fastest way to check a pandas dataframe for elements?

Fastest way to get class types of elements of a cell array

Largest k elements in array under contraints

What is the fastest way to convert a folder with 10k .txt files into an array (python)

What is the fastest way to change the order of elements in an ArrayList?

what is the fastest way to write a byte array to socket outputstream in java

Getting the largest k elements of a double array

Java - Finding Largest and Smallest Numbers using an Array

K largest elements of an array, sorting algorithm

Switching the smallest and largest value in an object array in Java

What is the fastest way to get the frequency of numbers in an array in C++?

What is the fastest way to see if an array has two common elements?

C# - What's the fastest way to convert a number to the smallest BitArray

Sorting array elements to find the largest and smallest number in C

What is the fastest way to find the largest/smallest file in a folder with a few million of small files

k-Smallest Elements in an Array in O(n)