LUCENE-6968 - MinHash filter, thanks to Andy Hind and Cao Manh Dat for patches

This commit is contained in:
Tommaso Teofili 2016-06-14 11:49:53 +02:00
parent bd7ddb8fbf
commit 82a9244193
6 changed files with 915 additions and 0 deletions

View File

@ -13,6 +13,8 @@ Other
* LUCENE-7328: Remove LegacyNumericEncoding from GeoPointField. (Nick Knize)
* LUCENE-6968: LSH Filter (Tommaso Teofili, Andy Hind, Cao Manh Dat)
======================= Lucene 6.2.0 =======================
New Features

View File

@ -0,0 +1,504 @@
package org.apache.lucene.analysis.minhash;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.apache.lucene.analysis.TokenFilter;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.OffsetAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute;
import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
/**
* Generate min hash tokens from an incoming stream of tokens. The incoming tokens would typically be 5 word shingles.
*
* The number of hashes used and the number of minimum values for each hash can be set. You could have 1 hash and keep
* the 100 lowest values or 100 hashes and keep the lowest one for each. Hashes can also be bucketed in ranges over the
* 128-bit hash space,
*
* A 128-bit hash is used internally. 5 word shingles from 10e5 words generate 10e25 combinations So a 64 bit hash would
* have collisions (1.8e19)
*
* When using different hashes 32 bits are used for the hash position leaving scope for 8e28 unique hashes. A single
* hash will use all 128 bits.
*
*/
public class MinHashFilter extends TokenFilter {
private static final int HASH_CACHE_SIZE = 512;
private static final LongPair[] cachedIntHashes = new LongPair[HASH_CACHE_SIZE];
static final int DEFAULT_HASH_COUNT = 1;
static final int DEFAULT_HASH_SET_SIZE = 1;
static final int DEFAULT_BUCKET_COUNT = 512;
static final String MIN_HASH_TYPE = "MIN_HASH";
private final List<List<FixedSizeTreeSet<LongPair>>> minHashSets;
private int hashSetSize = DEFAULT_HASH_SET_SIZE;
private int bucketCount = DEFAULT_BUCKET_COUNT;
private int hashCount = DEFAULT_HASH_COUNT;
private boolean requiresInitialisation = true;
private State endState;
private int hashPosition = -1;
private int bucketPosition = -1;
private long bucketSize;
private final boolean withRotation;
private int endOffset;
private boolean exhausted = false;
private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class);
private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class);
private final TypeAttribute typeAttribute = addAttribute(TypeAttribute.class);
private final PositionIncrementAttribute posIncAttribute = addAttribute(PositionIncrementAttribute.class);
private final PositionLengthAttribute posLenAttribute = addAttribute(PositionLengthAttribute.class);
static {
for (int i = 0; i < HASH_CACHE_SIZE; i++) {
cachedIntHashes[i] = new LongPair();
murmurhash3_x64_128(getBytes(i), 0, 4, 0, cachedIntHashes[i]);
}
}
static byte[] getBytes(int i) {
byte[] answer = new byte[4];
answer[3] = (byte) (i);
answer[2] = (byte) (i >> 8);
answer[1] = (byte) (i >> 16);
answer[0] = (byte) (i >> 24);
return answer;
}
/**
* create a MinHash filter
*
* @param input the token stream
* @param hashCount the no. of hashes
* @param bucketCount the no. of buckets for hashing
* @param hashSetSize the no. of min hashes to keep
* @param withRotation whether rotate or not hashes while incrementing tokens
*/
MinHashFilter(TokenStream input, int hashCount, int bucketCount, int hashSetSize, boolean withRotation) {
super(input);
this.hashCount = hashCount;
this.bucketCount = bucketCount;
this.hashSetSize = hashSetSize;
this.withRotation = withRotation;
this.bucketSize = (1L << 32) / bucketCount;
if((1L << 32) % bucketCount != 0)
{
bucketSize++;
}
minHashSets = new ArrayList<>(this.hashCount);
for (int i = 0; i < this.hashCount; i++) {
ArrayList<FixedSizeTreeSet<LongPair>> buckets = new ArrayList<>(this.bucketCount);
minHashSets.add(buckets);
for (int j = 0; j < this.bucketCount; j++) {
FixedSizeTreeSet<LongPair> minSet = new FixedSizeTreeSet<>(this.hashSetSize);
buckets.add(minSet);
}
}
doRest();
}
@Override
public final boolean incrementToken() throws IOException {
// Pull the underlying stream of tokens
// Hash each token found
// Generate the required number of variants of this hash
// Keep the minimum hash value found so far of each variant
int positionIncrement = 0;
if (requiresInitialisation) {
requiresInitialisation = false;
boolean found = false;
// First time through so we pull and hash everything
while (input.incrementToken()) {
found = true;
String current = new String(termAttribute.buffer(), 0, termAttribute.length());
for (int i = 0; i < hashCount; i++) {
byte[] bytes = current.getBytes("UTF-16LE");
LongPair hash = new LongPair();
murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash);
LongPair rehashed = combineOrdered(hash, getIntHash(i));
minHashSets.get(i).get((int) ((rehashed.val2 >>> 32) / bucketSize)).add(rehashed);
}
endOffset = offsetAttribute.endOffset();
}
exhausted = true;
input.end();
// We need the end state so an underlying shingle filter can have its state restored correctly.
endState = captureState();
if (!found) {
return false;
}
positionIncrement = 1;
// fix up any wrap around bucket values. ...
if (withRotation && (hashSetSize == 1)) {
for (int hashLoop = 0; hashLoop < hashCount; hashLoop++) {
for (int bucketLoop = 0; bucketLoop < bucketCount; bucketLoop++) {
if (minHashSets.get(hashLoop).get(bucketLoop).size() == 0) {
for (int bucketOffset = 1; bucketOffset < bucketCount; bucketOffset++) {
if (minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount).size() > 0) {
LongPair replacementHash = minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount)
.first();
minHashSets.get(hashLoop).get(bucketLoop).add(replacementHash);
break;
}
}
}
}
}
}
}
clearAttributes();
while (hashPosition < hashCount) {
if (hashPosition == -1) {
hashPosition++;
} else {
while (bucketPosition < bucketCount) {
if (bucketPosition == -1) {
bucketPosition++;
} else {
LongPair hash = minHashSets.get(hashPosition).get(bucketPosition).pollFirst();
if (hash != null) {
termAttribute.setEmpty();
if (hashCount > 1) {
termAttribute.append(int0(hashPosition));
termAttribute.append(int1(hashPosition));
}
long high = hash.val2;
termAttribute.append(long0(high));
termAttribute.append(long1(high));
termAttribute.append(long2(high));
termAttribute.append(long3(high));
long low = hash.val1;
termAttribute.append(long0(low));
termAttribute.append(long1(low));
if (hashCount == 1) {
termAttribute.append(long2(low));
termAttribute.append(long3(low));
}
posIncAttribute.setPositionIncrement(positionIncrement);
offsetAttribute.setOffset(0, endOffset);
typeAttribute.setType(MIN_HASH_TYPE);
posLenAttribute.setPositionLength(1);
return true;
} else {
bucketPosition++;
}
}
}
bucketPosition = -1;
hashPosition++;
}
}
return false;
}
private static LongPair getIntHash(int i) {
if (i < HASH_CACHE_SIZE) {
return cachedIntHashes[i];
} else {
LongPair answer = new LongPair();
murmurhash3_x64_128(getBytes(i), 0, 4, 0, answer);
return answer;
}
}
@Override
public void end() throws IOException {
if(!exhausted) {
input.end();
}
restoreState(endState);
}
@Override
public void reset() throws IOException {
super.reset();
doRest();
}
private void doRest() {
for (int i = 0; i < hashCount; i++) {
for (int j = 0; j < bucketCount; j++) {
minHashSets.get(i).get(j).clear();
}
}
endState = null;
hashPosition = -1;
bucketPosition = -1;
requiresInitialisation = true;
exhausted = false;
}
private static char long0(long x) {
return (char) (x >> 48);
}
private static char long1(long x) {
return (char) (x >> 32);
}
private static char long2(long x) {
return (char) (x >> 16);
}
private static char long3(long x) {
return (char) (x);
}
private static char int0(int x) {
return (char) (x >> 16);
}
private static char int1(int x) {
return (char) (x);
}
static boolean isLessThanUnsigned(long n1, long n2) {
return (n1 < n2) ^ ((n1 < 0) != (n2 < 0));
}
static class FixedSizeTreeSet<E extends Comparable<E>> extends TreeSet<E> {
/**
*
*/
private static final long serialVersionUID = -8237117170340299630L;
private final int capacity;
FixedSizeTreeSet() {
this(20);
}
FixedSizeTreeSet(int capacity) {
super();
this.capacity = capacity;
}
@Override
public boolean add(final E toAdd) {
if (capacity <= size()) {
final E lastElm = last();
if (toAdd.compareTo(lastElm) > -1) {
return false;
} else {
pollLast();
}
}
return super.add(toAdd);
}
}
private static LongPair combineOrdered(LongPair... hashCodes) {
LongPair result = new LongPair();
for (LongPair hashCode : hashCodes) {
result.val1 = result.val1 * 37 + hashCode.val1;
result.val2 = result.val2 * 37 + hashCode.val2;
}
return result;
}
/** 128 bits of state */
static final class LongPair implements Comparable<LongPair> {
public long val1;
public long val2;
/*
* (non-Javadoc)
*
* @see java.lang.Comparable#compareTo(java.lang.Object)
*/
@Override
public int compareTo(LongPair other) {
if (isLessThanUnsigned(val2, other.val2)) {
return -1;
} else if (val2 == other.val2) {
if (isLessThanUnsigned(val1, other.val1)) {
return -1;
} else if (val1 == other.val1) {
return 0;
} else {
return 1;
}
} else {
return 1;
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LongPair longPair = (LongPair) o;
return val1 == longPair.val1 && val2 == longPair.val2;
}
@Override
public int hashCode() {
int result = (int) (val1 ^ (val1 >>> 32));
result = 31 * result + (int) (val2 ^ (val2 >>> 32));
return result;
}
}
/** Gets a long from a byte buffer in little endian byte order. */
private static long getLongLittleEndian(byte[] buf, int offset) {
return ((long) buf[offset + 7] << 56) // no mask needed
| ((buf[offset + 6] & 0xffL) << 48)
| ((buf[offset + 5] & 0xffL) << 40)
| ((buf[offset + 4] & 0xffL) << 32)
| ((buf[offset + 3] & 0xffL) << 24)
| ((buf[offset + 2] & 0xffL) << 16)
| ((buf[offset + 1] & 0xffL) << 8)
| ((buf[offset] & 0xffL)); // no shift needed
}
/** Returns the MurmurHash3_x64_128 hash, placing the result in "out". */
static void murmurhash3_x64_128(byte[] key, int offset, int len, int seed, LongPair out) {
// The original algorithm does have a 32 bit unsigned seed.
// We have to mask to match the behavior of the unsigned types and prevent sign extension.
long h1 = seed & 0x00000000FFFFFFFFL;
long h2 = seed & 0x00000000FFFFFFFFL;
final long c1 = 0x87c37b91114253d5L;
final long c2 = 0x4cf5ad432745937fL;
int roundedEnd = offset + (len & 0xFFFFFFF0); // round down to 16 byte block
for (int i = offset; i < roundedEnd; i += 16) {
long k1 = getLongLittleEndian(key, i);
long k2 = getLongLittleEndian(key, i + 8);
k1 *= c1;
k1 = Long.rotateLeft(k1, 31);
k1 *= c2;
h1 ^= k1;
h1 = Long.rotateLeft(h1, 27);
h1 += h2;
h1 = h1 * 5 + 0x52dce729;
k2 *= c2;
k2 = Long.rotateLeft(k2, 33);
k2 *= c1;
h2 ^= k2;
h2 = Long.rotateLeft(h2, 31);
h2 += h1;
h2 = h2 * 5 + 0x38495ab5;
}
long k1 = 0;
long k2 = 0;
switch (len & 15) {
case 15:
k2 = (key[roundedEnd + 14] & 0xffL) << 48;
case 14:
k2 |= (key[roundedEnd + 13] & 0xffL) << 40;
case 13:
k2 |= (key[roundedEnd + 12] & 0xffL) << 32;
case 12:
k2 |= (key[roundedEnd + 11] & 0xffL) << 24;
case 11:
k2 |= (key[roundedEnd + 10] & 0xffL) << 16;
case 10:
k2 |= (key[roundedEnd + 9] & 0xffL) << 8;
case 9:
k2 |= (key[roundedEnd + 8] & 0xffL);
k2 *= c2;
k2 = Long.rotateLeft(k2, 33);
k2 *= c1;
h2 ^= k2;
case 8:
k1 = ((long) key[roundedEnd + 7]) << 56;
case 7:
k1 |= (key[roundedEnd + 6] & 0xffL) << 48;
case 6:
k1 |= (key[roundedEnd + 5] & 0xffL) << 40;
case 5:
k1 |= (key[roundedEnd + 4] & 0xffL) << 32;
case 4:
k1 |= (key[roundedEnd + 3] & 0xffL) << 24;
case 3:
k1 |= (key[roundedEnd + 2] & 0xffL) << 16;
case 2:
k1 |= (key[roundedEnd + 1] & 0xffL) << 8;
case 1:
k1 |= (key[roundedEnd] & 0xffL);
k1 *= c1;
k1 = Long.rotateLeft(k1, 31);
k1 *= c2;
h1 ^= k1;
}
// ----------
// finalization
h1 ^= len;
h2 ^= len;
h1 += h2;
h2 += h1;
h1 = fmix64(h1);
h2 = fmix64(h2);
h1 += h2;
h2 += h1;
out.val1 = h1;
out.val2 = h2;
}
private static long fmix64(long k) {
k ^= k >>> 33;
k *= 0xff51afd7ed558ccdL;
k ^= k >>> 33;
k *= 0xc4ceb9fe1a85ec53L;
k ^= k >>> 33;
return k;
}
}

