diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 3c524401152..5d97fdda6e5 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -13,6 +13,8 @@ Other * LUCENE-7328: Remove LegacyNumericEncoding from GeoPointField. (Nick Knize) +* LUCENE-6968: LSH Filter (Tommaso Teofili, Andy Hind, Cao Manh Dat) + ======================= Lucene 6.2.0 ======================= New Features diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilter.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilter.java new file mode 100644 index 00000000000..60df1c051c4 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilter.java @@ -0,0 +1,504 @@ +package org.apache.lucene.analysis.minhash; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.TreeSet; + +import org.apache.lucene.analysis.TokenFilter; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute; +import org.apache.lucene.analysis.tokenattributes.PositionLengthAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; + +/** + * Generate min hash tokens from an incoming stream of tokens. The incoming tokens would typically be 5 word shingles. + * + * The number of hashes used and the number of minimum values for each hash can be set. You could have 1 hash and keep + * the 100 lowest values or 100 hashes and keep the lowest one for each. Hashes can also be bucketed in ranges over the + * 128-bit hash space, + * + * A 128-bit hash is used internally. 5 word shingles from 10e5 words generate 10e25 combinations So a 64 bit hash would + * have collisions (1.8e19) + * + * When using different hashes 32 bits are used for the hash position leaving scope for 8e28 unique hashes. A single + * hash will use all 128 bits. + * + */ +public class MinHashFilter extends TokenFilter { + private static final int HASH_CACHE_SIZE = 512; + + private static final LongPair[] cachedIntHashes = new LongPair[HASH_CACHE_SIZE]; + + static final int DEFAULT_HASH_COUNT = 1; + + static final int DEFAULT_HASH_SET_SIZE = 1; + + static final int DEFAULT_BUCKET_COUNT = 512; + + static final String MIN_HASH_TYPE = "MIN_HASH"; + + private final List>> minHashSets; + + private int hashSetSize = DEFAULT_HASH_SET_SIZE; + + private int bucketCount = DEFAULT_BUCKET_COUNT; + + private int hashCount = DEFAULT_HASH_COUNT; + + private boolean requiresInitialisation = true; + + private State endState; + + private int hashPosition = -1; + + private int bucketPosition = -1; + + private long bucketSize; + + private final boolean withRotation; + + private int endOffset; + + private boolean exhausted = false; + + private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class); + private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class); + private final TypeAttribute typeAttribute = addAttribute(TypeAttribute.class); + private final PositionIncrementAttribute posIncAttribute = addAttribute(PositionIncrementAttribute.class); + private final PositionLengthAttribute posLenAttribute = addAttribute(PositionLengthAttribute.class); + + static { + for (int i = 0; i < HASH_CACHE_SIZE; i++) { + cachedIntHashes[i] = new LongPair(); + murmurhash3_x64_128(getBytes(i), 0, 4, 0, cachedIntHashes[i]); + } + } + + static byte[] getBytes(int i) { + byte[] answer = new byte[4]; + answer[3] = (byte) (i); + answer[2] = (byte) (i >> 8); + answer[1] = (byte) (i >> 16); + answer[0] = (byte) (i >> 24); + return answer; + } + + /** + * create a MinHash filter + * + * @param input the token stream + * @param hashCount the no. of hashes + * @param bucketCount the no. of buckets for hashing + * @param hashSetSize the no. of min hashes to keep + * @param withRotation whether rotate or not hashes while incrementing tokens + */ + MinHashFilter(TokenStream input, int hashCount, int bucketCount, int hashSetSize, boolean withRotation) { + super(input); + this.hashCount = hashCount; + this.bucketCount = bucketCount; + this.hashSetSize = hashSetSize; + this.withRotation = withRotation; + this.bucketSize = (1L << 32) / bucketCount; + if((1L << 32) % bucketCount != 0) + { + bucketSize++; + } + minHashSets = new ArrayList<>(this.hashCount); + for (int i = 0; i < this.hashCount; i++) { + ArrayList> buckets = new ArrayList<>(this.bucketCount); + minHashSets.add(buckets); + for (int j = 0; j < this.bucketCount; j++) { + FixedSizeTreeSet minSet = new FixedSizeTreeSet<>(this.hashSetSize); + buckets.add(minSet); + } + } + doRest(); + } + + @Override + public final boolean incrementToken() throws IOException { + // Pull the underlying stream of tokens + // Hash each token found + // Generate the required number of variants of this hash + // Keep the minimum hash value found so far of each variant + + int positionIncrement = 0; + if (requiresInitialisation) { + requiresInitialisation = false; + boolean found = false; + // First time through so we pull and hash everything + while (input.incrementToken()) { + found = true; + String current = new String(termAttribute.buffer(), 0, termAttribute.length()); + + for (int i = 0; i < hashCount; i++) { + byte[] bytes = current.getBytes("UTF-16LE"); + LongPair hash = new LongPair(); + murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash); + LongPair rehashed = combineOrdered(hash, getIntHash(i)); + minHashSets.get(i).get((int) ((rehashed.val2 >>> 32) / bucketSize)).add(rehashed); + } + endOffset = offsetAttribute.endOffset(); + } + exhausted = true; + input.end(); + // We need the end state so an underlying shingle filter can have its state restored correctly. + endState = captureState(); + if (!found) { + return false; + } + + positionIncrement = 1; + // fix up any wrap around bucket values. ... + if (withRotation && (hashSetSize == 1)) { + for (int hashLoop = 0; hashLoop < hashCount; hashLoop++) { + for (int bucketLoop = 0; bucketLoop < bucketCount; bucketLoop++) { + if (minHashSets.get(hashLoop).get(bucketLoop).size() == 0) { + for (int bucketOffset = 1; bucketOffset < bucketCount; bucketOffset++) { + if (minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount).size() > 0) { + LongPair replacementHash = minHashSets.get(hashLoop).get((bucketLoop + bucketOffset) % bucketCount) + .first(); + minHashSets.get(hashLoop).get(bucketLoop).add(replacementHash); + break; + } + } + } + } + } + } + + } + + clearAttributes(); + + while (hashPosition < hashCount) { + if (hashPosition == -1) { + hashPosition++; + } else { + while (bucketPosition < bucketCount) { + if (bucketPosition == -1) { + bucketPosition++; + } else { + LongPair hash = minHashSets.get(hashPosition).get(bucketPosition).pollFirst(); + if (hash != null) { + termAttribute.setEmpty(); + if (hashCount > 1) { + termAttribute.append(int0(hashPosition)); + termAttribute.append(int1(hashPosition)); + } + long high = hash.val2; + termAttribute.append(long0(high)); + termAttribute.append(long1(high)); + termAttribute.append(long2(high)); + termAttribute.append(long3(high)); + long low = hash.val1; + termAttribute.append(long0(low)); + termAttribute.append(long1(low)); + if (hashCount == 1) { + termAttribute.append(long2(low)); + termAttribute.append(long3(low)); + } + posIncAttribute.setPositionIncrement(positionIncrement); + offsetAttribute.setOffset(0, endOffset); + typeAttribute.setType(MIN_HASH_TYPE); + posLenAttribute.setPositionLength(1); + return true; + } else { + bucketPosition++; + } + } + + } + bucketPosition = -1; + hashPosition++; + } + } + return false; + } + + private static LongPair getIntHash(int i) { + if (i < HASH_CACHE_SIZE) { + return cachedIntHashes[i]; + } else { + LongPair answer = new LongPair(); + murmurhash3_x64_128(getBytes(i), 0, 4, 0, answer); + return answer; + } + } + + @Override + public void end() throws IOException { + if(!exhausted) { + input.end(); + } + + restoreState(endState); + } + + @Override + public void reset() throws IOException { + super.reset(); + doRest(); + } + + private void doRest() { + for (int i = 0; i < hashCount; i++) { + for (int j = 0; j < bucketCount; j++) { + minHashSets.get(i).get(j).clear(); + } + } + endState = null; + hashPosition = -1; + bucketPosition = -1; + requiresInitialisation = true; + exhausted = false; + } + + private static char long0(long x) { + return (char) (x >> 48); + } + + private static char long1(long x) { + return (char) (x >> 32); + } + + private static char long2(long x) { + return (char) (x >> 16); + } + + private static char long3(long x) { + return (char) (x); + } + + private static char int0(int x) { + return (char) (x >> 16); + } + + private static char int1(int x) { + return (char) (x); + } + + static boolean isLessThanUnsigned(long n1, long n2) { + return (n1 < n2) ^ ((n1 < 0) != (n2 < 0)); + } + + static class FixedSizeTreeSet> extends TreeSet { + + /** + * + */ + private static final long serialVersionUID = -8237117170340299630L; + private final int capacity; + + FixedSizeTreeSet() { + this(20); + } + + FixedSizeTreeSet(int capacity) { + super(); + this.capacity = capacity; + } + + @Override + public boolean add(final E toAdd) { + if (capacity <= size()) { + final E lastElm = last(); + if (toAdd.compareTo(lastElm) > -1) { + return false; + } else { + pollLast(); + } + } + return super.add(toAdd); + } + } + + private static LongPair combineOrdered(LongPair... hashCodes) { + LongPair result = new LongPair(); + for (LongPair hashCode : hashCodes) { + result.val1 = result.val1 * 37 + hashCode.val1; + result.val2 = result.val2 * 37 + hashCode.val2; + + } + return result; + } + + /** 128 bits of state */ + static final class LongPair implements Comparable { + public long val1; + public long val2; + + /* + * (non-Javadoc) + * + * @see java.lang.Comparable#compareTo(java.lang.Object) + */ + @Override + public int compareTo(LongPair other) { + if (isLessThanUnsigned(val2, other.val2)) { + return -1; + } else if (val2 == other.val2) { + if (isLessThanUnsigned(val1, other.val1)) { + return -1; + } else if (val1 == other.val1) { + return 0; + } else { + return 1; + } + } else { + return 1; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + LongPair longPair = (LongPair) o; + + return val1 == longPair.val1 && val2 == longPair.val2; + + } + + @Override + public int hashCode() { + int result = (int) (val1 ^ (val1 >>> 32)); + result = 31 * result + (int) (val2 ^ (val2 >>> 32)); + return result; + } + } + + /** Gets a long from a byte buffer in little endian byte order. */ + private static long getLongLittleEndian(byte[] buf, int offset) { + return ((long) buf[offset + 7] << 56) // no mask needed + | ((buf[offset + 6] & 0xffL) << 48) + | ((buf[offset + 5] & 0xffL) << 40) + | ((buf[offset + 4] & 0xffL) << 32) + | ((buf[offset + 3] & 0xffL) << 24) + | ((buf[offset + 2] & 0xffL) << 16) + | ((buf[offset + 1] & 0xffL) << 8) + | ((buf[offset] & 0xffL)); // no shift needed + } + + /** Returns the MurmurHash3_x64_128 hash, placing the result in "out". */ + static void murmurhash3_x64_128(byte[] key, int offset, int len, int seed, LongPair out) { + // The original algorithm does have a 32 bit unsigned seed. + // We have to mask to match the behavior of the unsigned types and prevent sign extension. + long h1 = seed & 0x00000000FFFFFFFFL; + long h2 = seed & 0x00000000FFFFFFFFL; + + final long c1 = 0x87c37b91114253d5L; + final long c2 = 0x4cf5ad432745937fL; + + int roundedEnd = offset + (len & 0xFFFFFFF0); // round down to 16 byte block + for (int i = offset; i < roundedEnd; i += 16) { + long k1 = getLongLittleEndian(key, i); + long k2 = getLongLittleEndian(key, i + 8); + k1 *= c1; + k1 = Long.rotateLeft(k1, 31); + k1 *= c2; + h1 ^= k1; + h1 = Long.rotateLeft(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + k2 *= c2; + k2 = Long.rotateLeft(k2, 33); + k2 *= c1; + h2 ^= k2; + h2 = Long.rotateLeft(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + } + + long k1 = 0; + long k2 = 0; + + switch (len & 15) { + case 15: + k2 = (key[roundedEnd + 14] & 0xffL) << 48; + case 14: + k2 |= (key[roundedEnd + 13] & 0xffL) << 40; + case 13: + k2 |= (key[roundedEnd + 12] & 0xffL) << 32; + case 12: + k2 |= (key[roundedEnd + 11] & 0xffL) << 24; + case 11: + k2 |= (key[roundedEnd + 10] & 0xffL) << 16; + case 10: + k2 |= (key[roundedEnd + 9] & 0xffL) << 8; + case 9: + k2 |= (key[roundedEnd + 8] & 0xffL); + k2 *= c2; + k2 = Long.rotateLeft(k2, 33); + k2 *= c1; + h2 ^= k2; + case 8: + k1 = ((long) key[roundedEnd + 7]) << 56; + case 7: + k1 |= (key[roundedEnd + 6] & 0xffL) << 48; + case 6: + k1 |= (key[roundedEnd + 5] & 0xffL) << 40; + case 5: + k1 |= (key[roundedEnd + 4] & 0xffL) << 32; + case 4: + k1 |= (key[roundedEnd + 3] & 0xffL) << 24; + case 3: + k1 |= (key[roundedEnd + 2] & 0xffL) << 16; + case 2: + k1 |= (key[roundedEnd + 1] & 0xffL) << 8; + case 1: + k1 |= (key[roundedEnd] & 0xffL); + k1 *= c1; + k1 = Long.rotateLeft(k1, 31); + k1 *= c2; + h1 ^= k1; + } + + // ---------- + // finalization + + h1 ^= len; + h2 ^= len; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + out.val1 = h1; + out.val2 = h2; + } + + private static long fmix64(long k) { + k ^= k >>> 33; + k *= 0xff51afd7ed558ccdL; + k ^= k >>> 33; + k *= 0xc4ceb9fe1a85ec53L; + k ^= k >>> 33; + return k; + } +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilterFactory.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilterFactory.java new file mode 100644 index 00000000000..d951b9bb1c0 --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/MinHashFilterFactory.java @@ -0,0 +1,57 @@ +package org.apache.lucene.analysis.minhash; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import java.util.Map; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.util.TokenFilterFactory; + +/** + * {@link TokenFilterFactory} for {@link MinHashFilter}. + */ +public class MinHashFilterFactory extends TokenFilterFactory { + private int hashCount = MinHashFilter.DEFAULT_HASH_COUNT; + + private int bucketCount = MinHashFilter.DEFAULT_BUCKET_COUNT; + + private int hashSetSize = MinHashFilter.DEFAULT_HASH_SET_SIZE; + + private boolean withRotation; + + /** + * Create a {@link MinHashFilterFactory}. + */ + public MinHashFilterFactory(Map args) { + super(args); + hashCount = getInt(args, "hashCount", MinHashFilter.DEFAULT_HASH_COUNT); + bucketCount = getInt(args, "bucketCount", MinHashFilter.DEFAULT_BUCKET_COUNT); + hashSetSize = getInt(args, "hashSetSize", MinHashFilter.DEFAULT_HASH_SET_SIZE); + withRotation = getBoolean(args, "withRotation", bucketCount > 1); + } + + /* + * (non-Javadoc) + * + * @see org.apache.lucene.analysis.util.TokenFilterFactory#create(org.apache.lucene.analysis.TokenStream) + */ + @Override + public TokenStream create(TokenStream input) { + return new MinHashFilter(input, hashCount, bucketCount, hashSetSize, withRotation); + } + +} diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/package-info.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/package-info.java new file mode 100644 index 00000000000..b47634791bb --- /dev/null +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/minhash/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * MinHash filtering (for LSH). + */ +package org.apache.lucene.analysis.minhash; diff --git a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.util.TokenFilterFactory b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.util.TokenFilterFactory index 8b19cdfc45f..70120c5221b 100644 --- a/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.util.TokenFilterFactory +++ b/lucene/analysis/common/src/resources/META-INF/services/org.apache.lucene.analysis.util.TokenFilterFactory @@ -58,6 +58,7 @@ org.apache.lucene.analysis.id.IndonesianStemFilterFactory org.apache.lucene.analysis.in.IndicNormalizationFilterFactory org.apache.lucene.analysis.it.ItalianLightStemFilterFactory org.apache.lucene.analysis.lv.LatvianStemFilterFactory +org.apache.lucene.analysis.minhash.MinHashFilterFactory org.apache.lucene.analysis.miscellaneous.ASCIIFoldingFilterFactory org.apache.lucene.analysis.miscellaneous.CapitalizationFilterFactory org.apache.lucene.analysis.miscellaneous.CodepointCountFilterFactory diff --git a/lucene/analysis/common/src/test/org/apache/lucene/analysis/minhash/MinHashFilterTest.java b/lucene/analysis/common/src/test/org/apache/lucene/analysis/minhash/MinHashFilterTest.java new file mode 100644 index 00000000000..7c02d3b3aed --- /dev/null +++ b/lucene/analysis/common/src/test/org/apache/lucene/analysis/minhash/MinHashFilterTest.java @@ -0,0 +1,330 @@ +package org.apache.lucene.analysis.minhash; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.io.StringReader; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; + +import org.apache.lucene.analysis.BaseTokenStreamTestCase; +import org.apache.lucene.analysis.MockTokenizer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.Tokenizer; +import org.apache.lucene.analysis.minhash.MinHashFilter.FixedSizeTreeSet; +import org.apache.lucene.analysis.minhash.MinHashFilter.LongPair; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.util.automaton.CharacterRunAutomaton; +import org.apache.lucene.util.automaton.RegExp; +import org.junit.Test; + +/** + * Tests for {@link MinHashFilter} + */ +public class MinHashFilterTest extends BaseTokenStreamTestCase { + + @Test + public void testIntHash() { + LongPair hash = new LongPair(); + MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(0), 0, 4, 0, hash); + assertEquals(-3485513579396041028L, hash.val1); + assertEquals(6383328099726337777L, hash.val2); + } + + @Test + public void testStringHash() throws UnsupportedEncodingException { + LongPair hash = new LongPair(); + byte[] bytes = "woof woof woof woof woof".getBytes("UTF-16LE"); + MinHashFilter.murmurhash3_x64_128(bytes, 0, bytes.length, 0, hash); + assertEquals(7638079586852243959L, hash.val1); + assertEquals(4378804943379391304L, hash.val2); + } + + @Test + public void testSimpleOrder() throws UnsupportedEncodingException { + LongPair hash1 = new LongPair(); + hash1.val1 = 1; + hash1.val2 = 2; + LongPair hash2 = new LongPair(); + hash2.val1 = 2; + hash2.val2 = 1; + assert (hash1.compareTo(hash2) > 0); + } + + @Test + public void testHashOrder() { + assertTrue(!MinHashFilter.isLessThanUnsigned(0L, 0L)); + assertTrue(MinHashFilter.isLessThanUnsigned(0L, -1L)); + assertTrue(MinHashFilter.isLessThanUnsigned(1L, -1L)); + assertTrue(MinHashFilter.isLessThanUnsigned(-2L, -1L)); + assertTrue(MinHashFilter.isLessThanUnsigned(1L, 2L)); + assertTrue(MinHashFilter.isLessThanUnsigned(Long.MAX_VALUE, Long.MIN_VALUE)); + + FixedSizeTreeSet minSet = new FixedSizeTreeSet(500); + HashSet unadded = new HashSet(); + for (int i = 0; i < 100; i++) { + LongPair hash = new LongPair(); + MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash); + LongPair peek = null; + if (minSet.size() > 0) { + peek = minSet.last(); + } + + if (!minSet.add(hash)) { + unadded.add(hash); + } else { + if (peek != null) { + if ((minSet.size() == 500) && !peek.equals(minSet.last())) { + unadded.add(peek); + } + } + } + } + assertEquals(100, minSet.size()); + assertEquals(0, unadded.size()); + + HashSet collisionDetection = new HashSet(); + unadded = new HashSet(); + minSet = new FixedSizeTreeSet(500); + for (int i = 0; i < 1000000; i++) { + LongPair hash = new LongPair(); + MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash); + collisionDetection.add(hash); + LongPair peek = null; + if (minSet.size() > 0) { + peek = minSet.last(); + } + + if (!minSet.add(hash)) { + unadded.add(hash); + } else { + if (peek != null) { + if ((minSet.size() == 500) && !peek.equals(minSet.last())) { + unadded.add(peek); + } + } + } + } + assertEquals(1000000, collisionDetection.size()); + assertEquals(500, minSet.size()); + assertEquals(999500, unadded.size()); + + LongPair last = null; + LongPair current = null; + while ((current = minSet.pollLast()) != null) { + if (last != null) { + assertTrue(isLessThan(current, last)); + } + last = current; + } + } + + @Test + public void testHashNotRepeated() { + FixedSizeTreeSet minSet = new FixedSizeTreeSet(500); + HashSet unadded = new HashSet(); + for (int i = 0; i < 10000; i++) { + LongPair hash = new LongPair(); + MinHashFilter.murmurhash3_x64_128(MinHashFilter.getBytes(i), 0, 4, 0, hash); + LongPair peek = null; + if (minSet.size() > 0) { + peek = minSet.last(); + } + if (!minSet.add(hash)) { + unadded.add(hash); + } else { + if (peek != null) { + if ((minSet.size() == 500) && !peek.equals(minSet.last())) { + unadded.add(peek); + } + } + } + } + assertEquals(500, minSet.size()); + + LongPair last = null; + LongPair current = null; + while ((current = minSet.pollLast()) != null) { + if (last != null) { + assertTrue(isLessThan(current, last)); + } + last = current; + } + } + + @Test + public void testMockShingleTokenizer() throws IOException { + Tokenizer mockShingleTokenizer = createMockShingleTokenizer(5, + "woof woof woof woof woof" + " " + "woof woof woof woof puff"); + assertTokenStreamContents(mockShingleTokenizer, + new String[]{"woof woof woof woof woof", "woof woof woof woof puff"}); + } + + @Test + public void testTokenStreamSingleInput() throws IOException { + String[] hashes = new String[]{"℁팽徭聙↝ꇁ홱杯"}; + TokenStream ts = createTokenStream(5, "woof woof woof woof woof", 1, 1, 100, false); + assertTokenStreamContents(ts, hashes, new int[]{0}, + new int[]{24}, new String[]{MinHashFilter.MIN_HASH_TYPE}, new int[]{1}, new int[]{1}, 24, 0, null, + true); + + ts = createTokenStream(5, "woof woof woof woof woof", 2, 1, 1, false); + assertTokenStreamContents(ts, new String[]{new String(new char[]{0, 0, 8449, 54077, 64133, 32857, 8605, 41409}), + new String(new char[]{0, 1, 16887, 58164, 39536, 14926, 6529, 17276})}, new int[]{0, 0}, + new int[]{24, 24}, new String[]{MinHashFilter.MIN_HASH_TYPE, MinHashFilter.MIN_HASH_TYPE}, new int[]{1, 0}, + new int[]{1, 1}, 24, 0, null, + true); + } + + @Test + public void testTokenStream1() throws IOException { + String[] hashes = new String[]{"℁팽徭聙↝ꇁ홱杯", + new String(new char[]{36347, 63457, 43013, 56843, 52284, 34231, 57934, 42302})}; // String is degenerate as + // characters! + + TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 1, 100, + false); + assertTokenStreamContents(ts, hashes, new int[]{0, 0}, + new int[]{49, 49}, new String[]{MinHashFilter.MIN_HASH_TYPE, MinHashFilter.MIN_HASH_TYPE}, new int[]{1, 0}, + new int[]{1, 1}, 49, 0, null, true); + } + + private ArrayList getTokens(TokenStream ts) throws IOException { + ArrayList tokens = new ArrayList<>(); + ts.reset(); + while (ts.incrementToken()) { + CharTermAttribute termAttribute = ts.getAttribute(CharTermAttribute.class); + String token = new String(termAttribute.buffer(), 0, termAttribute.length()); + tokens.add(token); + } + ts.end(); + ts.close(); + + return tokens; + } + + @Test + public void testTokenStream2() throws IOException { + TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 100, 1, 1, + false); + ArrayList tokens = getTokens(ts); + ts.close(); + + assertEquals(100, tokens.size()); + } + + @Test + public void testTokenStream3() throws IOException { + TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 1, 10, + false); + ArrayList tokens = getTokens(ts); + ts.close(); + + assertEquals(20, tokens.size()); + } + + @Test + public void testTokenStream4() throws IOException { + TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 10, 1, + false); + ArrayList tokens = getTokens(ts); + ts.close(); + + assertEquals(20, tokens.size()); + + ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 10, 10, 1, true); + tokens = getTokens(ts); + ts.close(); + + assertEquals(100, tokens.size()); + + } + + @Test + public void testTokenStream5() throws IOException { + TokenStream ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 100, 1, + false); + ArrayList tokens = getTokens(ts); + ts.close(); + + assertEquals(2, tokens.size()); + + ts = createTokenStream(5, "woof woof woof woof woof" + " " + "woof woof woof woof puff", 1, 100, 1, true); + tokens = getTokens(ts); + ts.close(); + + assertEquals(100, tokens.size()); + HashSet set = new HashSet<>(tokens); + assertEquals(2, set.size()); + + boolean rolled = false; + String first = null; + String last = null; + for (String current : tokens) { + if (first == null) { + first = current; + } + if (last != null) { + if (!rolled) { + if (current.compareTo(last) >= 0) { + // fine + } else if (current.equals(first)) { + rolled = true; + } else { + fail("Incorrect hash order"); + } + } else { + if (!current.equals(first)) { + fail("Incorrect hash order"); + } + } + } + last = current; + } + + } + + public static TokenStream createTokenStream(int shingleSize, String shingles, int hashCount, int bucketCount, + int hashSetSize, boolean withRotation) { + Tokenizer tokenizer = createMockShingleTokenizer(shingleSize, shingles); + HashMap lshffargs = new HashMap<>(); + lshffargs.put("hashCount", "" + hashCount); + lshffargs.put("bucketCount", "" + bucketCount); + lshffargs.put("hashSetSize", "" + hashSetSize); + lshffargs.put("withRotation", "" + withRotation); + MinHashFilterFactory lshff = new MinHashFilterFactory(lshffargs); + return lshff.create(tokenizer); + } + + private static Tokenizer createMockShingleTokenizer(int shingleSize, String shingles) { + MockTokenizer tokenizer = new MockTokenizer( + new CharacterRunAutomaton(new RegExp("[^ \t\r\n]+([ \t\r\n]+[^ \t\r\n]+){" + (shingleSize - 1) + "}").toAutomaton()), + true); + tokenizer.setEnableChecks(true); + if (shingles != null) { + tokenizer.setReader(new StringReader(shingles)); + } + return tokenizer; + } + + private boolean isLessThan(LongPair hash1, LongPair hash2) { + return MinHashFilter.isLessThanUnsigned(hash1.val2, hash2.val2) || hash1.val2 == hash2.val2 && (MinHashFilter.isLessThanUnsigned(hash1.val1, hash2.val1)); + } +}