Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.util.collection.unsafe.sort;
0019 
0020 import com.google.common.primitives.Ints;
0021 
0022 import org.apache.spark.unsafe.Platform;
0023 import org.apache.spark.unsafe.array.LongArray;
0024 
0025 public class RadixSort {
0026 
0027   /**
0028    * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes
0029    * you have extra space at the end of the array at least equal to the number of records. The
0030    * sort is destructive and may relocate the data positioned within the array.
0031    *
0032    * @param array array of long elements followed by at least that many empty slots.
0033    * @param numRecords number of data records in the array.
0034    * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the
0035    *                       least significant byte.
0036    * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the
0037    *                     least significant byte. Must be greater than startByteIndex.
0038    * @param desc whether this is a descending (binary-order) sort.
0039    * @param signed whether this is a signed (two's complement) sort.
0040    *
0041    * @return The starting index of the sorted data within the given array. We return this instead
0042    *         of always copying the data back to position zero for efficiency.
0043    */
0044   public static int sort(
0045       LongArray array, long numRecords, int startByteIndex, int endByteIndex,
0046       boolean desc, boolean signed) {
0047     assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
0048     assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
0049     assert endByteIndex > startByteIndex;
0050     assert numRecords * 2 <= array.size();
0051     long inIndex = 0;
0052     long outIndex = numRecords;
0053     if (numRecords > 0) {
0054       long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
0055       for (int i = startByteIndex; i <= endByteIndex; i++) {
0056         if (counts[i] != null) {
0057           sortAtByte(
0058             array, numRecords, counts[i], i, inIndex, outIndex,
0059             desc, signed && i == endByteIndex);
0060           long tmp = inIndex;
0061           inIndex = outIndex;
0062           outIndex = tmp;
0063         }
0064       }
0065     }
0066     return Ints.checkedCast(inIndex);
0067   }
0068 
0069   /**
0070    * Performs a partial sort by copying data into destination offsets for each byte value at the
0071    * specified byte offset.
0072    *
0073    * @param array array to partially sort.
0074    * @param numRecords number of data records in the array.
0075    * @param counts counts for each byte value. This routine destructively modifies this array.
0076    * @param byteIdx the byte in a long to sort at, counting from the least significant byte.
0077    * @param inIndex the starting index in the array where input data is located.
0078    * @param outIndex the starting index where sorted output data should be written.
0079    * @param desc whether this is a descending (binary-order) sort.
0080    * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
0081    */
0082   private static void sortAtByte(
0083       LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
0084       boolean desc, boolean signed) {
0085     assert counts.length == 256;
0086     long[] offsets = transformCountsToOffsets(
0087       counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
0088     Object baseObject = array.getBaseObject();
0089     long baseOffset = array.getBaseOffset() + inIndex * 8L;
0090     long maxOffset = baseOffset + numRecords * 8L;
0091     for (long offset = baseOffset; offset < maxOffset; offset += 8) {
0092       long value = Platform.getLong(baseObject, offset);
0093       int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
0094       Platform.putLong(baseObject, offsets[bucket], value);
0095       offsets[bucket] += 8;
0096     }
0097   }
0098 
0099   /**
0100    * Computes a value histogram for each byte in the given array.
0101    *
0102    * @param array array to count records in.
0103    * @param numRecords number of data records in the array.
0104    * @param startByteIndex the first byte to compute counts for (the prior are skipped).
0105    * @param endByteIndex the last byte to compute counts for.
0106    *
0107    * @return an array of eight 256-byte count arrays, one for each byte starting from the least
0108    *         significant byte. If the byte does not need sorting the array will be null.
0109    */
0110   private static long[][] getCounts(
0111       LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
0112     long[][] counts = new long[8][];
0113     // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
0114     // If all the byte values at a particular index are the same we don't need to count it.
0115     long bitwiseMax = 0;
0116     long bitwiseMin = -1L;
0117     long maxOffset = array.getBaseOffset() + numRecords * 8L;
0118     Object baseObject = array.getBaseObject();
0119     for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
0120       long value = Platform.getLong(baseObject, offset);
0121       bitwiseMax |= value;
0122       bitwiseMin &= value;
0123     }
0124     long bitsChanged = bitwiseMin ^ bitwiseMax;
0125     // Compute counts for each byte index.
0126     for (int i = startByteIndex; i <= endByteIndex; i++) {
0127       if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
0128         counts[i] = new long[256];
0129         // TODO(ekl) consider computing all the counts in one pass.
0130         for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
0131           counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
0132         }
0133       }
0134     }
0135     return counts;
0136   }
0137 
0138   /**
0139    * Transforms counts into the proper unsafe output offsets for the sort type.
0140    *
0141    * @param counts counts for each byte value. This routine destructively modifies this array.
0142    * @param numRecords number of data records in the original data array.
0143    * @param outputOffset output offset in bytes from the base array object.
0144    * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort).
0145    * @param desc whether this is a descending (binary-order) sort.
0146    * @param signed whether this is a signed (two's complement) sort.
0147    *
0148    * @return the input counts array.
0149    */
0150   private static long[] transformCountsToOffsets(
0151       long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
0152       boolean desc, boolean signed) {
0153     assert counts.length == 256;
0154     int start = signed ? 128 : 0;  // output the negative records first (values 129-255).
0155     if (desc) {
0156       long pos = numRecords;
0157       for (int i = start; i < start + 256; i++) {
0158         pos -= counts[i & 0xff];
0159         counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
0160       }
0161     } else {
0162       long pos = 0;
0163       for (int i = start; i < start + 256; i++) {
0164         long tmp = counts[i & 0xff];
0165         counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
0166         pos += tmp;
0167       }
0168     }
0169     return counts;
0170   }
0171 
0172   /**
0173    * Specialization of sort() for key-prefix arrays. In this type of array, each record consists
0174    * of two longs, only the second of which is sorted on.
0175    *
0176    * @param startIndex starting index in the array to sort from. This parameter is not supported
0177    *    in the plain sort() implementation.
0178    */
0179   public static int sortKeyPrefixArray(
0180       LongArray array,
0181       long startIndex,
0182       long numRecords,
0183       int startByteIndex,
0184       int endByteIndex,
0185       boolean desc,
0186       boolean signed) {
0187     assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
0188     assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
0189     assert endByteIndex > startByteIndex;
0190     assert numRecords * 4 <= array.size();
0191     long inIndex = startIndex;
0192     long outIndex = startIndex + numRecords * 2L;
0193     if (numRecords > 0) {
0194       long[][] counts = getKeyPrefixArrayCounts(
0195         array, startIndex, numRecords, startByteIndex, endByteIndex);
0196       for (int i = startByteIndex; i <= endByteIndex; i++) {
0197         if (counts[i] != null) {
0198           sortKeyPrefixArrayAtByte(
0199             array, numRecords, counts[i], i, inIndex, outIndex,
0200             desc, signed && i == endByteIndex);
0201           long tmp = inIndex;
0202           inIndex = outIndex;
0203           outIndex = tmp;
0204         }
0205       }
0206     }
0207     return Ints.checkedCast(inIndex);
0208   }
0209 
0210   /**
0211    * Specialization of getCounts() for key-prefix arrays. We could probably combine this with
0212    * getCounts with some added parameters but that seems to hurt in benchmarks.
0213    */
0214   private static long[][] getKeyPrefixArrayCounts(
0215       LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
0216     long[][] counts = new long[8][];
0217     long bitwiseMax = 0;
0218     long bitwiseMin = -1L;
0219     long baseOffset = array.getBaseOffset() + startIndex * 8L;
0220     long limit = baseOffset + numRecords * 16L;
0221     Object baseObject = array.getBaseObject();
0222     for (long offset = baseOffset; offset < limit; offset += 16) {
0223       long value = Platform.getLong(baseObject, offset + 8);
0224       bitwiseMax |= value;
0225       bitwiseMin &= value;
0226     }
0227     long bitsChanged = bitwiseMin ^ bitwiseMax;
0228     for (int i = startByteIndex; i <= endByteIndex; i++) {
0229       if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
0230         counts[i] = new long[256];
0231         for (long offset = baseOffset; offset < limit; offset += 16) {
0232           counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
0233         }
0234       }
0235     }
0236     return counts;
0237   }
0238 
0239   /**
0240    * Specialization of sortAtByte() for key-prefix arrays.
0241    */
0242   private static void sortKeyPrefixArrayAtByte(
0243       LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
0244       boolean desc, boolean signed) {
0245     assert counts.length == 256;
0246     long[] offsets = transformCountsToOffsets(
0247       counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
0248     Object baseObject = array.getBaseObject();
0249     long baseOffset = array.getBaseOffset() + inIndex * 8L;
0250     long maxOffset = baseOffset + numRecords * 16L;
0251     for (long offset = baseOffset; offset < maxOffset; offset += 16) {
0252       long key = Platform.getLong(baseObject, offset);
0253       long prefix = Platform.getLong(baseObject, offset + 8);
0254       int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
0255       long dest = offsets[bucket];
0256       Platform.putLong(baseObject, dest, key);
0257       Platform.putLong(baseObject, dest + 8, prefix);
0258       offsets[bucket] += 16;
0259     }
0260   }
0261 }