View File

@ -0,0 +1,57 @@
package org.apache.lucene.analysis.minhash;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.util.Map;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.util.TokenFilterFactory;
/**
* {@link TokenFilterFactory} for {@link MinHashFilter}.
*/
public class MinHashFilterFactory extends TokenFilterFactory {
private int hashCount = MinHashFilter.DEFAULT_HASH_COUNT;
private int bucketCount = MinHashFilter.DEFAULT_BUCKET_COUNT;
private int hashSetSize = MinHashFilter.DEFAULT_HASH_SET_SIZE;
private boolean withRotation;
/**
* Create a {@link MinHashFilterFactory}.
*/
public MinHashFilterFactory(Map<String,String> args) {
super(args);
hashCount = getInt(args, "hashCount", MinHashFilter.DEFAULT_HASH_COUNT);
bucketCount = getInt(args, "bucketCount", MinHashFilter.DEFAULT_BUCKET_COUNT);
hashSetSize = getInt(args, "hashSetSize", MinHashFilter.DEFAULT_HASH_SET_SIZE);
withRotation = getBoolean(args, "withRotation", bucketCount > 1);
}
/*
* (non-Javadoc)
*
* @see org.apache.lucene.analysis.util.TokenFilterFactory#create(org.apache.lucene.analysis.TokenStream)
*/
@Override
public TokenStream create(TokenStream input) {
return new MinHashFilter(input, hashCount, bucketCount, hashSetSize, withRotation);
}
}

