Skip to content

Commit 91ce4a5

Browse files
committed
Improve byte array sort perf by unify getPrefix function of UTF8String and ByteArray
1 parent b9a8165 commit 91ce4a5

File tree

2 files changed

+30
-44
lines changed

2 files changed

+30
-44
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.unsafe.types;
1919

20+
import java.nio.ByteOrder;
2021
import java.util.Arrays;
2122

2223
import com.google.common.primitives.Ints;
@@ -26,6 +27,8 @@
2627
public final class ByteArray {
2728

2829
public static final byte[] EMPTY_BYTE = new byte[0];
30+
private static final boolean IS_LITTLE_ENDIAN =
31+
ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN;
2932

3033
/**
3134
* Writes the content of a byte array into a memory address, identified by an object and an
@@ -42,15 +45,34 @@ public static void writeToMemory(byte[] src, Object target, long targetOffset) {
4245
public static long getPrefix(byte[] bytes) {
4346
if (bytes == null) {
4447
return 0L;
48+
}
49+
return getPrefix(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length);
50+
}
51+
52+
protected static long getPrefix(Object base, long offset, int numBytes) {
53+
// Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the bytes.
54+
// If size is 0, just return 0.
55+
// If size is between 1 and 4 (inclusive), assume data is 4-byte aligned under the hood and
56+
// use a getInt to fetch the prefix.
57+
// If size is greater than 4, assume we have at least 8 bytes of data to fetch.
58+
// After getting the data, we use a mask to mask out data that is not part of the bytes.
59+
final long p;
60+
final long mask;
61+
if (numBytes >= 8) {
62+
p = Platform.getLong(base, offset);
63+
mask = 0;
64+
} else if (numBytes > 4) {
65+
p = Platform.getLong(base, offset);
66+
mask = (1L << (8 - numBytes) * 8) - 1;
67+
} else if (numBytes > 0) {
68+
long pRaw = Platform.getInt(base, offset);
69+
p = IS_LITTLE_ENDIAN ? pRaw : (pRaw << 32);
70+
mask = (1L << (8 - numBytes) * 8) - 1;
4571
} else {
46-
final int minLen = Math.min(bytes.length, 8);
47-
long p = 0;
48-
for (int i = 0; i < minLen; ++i) {
49-
p |= ((long) Platform.getByte(bytes, Platform.BYTE_ARRAY_OFFSET + i) & 0xff)
50-
<< (56 - 8 * i);
51-
}
52-
return p;
72+
p = 0;
73+
mask = 0;
5374
}
75+
return (IS_LITTLE_ENDIAN ? java.lang.Long.reverseBytes(p) : p) & ~mask;
5476
}
5577

5678
public static byte[] subStringSQL(byte[] bytes, int pos, int len) {

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -246,43 +246,7 @@ public int numChars() {
246246
* Returns a 64-bit integer that can be used as the prefix used in sorting.
247247
*/
248248
public long getPrefix() {
249-
// Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string.
250-
// If size is 0, just return 0.
251-
// If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and
252-
// use a getInt to fetch the prefix.
253-
// If size is greater than 4, assume we have at least 8 bytes of data to fetch.
254-
// After getting the data, we use a mask to mask out data that is not part of the string.
255-
long p;
256-
long mask = 0;
257-
if (IS_LITTLE_ENDIAN) {
258-
if (numBytes >= 8) {
259-
p = Platform.getLong(base, offset);
260-
} else if (numBytes > 4) {
261-
p = Platform.getLong(base, offset);
262-
mask = (1L << (8 - numBytes) * 8) - 1;
263-
} else if (numBytes > 0) {
264-
p = (long) Platform.getInt(base, offset);
265-
mask = (1L << (8 - numBytes) * 8) - 1;
266-
} else {
267-
p = 0;
268-
}
269-
p = java.lang.Long.reverseBytes(p);
270-
} else {
271-
// byteOrder == ByteOrder.BIG_ENDIAN
272-
if (numBytes >= 8) {
273-
p = Platform.getLong(base, offset);
274-
} else if (numBytes > 4) {
275-
p = Platform.getLong(base, offset);
276-
mask = (1L << (8 - numBytes) * 8) - 1;
277-
} else if (numBytes > 0) {
278-
p = ((long) Platform.getInt(base, offset)) << 32;
279-
mask = (1L << (8 - numBytes) * 8) - 1;
280-
} else {
281-
p = 0;
282-
}
283-
}
284-
p &= ~mask;
285-
return p;
249+
return ByteArray.getPrefix(base, offset, numBytes);
286250
}
287251

288252
/**

0 commit comments

Comments
 (0)