0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.util.collection.unsafe.sort;
0019
0020 import com.google.common.primitives.Ints;
0021
0022 import org.apache.spark.unsafe.Platform;
0023 import org.apache.spark.unsafe.array.LongArray;
0024
0025 public class RadixSort {
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044 public static int sort(
0045 LongArray array, long numRecords, int startByteIndex, int endByteIndex,
0046 boolean desc, boolean signed) {
0047 assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
0048 assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
0049 assert endByteIndex > startByteIndex;
0050 assert numRecords * 2 <= array.size();
0051 long inIndex = 0;
0052 long outIndex = numRecords;
0053 if (numRecords > 0) {
0054 long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
0055 for (int i = startByteIndex; i <= endByteIndex; i++) {
0056 if (counts[i] != null) {
0057 sortAtByte(
0058 array, numRecords, counts[i], i, inIndex, outIndex,
0059 desc, signed && i == endByteIndex);
0060 long tmp = inIndex;
0061 inIndex = outIndex;
0062 outIndex = tmp;
0063 }
0064 }
0065 }
0066 return Ints.checkedCast(inIndex);
0067 }
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082 private static void sortAtByte(
0083 LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
0084 boolean desc, boolean signed) {
0085 assert counts.length == 256;
0086 long[] offsets = transformCountsToOffsets(
0087 counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
0088 Object baseObject = array.getBaseObject();
0089 long baseOffset = array.getBaseOffset() + inIndex * 8L;
0090 long maxOffset = baseOffset + numRecords * 8L;
0091 for (long offset = baseOffset; offset < maxOffset; offset += 8) {
0092 long value = Platform.getLong(baseObject, offset);
0093 int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
0094 Platform.putLong(baseObject, offsets[bucket], value);
0095 offsets[bucket] += 8;
0096 }
0097 }
0098
0099
0100
0101
0102
0103
0104
0105
0106
0107
0108
0109
0110 private static long[][] getCounts(
0111 LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
0112 long[][] counts = new long[8][];
0113
0114
0115 long bitwiseMax = 0;
0116 long bitwiseMin = -1L;
0117 long maxOffset = array.getBaseOffset() + numRecords * 8L;
0118 Object baseObject = array.getBaseObject();
0119 for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
0120 long value = Platform.getLong(baseObject, offset);
0121 bitwiseMax |= value;
0122 bitwiseMin &= value;
0123 }
0124 long bitsChanged = bitwiseMin ^ bitwiseMax;
0125
0126 for (int i = startByteIndex; i <= endByteIndex; i++) {
0127 if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
0128 counts[i] = new long[256];
0129
0130 for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
0131 counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
0132 }
0133 }
0134 }
0135 return counts;
0136 }
0137
0138
0139
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150 private static long[] transformCountsToOffsets(
0151 long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
0152 boolean desc, boolean signed) {
0153 assert counts.length == 256;
0154 int start = signed ? 128 : 0;
0155 if (desc) {
0156 long pos = numRecords;
0157 for (int i = start; i < start + 256; i++) {
0158 pos -= counts[i & 0xff];
0159 counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
0160 }
0161 } else {
0162 long pos = 0;
0163 for (int i = start; i < start + 256; i++) {
0164 long tmp = counts[i & 0xff];
0165 counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
0166 pos += tmp;
0167 }
0168 }
0169 return counts;
0170 }
0171
0172
0173
0174
0175
0176
0177
0178
0179 public static int sortKeyPrefixArray(
0180 LongArray array,
0181 long startIndex,
0182 long numRecords,
0183 int startByteIndex,
0184 int endByteIndex,
0185 boolean desc,
0186 boolean signed) {
0187 assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
0188 assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
0189 assert endByteIndex > startByteIndex;
0190 assert numRecords * 4 <= array.size();
0191 long inIndex = startIndex;
0192 long outIndex = startIndex + numRecords * 2L;
0193 if (numRecords > 0) {
0194 long[][] counts = getKeyPrefixArrayCounts(
0195 array, startIndex, numRecords, startByteIndex, endByteIndex);
0196 for (int i = startByteIndex; i <= endByteIndex; i++) {
0197 if (counts[i] != null) {
0198 sortKeyPrefixArrayAtByte(
0199 array, numRecords, counts[i], i, inIndex, outIndex,
0200 desc, signed && i == endByteIndex);
0201 long tmp = inIndex;
0202 inIndex = outIndex;
0203 outIndex = tmp;
0204 }
0205 }
0206 }
0207 return Ints.checkedCast(inIndex);
0208 }
0209
0210
0211
0212
0213
0214 private static long[][] getKeyPrefixArrayCounts(
0215 LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
0216 long[][] counts = new long[8][];
0217 long bitwiseMax = 0;
0218 long bitwiseMin = -1L;
0219 long baseOffset = array.getBaseOffset() + startIndex * 8L;
0220 long limit = baseOffset + numRecords * 16L;
0221 Object baseObject = array.getBaseObject();
0222 for (long offset = baseOffset; offset < limit; offset += 16) {
0223 long value = Platform.getLong(baseObject, offset + 8);
0224 bitwiseMax |= value;
0225 bitwiseMin &= value;
0226 }
0227 long bitsChanged = bitwiseMin ^ bitwiseMax;
0228 for (int i = startByteIndex; i <= endByteIndex; i++) {
0229 if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
0230 counts[i] = new long[256];
0231 for (long offset = baseOffset; offset < limit; offset += 16) {
0232 counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
0233 }
0234 }
0235 }
0236 return counts;
0237 }
0238
0239
0240
0241
0242 private static void sortKeyPrefixArrayAtByte(
0243 LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
0244 boolean desc, boolean signed) {
0245 assert counts.length == 256;
0246 long[] offsets = transformCountsToOffsets(
0247 counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
0248 Object baseObject = array.getBaseObject();
0249 long baseOffset = array.getBaseOffset() + inIndex * 8L;
0250 long maxOffset = baseOffset + numRecords * 16L;
0251 for (long offset = baseOffset; offset < maxOffset; offset += 16) {
0252 long key = Platform.getLong(baseObject, offset);
0253 long prefix = Platform.getLong(baseObject, offset + 8);
0254 int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
0255 long dest = offsets[bucket];
0256 Platform.putLong(baseObject, dest, key);
0257 Platform.putLong(baseObject, dest + 8, prefix);
0258 offsets[bucket] += 16;
0259 }
0260 }
0261 }