View File

@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* MinHash filtering (for LSH).
*/
package org.apache.lucene.analysis.minhash;

View File

@ -58,6 +58,7 @@ org.apache.lucene.analysis.id.IndonesianStemFilterFactory
org.apache.lucene.analysis.in.IndicNormalizationFilterFactory
org.apache.lucene.analysis.it.ItalianLightStemFilterFactory
org.apache.lucene.analysis.lv.LatvianStemFilterFactory
org.apache.lucene.analysis.minhash.MinHashFilterFactory
org.apache.lucene.analysis.miscellaneous.ASCIIFoldingFilterFactory
org.apache.lucene.analysis.miscellaneous.CapitalizationFilterFactory
org.apache.lucene.analysis.miscellaneous.CodepointCountFilterFactory

View File

@ -0,0 +1,330 @@
package org.apache.lucene.analysis.minhash;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.io.StringReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import org.apache.lucene.analysis.BaseTokenStreamTestCase;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.Tokenizer;
import org.apache.lucene.analysis.minhash.MinHashFilter.FixedSizeTreeSet;
import org.apache.lucene.analysis.minhash.MinHashFilter.LongPair;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.automaton.CharacterRunAutomaton;
import org.apache.lucene.util.automaton.RegExp;
import org.junit.Test;
/**
* Tests for {@link MinHashFilter}
*/
public class MinHashFilterTest extends BaseTokenStreamTestCase {
@Test
public void testIntHash() {
LongPair hash = new LongPair();
MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(0), 0, 4, 0, hash);
assertEquals(-3485513579396041028L, hash.val1);
assertEquals(6383328099726337777L, hash.val2);
}
@Test
public void testStringHash() throws UnsupportedEncodingException {
LongPair hash = new LongPair();
byte[] bytes = "woof woof woof woof woof".getBytes("UTF-16LE");
MinHashFilter.murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash);
assertEquals(7638079586852243959L, hash.val1);
assertEquals(4378804943379391304L, hash.val2);
}
@Test
public void testSimpleOrder() throws UnsupportedEncodingException {
LongPair hash1 = new LongPair();
hash1.val1 = 1;
hash1.val2 = 2;
LongPair hash2 = new LongPair();
hash2.val1 = 2;
hash2.val2 = 1;
assert (hash1.compareTo(hash2) > 0);
}
@Test
public void testHashOrder() {
assertTrue(!MinHashFilter.isLessThanUnsigned(0L, 0L));
assertTrue(MinHashFilter.isLessThanUnsigned(0L, -1L));
assertTrue(MinHashFilter.isLessThanUnsigned(1L, -1L));
assertTrue(MinHashFilter.isLessThanUnsigned(-2L, -1L));
assertTrue(MinHashFilter.isLessThanUnsigned(1L, 2L));
assertTrue(MinHashFilter.isLessThanUnsigned(Long.MAX_VALUE, Long.MIN_VALUE));
FixedSizeTreeSet<LongPair> minSet = new FixedSizeTreeSet<LongPair>(500);
HashSet<LongPair> unadded = new HashSet<LongPair>();
for (int i = 0; i < 100; i++) {
LongPair hash = new LongPair();
MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash);
LongPair peek = null;
if (minSet.size() > 0) {
peek = minSet.last();
}
if (!minSet.add(hash)) {
unadded.add(hash);
} else {
if (peek != null) {
if ((minSet.size() == 500) && !peek.equals(minSet.last())) {
unadded.add(peek);
}
}
}
}
assertEquals(100, minSet.size());
assertEquals(0, unadded.size());
HashSet<LongPair> collisionDetection = new HashSet<LongPair>();
unadded = new HashSet<LongPair>();
minSet = new FixedSizeTreeSet<LongPair>(500);
for (int i = 0; i < 1000000; i++) {
LongPair hash = new LongPair();
MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash);
collisionDetection.add(hash);
LongPair peek = null;
if (minSet.size() > 0) {
peek = minSet.last();
}
if (!minSet.add(hash)) {
unadded.add(hash);
} else {
if (peek != null) {
if ((minSet.size() == 500) && !peek.equals(minSet.last())) {
unadded.add(peek);
}
}
}
}
assertEquals(1000000, collisionDetection.size());
assertEquals(500, minSet.size());
assertEquals(999500, unadded.size());
LongPair last = null;
LongPair current = null;
while ((current = minSet.pollLast()) != null) {
if (last != null) {
assertTrue(isLessThan(current, last));
}
last = current;
}
}
@Test
public void testHashNotRepeated() {
FixedSizeTreeSet<LongPair> minSet = new FixedSizeTreeSet<LongPair>(500);
HashSet<LongPair> unadded = new HashSet<LongPair>();
for (int i = 0; i < 10000; i++) {
LongPair hash = new LongPair();
MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash);
LongPair peek = null;
if (minSet.size() > 0) {
peek = minSet.last();
}
if (!minSet.add(hash)) {
unadded.add(hash);
} else {
if (peek != null) {
if ((minSet.size() == 500) && !peek.equals(minSet.last())) {
unadded.add(peek);
}
}
}
}
assertEquals(500, minSet.size());
LongPair last = null;
LongPair current = null;
while ((current = minSet.pollLast()) != null) {
if (last != null) {
assertTrue(isLessThan(current, last));
}
last = current;
}
}
@Test
public void testMockShingleTokenizer() throws IOException {
Tokenizer mockShingleTokenizer = createMockShingleTokenizer(5,
"woof woof woof woof woof" + " " + "woof woof woof woof puff");
assertTokenStreamContents(mockShingleTokenizer,
new String[]{"woof woof woof woof woof", "woof woof woof woof puff"});
}
@Test
public void testTokenStreamSingleInput() throws IOException {
String[] hashes = new String[]{"℁팽徭聙↝ꇁ홱杯"};
TokenStream ts = createTokenStream(5, "woof woof woof woof woof", 1, 1, 100, false);
assertTokenStreamContents(ts, hashes, new int[]{0},
new int[]{24}, new String[]{MinHashFilter.MIN_HASH_TYPE}, new int[]{1}, new int[]{1}, 24, 0, null,
true);
ts = createTokenStream(5, "woof woof woof woof woof", 2, 1, 1, false);
assertTokenStreamContents(ts, new String[]{new String(new char[]{0, 0, 8449, 54077, 64133, 32857, 8605, 41409}),
new String(new char[]{0, 1, 16887, 58164, 39536, 14926, 6529, 17276})}, new int[]{0, 0},
new int[]{24, 24}, new String[]{MinHashFilter.MIN_HASH_TYPE, MinHashFilter.MIN_HASH_TYPE}, new int[]{1, 0},
new int[]{1, 1}, 24, 0, null,
true);
}
@Test
public void testTokenStream1() throws IOException {
String[] hashes = new String[]{"℁팽徭聙↝ꇁ홱杯",
new String(new char[]{36347, 63457, 43013, 56843, 52284, 34231, 57934, 42302})}; // String is degenerate as
// characters!
TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 1, 100,
false);
assertTokenStreamContents(ts, hashes, new int[]{0, 0},
new int[]{49, 49}, new String[]{MinHashFilter.MIN_HASH_TYPE, MinHashFilter.MIN_HASH_TYPE}, new int[]{1, 0},
new int[]{1, 1}, 49, 0, null, true);
}
private ArrayList<String> getTokens(TokenStream ts) throws IOException {
ArrayList<String> tokens = new ArrayList<>();
ts.reset();
while (ts.incrementToken()) {
CharTermAttribute termAttribute = ts.getAttribute(CharTermAttribute.class);
String token = new String(termAttribute.buffer(), 0, termAttribute.length());
tokens.add(token);
}
ts.end();
ts.close();
return tokens;
}
@Test
public void testTokenStream2() throws IOException {
TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 100, 1, 1,
false);
ArrayList<String> tokens = getTokens(ts);
ts.close();
assertEquals(100, tokens.size());
}
@Test
public void testTokenStream3() throws IOException {
TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 1, 10,
false);
ArrayList<String> tokens = getTokens(ts);
ts.close();
assertEquals(20, tokens.size());
}
@Test
public void testTokenStream4() throws IOException {
TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 10, 1,
false);
ArrayList<String> tokens = getTokens(ts);
ts.close();
assertEquals(20, tokens.size());
ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 10, 1, true);
tokens = getTokens(ts);
ts.close();
assertEquals(100, tokens.size());
}
@Test
public void testTokenStream5() throws IOException {
TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 100, 1,
false);
ArrayList<String> tokens = getTokens(ts);
ts.close();
assertEquals(2, tokens.size());
ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 100, 1, true);
tokens = getTokens(ts);
ts.close();
assertEquals(100, tokens.size());
HashSet<String> set = new HashSet<>(tokens);
assertEquals(2, set.size());
boolean rolled = false;
String first = null;
String last = null;
for (String current : tokens) {
if (first == null) {
first = current;
}
if (last != null) {
if (!rolled) {
if (current.compareTo(last) >= 0) {
// fine
} else if (current.equals(first)) {
rolled = true;
} else {
fail("Incorrect hash order");
}
} else {
if (!current.equals(first)) {
fail("Incorrect hash order");
}
}
}
last = current;
}
}
public static TokenStream createTokenStream(int shingleSize, String shingles, int hashCount, int bucketCount,
int hashSetSize, boolean withRotation) {
Tokenizer tokenizer = createMockShingleTokenizer(shingleSize, shingles);
HashMap<String, String> lshffargs = new HashMap<>();
lshffargs.put("hashCount", "" + hashCount);
lshffargs.put("bucketCount", "" + bucketCount);
lshffargs.put("hashSetSize", "" + hashSetSize);
lshffargs.put("withRotation", "" + withRotation);
MinHashFilterFactory lshff = new MinHashFilterFactory(lshffargs);
return lshff.create(tokenizer);
}
private static Tokenizer createMockShingleTokenizer(int shingleSize, String shingles) {
MockTokenizer tokenizer = new MockTokenizer(
new CharacterRunAutomaton(new RegExp("[^ \t\r\n]+([ \t\r\n]+[^ \t\r\n]+){" + (shingleSize - 1) + "}").toAutomaton()),
true);
tokenizer.setEnableChecks(true);
if (shingles != null) {
tokenizer.setReader(new StringReader(shingles));
}
return tokenizer;
}
private boolean isLessThan(LongPair hash1, LongPair hash2) {
return MinHashFilter.isLessThanUnsigned(hash1.val2, hash2.val2) || hash1.val2 == hash2.val2 && (MinHashFilter.isLessThanUnsigned(hash1.val1, hash2.val1));
}
}