[SPARK-7190] [SPARK-8804] [SPARK-7815] [SQL] unsafe UTF8String

Let UTF8String work with binary buffer. Before we have better idea on manage the lifecycle of UTF8String in Row, we still do the copy when calling `UnsafeRow.get()` for StringType.

cc rxin JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #7197 from davies/unsafe_string and squashes the following commits:

51b0ea0 [Davies Liu] fix test
50c1ebf [Davies Liu] remove optimization for upper/lower case
315d491 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_string
93fce17 [Davies Liu] address comment
e9ff7ba [Davies Liu] clean up
67ec266 [Davies Liu] fix bug
7b74b1f [Davies Liu] fallback to String if local dependent
ab7857c [Davies Liu] address comments
7da92f5 [Davies Liu] handle local in toUpperCase/toLowerCase
59dbb23 [Davies Liu] revert python change
d1e0716 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_string
002e35f [Davies Liu] rollback hashCode change
a87b7a8 [Davies Liu] improve toLowerCase and toUpperCase
76e794a [Davies Liu] fix test
8b2d5ce [Davies Liu] fix tests
fd3f0a6 [Davies Liu] bug fix
c4e9c88 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_string
c45d921 [Davies Liu] address comments
175405f [Davies Liu] unsafe UTF8String
This commit is contained in:
Davies Liu 2015-07-07 17:57:17 -07:00 committed by Reynold Xin
parent 770ff1025e
commit 4ca90935c5
8 changed files with 210 additions and 112 deletions

View file

@ -264,6 +264,7 @@ public final class UnsafeRow extends MutableRow {
int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
int size = (int) (v & Integer.MAX_VALUE);
final byte[] bytes = new byte[size];
// TODO(davies): Avoid the copy once we can manage the life cycle of Row well.
PlatformDependent.copyMemory(
baseObject,
baseOffset + offset,

View file

@ -139,7 +139,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// UDFToBoolean
private[this] def castToBoolean(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, _.length() != 0)
buildCast[UTF8String](_, _.numBytes() != 0)
case TimestampType =>
buildCast[Long](_, t => t != 0)
case DateType =>

View file

@ -250,7 +250,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
val (st, end) = slicePos(start, length, () => s.length())
val (st, end) = slicePos(start, length, () => s.numChars())
s.substring(st, end)
}
}
@ -265,10 +265,10 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
override def inputTypes: Seq[DataType] = Seq(StringType)
protected override def nullSafeEval(string: Any): Any =
string.asInstanceOf[UTF8String].length
string.asInstanceOf[UTF8String].numChars
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"($c).length()")
defineCodeGen(ctx, ev, c => s"($c).numChars()")
}
override def prettyName: String = "length"

View file

