3
\$\begingroup\$

I have this MSD (most-significant digit) radix sort sorting Entry objects holding a sorting key of type long and a satellite datum. It handles the issue of sign bits so that it produces a permutation that honours the sign bit: the Entry with the sign bit on will precede all the Entry objects with sign bit off. It is highly efficient on large arrays (almost 3 times faster than java.util.Arrays.sort to sort 1e7 Entry objects).

So what do you think?

Arrays.java:

package net.coderodde.util; public class Arrays { private static final int BUCKETS = 256; private static final int BITS_PER_BYTE = 8; private static final int RIGHT_SHIFT_AMOUNT = 56; private static final int MOST_SIGNIFICANT_BYTE_INDEX = 7; private static final int MERGESORT_THRESHOLD = 4096; private static final int LEAST_SIGNED_BUCKET_INDEX = 128; public static final <E> void sort(final Entry<E>[] array, final int fromIndex, final int toIndex) { if (toIndex - fromIndex < 2) { // Trivially sorted or indices ass-backwards. return; } final Entry<E>[] buffer = array.clone(); sortTopImpl(array, buffer, fromIndex, toIndex); } public static final <E> void sort(final Entry<E>[] array) { sort(array, 0, array.length); } public static final <E extends Comparable<? super E>> boolean isSorted(final E[] array, final int fromIndex, final int toIndex) { for (int i = fromIndex; i < toIndex - 1; ++i) { if (array[i].compareTo(array[i + 1]) > 0) { return false; } } return true; } public static final <E extends Comparable<? super E>> boolean isSorted(final E[] array) { return isSorted(array, 0, array.length); } public static final <E> boolean areEqual(final Entry<E>[]... arrays) { for (int i = 0; i < arrays.length - 1; ++i) { if (arrays[i].length != arrays[i + 1].length) { return false; } } for (int i = 0; i < arrays[0].length; ++i) { for (int j = 0; j < arrays.length - 1; ++j) { if (arrays[j][i] != arrays[j + 1][i]) { return false; } } } return true; } /** * This static method sorts sequentially an entry array by most-significant * bytes. * * @param <E> the type of satellite data of each entry. * @param array the actual array to sort. * @param buffer the auxiliary buffer. * @param fromIndex the least index of the range to sort. * @param toIndex the index one past the last index of the range to sort. */ private static <E> void sortTopImpl(final Entry<E>[] array, final Entry<E>[] buffer, final int fromIndex, final int toIndex) { // The amount of elements in the requested range. final int RANGE_LENGTH = toIndex - fromIndex; if (RANGE_LENGTH <= MERGESORT_THRESHOLD) { // Once here, the range is too small, use merge sort. // The amount of merge passes needed to sort the input range. final int PASSES = getPassAmount(RANGE_LENGTH); // Here, both 'array' and 'buffer' are identical in content. if ((PASSES & 1) == 0) { // Once here, there will be an even amount of merge passes, so // it makes sense to pass 'array' as the source array so that // the actual sorted data ends up in it, so there is no need // to copy the sorted range from 'buffer' to 'array'. mergesort(array, buffer, fromIndex, toIndex); } else { // A symmetric case: sort using 'buffer' as the source array // so that the sorted data ends up in 'array' and we don't have // to do additional copying. mergesort(buffer, array, fromIndex, toIndex); } return; } final int[] bucketSizeMap = new int[BUCKETS]; final int[] startIndexMap = new int[BUCKETS]; final int[] processedMap = new int[BUCKETS]; // Determine the size of each bucket. for (int i = fromIndex; i < toIndex; ++i) { bucketSizeMap[(int)(array[i].key >>> RIGHT_SHIFT_AMOUNT)]++; } // BEGIN: Special sign bit magic. startIndexMap[LEAST_SIGNED_BUCKET_INDEX] = fromIndex; for (int i = LEAST_SIGNED_BUCKET_INDEX + 1; i != BUCKETS; ++i) { startIndexMap[i] = startIndexMap[i - 1] + bucketSizeMap[i - 1]; } startIndexMap[0] = startIndexMap[BUCKETS - 1] + bucketSizeMap[BUCKETS - 1]; for (int i = 1; i != LEAST_SIGNED_BUCKET_INDEX; ++i) { startIndexMap[i] = startIndexMap[i - 1] + bucketSizeMap[i - 1]; } // END: Special sign bit magic. // Insert elements into their respective buckets in the buffer array. for (int i = fromIndex; i < toIndex; ++i) { final Entry<E> e = array[i]; final int index = (int)(e.key >>> RIGHT_SHIFT_AMOUNT); buffer[startIndexMap[index] + processedMap[index]++] = e; } // Recur to sort all non-empty buckets. for (int i = 0; i != BUCKETS; ++i) { if (bucketSizeMap[i] != 0) { sortImpl(buffer, array, MOST_SIGNIFICANT_BYTE_INDEX - 1, startIndexMap[i], startIndexMap[i] + bucketSizeMap[i]); } } } /** * This static method sorts the entry array by bytes that are not * most-significant. * * @param <E> the type of satellite data in the entry array. * @param source the source array. * @param target the target array. * @param byteIndex the index of the byte to use as the sorting key. 0 * represents the least-significant byte. * @param fromIndex the least index of the range to sort. * @param toIndex the index one past the greatest index of the range to * sort. */ private static <E> void sortImpl(final Entry<E>[] source, final Entry<E>[] target, final int byteIndex, final int fromIndex, final int toIndex) { // Try merge sort. if (toIndex - fromIndex <= MERGESORT_THRESHOLD) { // If 'even' is true, the sorted ranged ended up in 'source'. final boolean even = mergesort(source, target, fromIndex, toIndex); if (even) { // source contains the sorted bucket. if ((byteIndex & 1) == 0) { // byteIndex = 6, 4, 2, 0. // source is buffer, copy to target. System.arraycopy(source, fromIndex, target, fromIndex, toIndex - fromIndex); } } else { // target contains the sorted bucket. if ((byteIndex & 1) == 1) { // byteIndex = 5, 3, 1. // target is buffer, copy to source. System.arraycopy(target, fromIndex, source, fromIndex, toIndex - fromIndex); } } return; } final int[] bucketSizeMap = new int[BUCKETS]; final int[] startIndexMap = new int[BUCKETS]; final int[] processedMap = new int[BUCKETS]; // We need this as to get rid of the bits on the left from the byte we // are interesed in. final int LEFT_SHIFT_AMOUNT = BITS_PER_BYTE * (MOST_SIGNIFICANT_BYTE_INDEX - byteIndex); // Compute the size of each bucket. for (int i = fromIndex; i < toIndex; ++i) { bucketSizeMap[(int)((source[i].key << LEFT_SHIFT_AMOUNT) >>> RIGHT_SHIFT_AMOUNT)]++; } // Initialize the start index map. startIndexMap[0] = fromIndex; // Compute the start index map in its entirety. for (int i = 1; i != BUCKETS; ++i) { startIndexMap[i] = startIndexMap[i - 1] + bucketSizeMap[i - 1]; } // Insert the entries from 'source' into their respective 'target'. for (int i = fromIndex; i < toIndex; ++i) { final Entry<E> e = source[i]; final int index = (int)((e.key << LEFT_SHIFT_AMOUNT) >>> RIGHT_SHIFT_AMOUNT); target[startIndexMap[index] + processedMap[index]++] = e; } if (byteIndex == 0) { // There is nowhere to recur, return. return; } // Recur to sort each bucket. for (int i = 0; i != BUCKETS; ++i) { if (bucketSizeMap[i] != 0) { sortImpl(target, source, byteIndex - 1, startIndexMap[i], startIndexMap[i] + bucketSizeMap[i]); } } } /** * Sorts the range <code>[fromIndex, toIndex)</code> between the arrays * <code>source</code> and <code>target</code>. * * @param <E> the type of entries' satellite data. * @param source the source array; the data to sort is assumed to be in this * array. * @param target acts as an auxiliary array. * @param fromIndex the least component index of the range to sort. * @param toIndex <code>toIndex - 1</code> is the index of the rightmost * component in the range to sort. * @return <code>true</code> if there was an even amount of merge passes, * which implies that the sorted range ended up in <code>source</code>. * Otherwise <code>false</code> is returned, and the sorted range ended up * in the array <code>target</code>. */ private static final <E> boolean mergesort(final Entry<E>[] source, final Entry<E>[] target, final int fromIndex, final int toIndex) { final int RANGE_LENGTH = toIndex - fromIndex; Entry<E>[] s = source; Entry<E>[] t = target; int passes = 0; for (int width = 1; width < RANGE_LENGTH; width <<= 1) { ++passes; int c = 0; for (; c < RANGE_LENGTH / width; c += 2) { int left = fromIndex + c * width; int right = left + width; int i = left; final int leftBound = right; final int rightBound = Math.min(toIndex, right + width); while (left < leftBound && right < rightBound) { t[i++] = s[right].key < s[left].key ? s[right++] : s[left++]; } while (left < leftBound) { t[i++] = s[left++]; } while (right < rightBound) { t[i++] = s[right++]; } } if (c * width < RANGE_LENGTH) { for (int i = fromIndex + c * width; i < toIndex; ++i) { t[i] = s[i]; } } final Entry<E>[] tmp = s; s = t; t = tmp; } return (passes & 1) == 0; } private static int getPassAmount(int length) { if (length < 1) { // Should not get here. length = 1; } return 32 - Integer.numberOfLeadingZeros(length - 1); } } 

Entry.java:

package net.coderodde.util; /** * The wrapper class holding a satellite datum and the key. * * @param <E> the type of a satellite datum. */ public final class Entry<E> implements Comparable<Entry<E>> { /** * The sorting key. */ public long key; /** * The satellite data. */ public E satelliteData; /** * Constructs a new <code>Entry</code> with key <code>key</code> and * the satellite datum <code>satelliteData</code>. * * @param key the key of this entry. * @param satelliteData the satellite data associated with the key. */ public Entry(final long key, final E satelliteData) { this.key = key; this.satelliteData = satelliteData; } /** * Compares this <code>Entry</code> with another. * * @param o the entry to compare against. * * @return a negative value if this entry's key is less than that of * <code>o</code>, a positive value if this entry's key is greater than that * of <code>o</code>, or 0 if the two keys are equal. */ @Override public int compareTo(Entry<E> o) { if (key < o.key) { return -1; } else if (key > o.key) { return 1; } else { return 0; } } } 

Demo.java:

package net.coderodde.util; import java.util.Random; public class Demo { private static final int N = 10000000; public static void main(final String... args) { final long seed = System.currentTimeMillis(); final Random rnd = new Random(seed); final Entry<Integer>[] array1 = getRandomEntryArray(N, rnd); final Entry<Integer>[] array2 = array1.clone(); System.out.println("Seed: " + seed); long ta = System.currentTimeMillis(); net.coderodde.util.Arrays.sort(array1); long tb = System.currentTimeMillis(); System.out.println("net.coderodde.util.Arrays.sort in " + (tb - ta) + " ms."); ta = System.currentTimeMillis(); java.util.Arrays.sort(array2); tb = System.currentTimeMillis(); System.out.println("java.util.Arrays.sort in " + (tb - ta) + " ms."); System.out.println("Arrays are equal: " + Arrays.areEqual(array1, array2)); System.out.println("Sorted: " + Arrays.isSorted(array1)); } private static Entry<Integer>[] getRandomEntryArray(final int size, final Random rnd) { final Entry<Integer>[] array = new Entry[size]; for (int i = 0; i < size; ++i) { array[i] = new Entry<>(rnd.nextLong(), null); } return array; } } 
\$\endgroup\$

    1 Answer 1

    3
    \$\begingroup\$

    Entry

    Your Entry compareTo(...) method is fine, but you should try to defer to Long.compare() instead. your code:

    @Override public int compareTo(Entry<E> o) { if (key < o.key) { return -1; } else if (key > o.key) { return 1; } else { return 0; } } 

    could be:

    @Override public int compareTo(Entry<E> o) { return Long.compare(key, o.key); } 

    The key and satelliteData fields in the Entry should also be final, and instead of being public, should rather have 'getters' for them.

    Sorting

    You have special handling for the buckets and the ranges, depending on negative values.

    This special handling has also resulted in a lot of code duplication. You essentially have two complete method duplicates, one for sorting the high-byte (with negative values), and the other for sorting the remaining low bytes. Your code 'buckets' the data (or a data subset) in to buckets based on a significant byte. The challenge here is that the most significant byte has a different sort order than other bytes.

    The trick to solving this is to flip the most significant bit, and the resulting order is now accurate as if the long was unsigned.....

    Your code would boil down to something like:

    private static final int BITS_PER_BUCKET = 8; private static final int BUCKETS = 1 << BITS_PER_BUCKET; private static final int BUCKET_MASK = BUCKETS - 1; private static final long SIGN_MASK = 1L << 63; /* * Converts a section of a key in to a bucket. Treats sign bit properly. */ private static final int getBucket(final long key, final int recursionDepth) { final int bitShift = 64 - (recursionDepth + 1) * BITS_PER_BUCKET; return (int)((key ^ SIGN_MASK) >>> bitShift) & BUCKET_MASK } private static <E> void sortImpl(final Entry<E>[] source, final Entry<E>[] target, final int recursionDepth, final int fromIndex, final int toIndex) { // Try merge sort. if (toIndex - fromIndex <= MERGESORT_THRESHOLD) { // perform merge sort ..... .... return; } final int[] bucketSizeMap = new int[BUCKETS]; final int[] startIndexMap = new int[BUCKETS]; final int[] processedMap = new int[BUCKETS]; // Compute the size of each bucket. for (int i = fromIndex; i < toIndex; ++i) { bucketSizeMap[getBucket(source[i].key, recursionDepth)]++; } // Initialize the start index map. startIndexMap[0] = fromIndex; // Compute the start index map in its entirety. for (int i = 1; i != BUCKETS; ++i) { startIndexMap[i] = startIndexMap[i - 1] + bucketSizeMap[i - 1]; } // Insert the entries from 'source' into their respective 'target'. for (int i = fromIndex; i < toIndex; ++i) { final Entry<E> e = source[i]; final int index = getBucket(source[i].key, recursionDepth); target[startIndexMap[index] + processedMap[index]++] = e; } // Recur to sort each bucket. for (int i = 0; i != BUCKETS; ++i) { if (bucketSizeMap[i] != 0) { sortImpl(target, source, recursionDepth + 1, startIndexMap[i], startIndexMap[i] + bucketSizeMap[i]); } } } 

    With this code, there is no need for the 'Top' sort method at all, it's superfluous. Your recursion entry can change from:

    sortTopImpl(array, buffer, fromIndex, toIndex); 

    to:

    sortImpl(array, buffer, 0, fromIndex, toIndex); 

    and you can delete the sortTopImpl method entirely.

    Note that the bucket-constants are all based off the one BITS_PER_BUCKET size, and the rest is calculated from that. You should be able to easily change the bucket size by just changing that one constant.

    Additionally, just so long as the merge-sort threshold is larger than a bucket-size, then there is no need to check for the limit on the depth of recursion....

    \$\endgroup\$
    0

      Start asking to get answers

      Find the answer to your question by asking.

      Ask question

      Explore related questions

      See similar questions with these tags.