@ -17,7 +17,7 @@
package org.apache.spark.unsafe.array;
import org.apache.spark.unsafe.PlatformDependent;
import static org.apache.spark.unsafe.PlatformDependent.*;
public class ByteArrayMethods {
@ -35,21 +35,27 @@ public class ByteArrayMethods {
}
/**
* Optimized byte array equality check for 8-byte-word-aligned byte arrays.
* Optimized byte array equality check for byte arrays.
* @return true if the arrays are equal, false otherwise
*/
public static boolean wordAlignedArrayEquals(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset,
long arrayLengthInBytes) {
for (int i = 0; i < arrayLengthInBytes; i += 8) {
final long left =
PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i);
final long right =
PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i);
if (left != right) return false;
public static boolean arrayEquals(
Object leftBase,
long leftOffset,
Object rightBase,
long rightOffset,
final long length) {
int i = 0;
while (i <= length - 8) {
if (UNSAFE.getLong(leftBase, leftOffset + i) != UNSAFE.getLong(rightBase, rightOffset + i)) {
return false;
}
i += 8;
}
while (i < length) {
if (UNSAFE.getByte(leftBase, leftOffset + i) != UNSAFE.getByte(rightBase, rightOffset + i)) {
return false;
}
i += 1;
}
return true;
}

View file

@ -277,7 +277,7 @@ public final class BytesToBytesMap {
final MemoryLocation keyAddress = loc.getKeyAddress();
final Object storedKeyBaseObject = keyAddress.getBaseObject();
final long storedKeyBaseOffset = keyAddress.getBaseOffset();
final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals(
final boolean areEqual = ByteArrayMethods.arrayEquals(
keyBaseObject,
keyBaseOffset,
storedKeyBaseObject,

View file

@ -20,9 +20,10 @@ package org.apache.spark.unsafe.types;
import javax.annotation.Nonnull;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import static org.apache.spark.unsafe.PlatformDependent.*;
/**
* A UTF-8 String for internal Spark use.
@ -35,7 +36,9 @@ import org.apache.spark.unsafe.PlatformDependent;
public final class UTF8String implements Comparable<UTF8String>, Serializable {
@Nonnull
private byte[] bytes;
private final Object base;
private final long offset;
private final int numBytes;
private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
@ -44,60 +47,82 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
5, 5, 5, 5,
6, 6, 6, 6};
/**
* Creates an UTF8String from byte array, which should be encoded in UTF-8.
*
* Note: `bytes` will be hold by returned UTF8String.
*/
public static UTF8String fromBytes(byte[] bytes) {
return (bytes != null) ? new UTF8String().set(bytes) : null;
}
public static UTF8String fromString(String str) {
return (str != null) ? new UTF8String().set(str) : null;
if (bytes != null) {
return new UTF8String(bytes, BYTE_ARRAY_OFFSET, bytes.length);
} else {
return null;
}
}
/**
* Updates the UTF8String with String.
* Creates an UTF8String from String.
*/
protected UTF8String set(final String str) {
public static UTF8String fromString(String str) {
if (str == null) return null;
try {
bytes = str.getBytes("utf-8");
return fromBytes(str.getBytes("utf-8"));
} catch (UnsupportedEncodingException e) {
// Turn the exception into unchecked so we can find out about it at runtime, but
// don't need to add lots of boilerplate code everywhere.
PlatformDependent.throwException(e);
throwException(e);
return null;
}
return this;
}
/**
* Updates the UTF8String with byte[], which should be encoded in UTF-8.
*/
protected UTF8String set(final byte[] bytes) {
this.bytes = bytes;
return this;
protected UTF8String(Object base, long offset, int size) {
this.base = base;
this.offset = offset;
this.numBytes = size;
}
/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/
public int numBytes(final byte b) {
private static int numBytesForFirstByte(final byte b) {
final int offset = (b & 0xFF) - 192;
return (offset >= 0) ? bytesOfCodePointInUTF8[offset] : 1;
}
/**
* Returns the number of bytes
*/
public int numBytes() {
return numBytes;
}
/**
* Returns the number of code points in it.
*
* This is only used by Substring() when `start` is negative.
*/
public int length() {
public int numChars() {
int len = 0;
for (int i = 0; i < bytes.length; i+= numBytes(bytes[i])) {
for (int i = 0; i < numBytes; i += numBytesForFirstByte(getByte(i))) {
len += 1;
}
return len;
}
/**
* Returns the underline bytes, will be a copy of it if it's part of another array.
*/
public byte[] getBytes() {
return bytes;
// avoid copy if `base` is `byte[]`
if (offset == BYTE_ARRAY_OFFSET && base instanceof byte[]
&& ((byte[]) base).length == numBytes) {
return (byte[]) base;
} else {
byte[] bytes = new byte[numBytes];
copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes);
return bytes;
}
}
/**
@ -106,92 +131,110 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
* @param until the position after last code point, exclusive.
*/
public UTF8String substring(final int start, final int until) {
if (until <= start || start >= bytes.length) {
return UTF8String.fromBytes(new byte[0]);
if (until <= start || start >= numBytes) {
return fromBytes(new byte[0]);
}
int i = 0;
int c = 0;
for (; i < bytes.length && c < start; i += numBytes(bytes[i])) {
while (i < numBytes && c < start) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
int j = i;
for (; j < bytes.length && c < until; j += numBytes(bytes[i])) {
while (i < numBytes && c < until) {
i += numBytesForFirstByte(getByte(i));
c += 1;
}
return UTF8String.fromBytes(Arrays.copyOfRange(bytes, i, j));
byte[] bytes = new byte[i - j];
copyMemory(base, offset + j, bytes, BYTE_ARRAY_OFFSET, i - j);
return fromBytes(bytes);
}
/**
* Returns whether this contains `substring` or not.
*/
public boolean contains(final UTF8String substring) {
final byte[] b = substring.getBytes();
if (b.length == 0) {
if (substring.numBytes == 0) {
return true;
}
for (int i = 0; i <= bytes.length - b.length; i++) {
if (bytes[i] == b[0] && startsWith(b, i)) {
byte first = substring.getByte(0);
for (int i = 0; i <= numBytes - substring.numBytes; i++) {
if (getByte(i) == first && matchAt(substring, i)) {
return true;
}
}
return false;
}
private boolean startsWith(final byte[] prefix, int offsetInBytes) {
if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) {
/**
* Returns the byte at position `i`.
*/
private byte getByte(int i) {
return UNSAFE.getByte(base, offset + i);
}
private boolean matchAt(final UTF8String s, int pos) {
if (s.numBytes + pos > numBytes || pos < 0) {
return false;
}
int i = 0;
while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) {
i++;
}
return i == prefix.length;
return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes);
}
public boolean startsWith(final UTF8String prefix) {
return startsWith(prefix.getBytes(), 0);
return matchAt(prefix, 0);
}
public boolean endsWith(final UTF8String suffix) {
return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length);
return matchAt(suffix, numBytes - suffix.numBytes);
}
/**
* Returns the upper case of this string
*/
public UTF8String toUpperCase() {
return UTF8String.fromString(toString().toUpperCase());
return fromString(toString().toUpperCase());
}
/**
* Returns the lower case of this string
*/
public UTF8String toLowerCase() {
return UTF8String.fromString(toString().toLowerCase());
return fromString(toString().toLowerCase());
}
@Override
public String toString() {
try {
return new String(bytes, "utf-8");
return new String(getBytes(), "utf-8");
} catch (UnsupportedEncodingException e) {
// Turn the exception into unchecked so we can find out about it at runtime, but
// don't need to add lots of boilerplate code everywhere.
PlatformDependent.throwException(e);
throwException(e);
return "unknown"; // we will never reach here.
}
}
@Override
public UTF8String clone() {
return new UTF8String().set(bytes);
return fromBytes(getBytes());
}
@Override
public int compareTo(final UTF8String other) {
final byte[] b = other.getBytes();
for (int i = 0; i < bytes.length && i < b.length; i++) {
int res = bytes[i] - b[i];
int len = Math.min(numBytes, other.numBytes);
// TODO: compare 8 bytes as unsigned long
for (int i = 0; i < len; i ++) {
// In UTF-8, the byte should be unsigned, so we should compare them as unsigned int.
int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF);
if (res != 0) {
return res;
}
}
return bytes.length - b.length;
return numBytes - other.numBytes;
}
public int compare(final UTF8String other) {
@ -201,7 +244,11 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
@Override
public boolean equals(final Object other) {
if (other instanceof UTF8String) {
return Arrays.equals(bytes, ((UTF8String) other).getBytes());
UTF8String o = (UTF8String) other;
if (numBytes != o.numBytes){
return false;
}
return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes);
} else {
return false;
}
@ -209,6 +256,10 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
@Override
public int hashCode() {
return Arrays.hashCode(bytes);
int result = 1;
for (int i = 0; i < numBytes; i ++) {
result = 31 * result + getByte(i);
}
return result;
}
}

View file

@ -99,7 +99,7 @@ public abstract class AbstractBytesToBytesMapSuite {
byte[] expected,
MemoryLocation actualAddr,
long actualLengthBytes) {
return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals(
return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
expected,
BYTE_ARRAY_OFFSET,
actualAddr.getBaseObject(),

View file

@ -19,73 +19,113 @@ package org.apache.spark.unsafe.types;
import java.io.UnsupportedEncodingException;
import junit.framework.Assert;
import org.junit.Test;
import static junit.framework.Assert.*;
import static org.apache.spark.unsafe.types.UTF8String.*;
public class UTF8StringSuite {
private void checkBasic(String str, int len) throws UnsupportedEncodingException {
Assert.assertEquals(UTF8String.fromString(str).length(), len);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len);
UTF8String s1 = fromString(str);
UTF8String s2 = fromBytes(str.getBytes("utf8"));
assertEquals(s1.numChars(), len);
assertEquals(s2.numChars(), len);
Assert.assertEquals(UTF8String.fromString(str).toString(), str);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));
assertEquals(s1.toString(), str);
assertEquals(s2.toString(), str);
assertEquals(s1, s2);
Assert.assertEquals(UTF8String.fromString(str).hashCode(),
UTF8String.fromBytes(str.getBytes("utf8")).hashCode());
assertEquals(s1.hashCode(), s2.hashCode());
assertEquals(s1.compareTo(s2), 0);
assertEquals(s1.contains(s2), true);
assertEquals(s2.contains(s1), true);
assertEquals(s1.startsWith(s1), true);
assertEquals(s1.endsWith(s1), true);
}
@Test
public void basicTest() throws UnsupportedEncodingException {
checkBasic("", 0);
checkBasic("hello", 5);
checkBasic("世 界", 3);
checkBasic("大 千 世 界", 7);
}
@Test
public void compareTo() {
assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0);
assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0);
assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0);
assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0);
assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0);
assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0);
assertTrue(fromString("abc").compareTo(fromString("世界")) < 0);
assertTrue(fromString("你好").compareTo(fromString("世界")) > 0);
assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0);
}
protected void testUpperandLower(String upper, String lower) {
UTF8String us = fromString(upper);
UTF8String ls = fromString(lower);
assertEquals(ls, us.toLowerCase());
assertEquals(us, ls.toUpperCase());
assertEquals(us, us.toUpperCase());
assertEquals(ls, ls.toLowerCase());
}
@Test
public void upperAndLower() {
testUpperandLower("", "");
testUpperandLower("0123456", "0123456");
testUpperandLower("ABCXYZ", "abcxyz");
testUpperandLower("ЀЁЂѺΏỀ", "ѐёђѻώề");
testUpperandLower("大千世界 数据砖头", "大千世界 数据砖头");
}
@Test
public void contains() {
Assert.assertTrue(UTF8String.fromString("hello").contains(UTF8String.fromString("ello")));
Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("vello")));
Assert.assertFalse(UTF8String.fromString("hello").contains(UTF8String.fromString("hellooo")));
Assert.assertTrue(UTF8String.fromString("大千世界").contains(UTF8String.fromString("千世")));
Assert.assertFalse(UTF8String.fromString("大千世界").contains(UTF8String.fromString("世千")));
Assert.assertFalse(
UTF8String.fromString("大千世界").contains(UTF8String.fromString("大千世界好")));
assertTrue(fromString("").contains(fromString("")));
assertTrue(fromString("hello").contains(fromString("ello")));
assertFalse(fromString("hello").contains(fromString("vello")));
assertFalse(fromString("hello").contains(fromString("hellooo")));
assertTrue(fromString("大千世界").contains(fromString("千世界")));
assertFalse(fromString("大千世界").contains(fromString("世千")));
assertFalse(fromString("大千世界").contains(fromString("大千世界好")));
}
@Test
public void startsWith() {
Assert.assertTrue(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hell")));
Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("ell")));
Assert.assertFalse(UTF8String.fromString("hello").startsWith(UTF8String.fromString("hellooo")));
Assert.assertTrue(UTF8String.fromString("数据砖头").startsWith(UTF8String.fromString("数据")));
Assert.assertFalse(UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("")));
Assert.assertFalse(
UTF8String.fromString("大千世界").startsWith(UTF8String.fromString("大千世界好")));
assertTrue(fromString("").startsWith(fromString("")));
assertTrue(fromString("hello").startsWith(fromString("hell")));
assertFalse(fromString("hello").startsWith(fromString("ell")));
assertFalse(fromString("hello").startsWith(fromString("hellooo")));
assertTrue(fromString("数据砖头").startsWith(fromString("数据")));
assertFalse(fromString("大千世界").startsWith(fromString("")));
assertFalse(fromString("大千世界").startsWith(fromString("大千世界好")));
}
@Test
public void endsWith() {
Assert.assertTrue(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ello")));
Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("ellov")));
Assert.assertFalse(UTF8String.fromString("hello").endsWith(UTF8String.fromString("hhhello")));
Assert.assertTrue(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("世界")));
Assert.assertFalse(UTF8String.fromString("大千世界").endsWith(UTF8String.fromString("")));
Assert.assertFalse(
UTF8String.fromString("数据砖头").endsWith(UTF8String.fromString("我的数据砖头")));
assertTrue(fromString("").endsWith(fromString("")));
assertTrue(fromString("hello").endsWith(fromString("ello")));
assertFalse(fromString("hello").endsWith(fromString("ellov")));
assertFalse(fromString("hello").endsWith(fromString("hhhello")));
assertTrue(fromString("大千世界").endsWith(fromString("")));
assertFalse(fromString("大千世界").endsWith(fromString("")));
assertFalse(fromString("数据砖头").endsWith(fromString("我的数据砖头")));
}
@Test
public void substring() {
Assert.assertEquals(
UTF8String.fromString("hello").substring(0, 0), UTF8String.fromString(""));
Assert.assertEquals(
UTF8String.fromString("hello").substring(1, 3), UTF8String.fromString("el"));
Assert.assertEquals(
UTF8String.fromString("数据砖头").substring(0, 1), UTF8String.fromString(""));
Assert.assertEquals(
UTF8String.fromString("数据砖头").substring(1, 3), UTF8String.fromString("据砖"));
Assert.assertEquals(
UTF8String.fromString("数据砖头").substring(3, 5), UTF8String.fromString(""));
assertEquals(fromString("hello").substring(0, 0), fromString(""));
assertEquals(fromString("hello").substring(1, 3), fromString("el"));
assertEquals(fromString("数据砖头").substring(0, 1), fromString(""));
assertEquals(fromString("数据砖头").substring(1, 3), fromString("据砖"));
assertEquals(fromString("数据砖头").substring(3, 5), fromString(""));
assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
}
}