From e26a44319429caa8aa28f5c27d620dcbad0dc240 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Fri, 25 Mar 2016 17:21:36 +0100 Subject: [PATCH] LUCENE-7095: Add point values support to the numeric field query time join --- lucene/CHANGES.txt | 3 + .../apache/lucene/search/join/JoinUtil.java | 274 +++++++++++++- .../join/PointInSetIncludingScoreQuery.java | 340 ++++++++++++++++++ .../lucene/search/join/TestJoinUtil.java | 54 ++- 4 files changed, 660 insertions(+), 11 deletions(-) create mode 100644 lucene/join/src/java/org/apache/lucene/search/join/PointInSetIncludingScoreQuery.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index bef81cada97..fa9ab3165db 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -233,6 +233,9 @@ Other * LUCENE-7093: Add point values support to MemoryIndex (Martijn van Groningen, Mike McCandless) +* LUCENE-7095: Add point values support to the numeric field query time join. + (Martijn van Groningen, Mike McCandless) + ======================= Lucene 5.5.0 ======================= New Features diff --git a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java index 11a44a57dd0..dd253431355 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java @@ -17,21 +17,39 @@ package org.apache.lucene.search.join; import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; import java.util.Locale; +import java.util.Map; +import java.util.TreeSet; +import java.util.function.BiConsumer; +import java.util.function.LongFunction; +import org.apache.lucene.document.DoublePoint; import org.apache.lucene.document.FieldType.LegacyNumericType; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DocValues; import org.apache.lucene.index.DocValuesType; - import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Collector; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.PointInSetQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.join.DocValuesTermsCollector.Function; +import org.apache.lucene.util.BytesRef; /** * Utility for query time joining. @@ -94,6 +112,8 @@ public final class JoinUtil { } /** + * @deprecated Because {@link LegacyNumericType} is deprecated, instead use {@link #createJoinQuery(String, boolean, String, Class, Query, IndexSearcher, ScoreMode)} + * * Method for query time joining for numeric fields. It supports multi- and single- values longs and ints. * All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, * though memory consumption might be higher. @@ -112,7 +132,7 @@ public final class JoinUtil { * terms in the from and to field * @throws IOException If I/O related errors occur */ - + @Deprecated public static Query createJoinQuery(String fromField, boolean multipleValuesPerDocument, String toField, LegacyNumericType numericType, @@ -134,7 +154,255 @@ public final class JoinUtil { termsCollector); } - + + /** + * Method for query time joining for numeric fields. It supports multi- and single- values longs, ints, floats and longs. + * All considerations from {@link JoinUtil#createJoinQuery(String, boolean, String, Query, IndexSearcher, ScoreMode)} are applicable here too, + * though memory consumption might be higher. + *

+ * + * @param fromField The from field to join from + * @param multipleValuesPerDocument Whether the from field has multiple terms per document + * when true fromField might be {@link DocValuesType#SORTED_NUMERIC}, + * otherwise fromField should be {@link DocValuesType#NUMERIC} + * @param toField The to field to join to, should be {@link IntPoint}, {@link LongPoint}, {@link FloatPoint} + * or {@link DoublePoint}. + * @param numericType either {@link java.lang.Integer}, {@link java.lang.Long}, {@link java.lang.Float} + * or {@link java.lang.Double} it should correspond to toField types + * @param fromQuery The query to match documents on the from side + * @param fromSearcher The searcher that executed the specified fromQuery + * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query + * @return a {@link Query} instance that can be used to join documents based on the + * terms in the from and to field + * @throws IOException If I/O related errors occur + */ + public static Query createJoinQuery(String fromField, + boolean multipleValuesPerDocument, + String toField, + Class numericType, + Query fromQuery, + IndexSearcher fromSearcher, + ScoreMode scoreMode) throws IOException { + TreeSet joinValues = new TreeSet<>(); + Map aggregatedScores = new HashMap<>(); + Map occurrences = new HashMap<>(); + boolean needsScore = scoreMode != ScoreMode.None; + BiConsumer scoreAggregator; + if (scoreMode == ScoreMode.Max) { + scoreAggregator = (key, score) -> { + Float currentValue = aggregatedScores.putIfAbsent(key, score); + if (currentValue != null) { + aggregatedScores.put(key, Math.max(currentValue, score)); + } + }; + } else if (scoreMode == ScoreMode.Min) { + scoreAggregator = (key, score) -> { + Float currentValue = aggregatedScores.putIfAbsent(key, score); + if (currentValue != null) { + aggregatedScores.put(key, Math.min(currentValue, score)); + } + }; + } else if (scoreMode == ScoreMode.Total) { + scoreAggregator = (key, score) -> { + Float currentValue = aggregatedScores.putIfAbsent(key, score); + if (currentValue != null) { + aggregatedScores.put(key, currentValue + score); + } + }; + } else if (scoreMode == ScoreMode.Avg) { + scoreAggregator = (key, score) -> { + Float currentSore = aggregatedScores.putIfAbsent(key, score); + if (currentSore != null) { + aggregatedScores.put(key, currentSore + score); + } + Integer currentOccurrence = occurrences.putIfAbsent(key, 1); + if (currentOccurrence != null) { + occurrences.put(key, ++currentOccurrence); + } + + }; + } else { + scoreAggregator = (key, score) -> { + throw new UnsupportedOperationException(); + }; + } + + LongFunction joinScorer; + if (scoreMode == ScoreMode.Avg) { + joinScorer = (joinValue) -> { + Float aggregatedScore = aggregatedScores.get(joinValue); + Integer occurrence = occurrences.get(joinValue); + return aggregatedScore / occurrence; + }; + } else { + joinScorer = aggregatedScores::get; + } + + Collector collector; + if (multipleValuesPerDocument) { + collector = new SimpleCollector() { + + SortedNumericDocValues sortedNumericDocValues; + Scorer scorer; + + @Override + public void collect(int doc) throws IOException { + sortedNumericDocValues.setDocument(doc); + for (int i = 0; i < sortedNumericDocValues.count(); i++) { + long value = sortedNumericDocValues.valueAt(i); + joinValues.add(value); + if (needsScore) { + scoreAggregator.accept(value, scorer.score()); + } + } + } + + @Override + protected void doSetNextReader(LeafReaderContext context) throws IOException { + sortedNumericDocValues = DocValues.getSortedNumeric(context.reader(), fromField); + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public boolean needsScores() { + return needsScore; + } + }; + } else { + collector = new SimpleCollector() { + + NumericDocValues numericDocValues; + Scorer scorer; + + @Override + public void collect(int doc) throws IOException { + long value = numericDocValues.get(doc); + joinValues.add(value); + if (needsScore) { + scoreAggregator.accept(value, scorer.score()); + } + } + + @Override + protected void doSetNextReader(LeafReaderContext context) throws IOException { + numericDocValues = DocValues.getNumeric(context.reader(), fromField); + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public boolean needsScores() { + return needsScore; + } + }; + } + fromSearcher.search(fromQuery, collector); + + Iterator iterator = joinValues.iterator(); + + final int bytesPerDim; + final BytesRef encoded = new BytesRef(); + final PointInSetIncludingScoreQuery.Stream stream; + if (Integer.class.equals(numericType)) { + bytesPerDim = Integer.BYTES; + stream = new PointInSetIncludingScoreQuery.Stream() { + @Override + public BytesRef next() { + if (iterator.hasNext()) { + long value = iterator.next(); + IntPoint.encodeDimension((int) value, encoded.bytes, 0); + if (needsScore) { + score = joinScorer.apply(value); + } + return encoded; + } else { + return null; + } + } + }; + } else if (Long.class.equals(numericType)) { + bytesPerDim = Long.BYTES; + stream = new PointInSetIncludingScoreQuery.Stream() { + @Override + public BytesRef next() { + if (iterator.hasNext()) { + long value = iterator.next(); + LongPoint.encodeDimension(value, encoded.bytes, 0); + if (needsScore) { + score = joinScorer.apply(value); + } + return encoded; + } else { + return null; + } + } + }; + } else if (Float.class.equals(numericType)) { + bytesPerDim = Float.BYTES; + stream = new PointInSetIncludingScoreQuery.Stream() { + @Override + public BytesRef next() { + if (iterator.hasNext()) { + long value = iterator.next(); + FloatPoint.encodeDimension(Float.intBitsToFloat((int) value), encoded.bytes, 0); + if (needsScore) { + score = joinScorer.apply(value); + } + return encoded; + } else { + return null; + } + } + }; + } else if (Double.class.equals(numericType)) { + bytesPerDim = Double.BYTES; + stream = new PointInSetIncludingScoreQuery.Stream() { + @Override + public BytesRef next() { + if (iterator.hasNext()) { + long value = iterator.next(); + DoublePoint.encodeDimension(Double.longBitsToDouble(value), encoded.bytes, 0); + if (needsScore) { + score = joinScorer.apply(value); + } + return encoded; + } else { + return null; + } + } + }; + } else { + throw new IllegalArgumentException("unsupported numeric type, only Integer, Long, Float and Double are supported"); + } + + encoded.bytes = new byte[bytesPerDim]; + encoded.length = bytesPerDim; + + if (needsScore) { + return new PointInSetIncludingScoreQuery(fromQuery, multipleValuesPerDocument, toField, bytesPerDim, stream) { + + @Override + protected String toString(byte[] value) { + return toString.apply(value, numericType); + } + }; + } else { + return new PointInSetQuery(toField, 1, bytesPerDim, stream) { + @Override + protected String toString(byte[] value) { + return PointInSetIncludingScoreQuery.toString.apply(value, numericType); + } + }; + } + } + private static Query createJoinQuery(boolean multipleValuesPerDocument, String toField, Query fromQuery, IndexSearcher fromSearcher, ScoreMode scoreMode, final GenericTermsCollector collector) throws IOException { diff --git a/lucene/join/src/java/org/apache/lucene/search/join/PointInSetIncludingScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/PointInSetIncludingScoreQuery.java new file mode 100644 index 00000000000..df6aa98b63e --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/PointInSetIncludingScoreQuery.java @@ -0,0 +1,340 @@ +package org.apache.lucene.search.join; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +import org.apache.lucene.document.DoublePoint; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.PointValues.IntersectVisitor; +import org.apache.lucene.index.PointValues.Relation; +import org.apache.lucene.index.PrefixCodedTerms; +import org.apache.lucene.index.PrefixCodedTerms.TermIterator; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.PointInSetQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.BitSetIterator; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.BytesRefBuilder; +import org.apache.lucene.util.DocIdSetBuilder; +import org.apache.lucene.util.FixedBitSet; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// A TermsIncludingScoreQuery variant for point values: +abstract class PointInSetIncludingScoreQuery extends Query { + + static BiFunction, String> toString = (value, numericType) -> { + if (Integer.class.equals(numericType)) { + return Integer.toString(IntPoint.decodeDimension(value, 0)); + } else if (Long.class.equals(numericType)) { + return Long.toString(LongPoint.decodeDimension(value, 0)); + } else if (Float.class.equals(numericType)) { + return Float.toString(FloatPoint.decodeDimension(value, 0)); + } else if (Double.class.equals(numericType)) { + return Double.toString(DoublePoint.decodeDimension(value, 0)); + } else { + return "unsupported"; + } + }; + + final Query originalQuery; + final boolean multipleValuesPerDocument; + final PrefixCodedTerms sortedPackedPoints; + final int sortedPackedPointsHashCode; + final String field; + final int bytesPerDim; + + final List aggregatedJoinScores; + + static abstract class Stream extends PointInSetQuery.Stream { + + float score; + + } + + PointInSetIncludingScoreQuery(Query originalQuery, boolean multipleValuesPerDocument, String field, int bytesPerDim, + Stream packedPoints) { + this.originalQuery = originalQuery; + this.multipleValuesPerDocument = multipleValuesPerDocument; + this.field = field; + if (bytesPerDim < 1 || bytesPerDim > PointValues.MAX_NUM_BYTES) { + throw new IllegalArgumentException("bytesPerDim must be > 0 and <= " + PointValues.MAX_NUM_BYTES + "; got " + bytesPerDim); + } + this.bytesPerDim = bytesPerDim; + + aggregatedJoinScores = new ArrayList<>(); + PrefixCodedTerms.Builder builder = new PrefixCodedTerms.Builder(); + BytesRefBuilder previous = null; + BytesRef current; + while ((current = packedPoints.next()) != null) { + if (current.length != bytesPerDim) { + throw new IllegalArgumentException("packed point length should be " + (bytesPerDim) + " but got " + current.length + "; field=\"" + field + "\"bytesPerDim=" + bytesPerDim); + } + if (previous == null) { + previous = new BytesRefBuilder(); + } else { + int cmp = previous.get().compareTo(current); + if (cmp == 0) { + throw new IllegalArgumentException("unexpected duplicated value: " + current); + } else if (cmp >= 0) { + throw new IllegalArgumentException("values are out of order: saw " + previous + " before " + current); + } + } + builder.add(field, current); + aggregatedJoinScores.add(packedPoints.score); + previous.copyBytes(current); + } + sortedPackedPoints = builder.finish(); + sortedPackedPointsHashCode = sortedPackedPoints.hashCode(); + } + + @Override + public final Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException { + final Weight originalWeight = originalQuery.createWeight(searcher, needsScores); + return new Weight(this) { + + @Override + public void extractTerms(Set terms) { + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + Scorer scorer = scorer(context); + if (scorer != null) { + int target = scorer.iterator().advance(doc); + if (doc == target) { + return Explanation.match(scorer.score(), "A match"); + } + } + return Explanation.noMatch("Not a match"); + } + + @Override + public float getValueForNormalization() throws IOException { + return originalWeight.getValueForNormalization(); + } + + @Override + public void normalize(float norm, float boost) { + originalWeight.normalize(norm, boost); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + LeafReader reader = context.reader(); + PointValues values = reader.getPointValues(); + if (values == null) { + return null; + } + FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(field); + if (fieldInfo == null) { + return null; + } + if (fieldInfo.getPointDimensionCount() != 1) { + throw new IllegalArgumentException("field=\"" + field + "\" was indexed with numDims=" + fieldInfo.getPointDimensionCount() + " but this query has numDims=1"); + } + if (fieldInfo.getPointNumBytes() != bytesPerDim) { + throw new IllegalArgumentException("field=\"" + field + "\" was indexed with bytesPerDim=" + fieldInfo.getPointNumBytes() + " but this query has bytesPerDim=" + bytesPerDim); + } + + FixedBitSet result = new FixedBitSet(reader.maxDoc()); + float[] scores = new float[reader.maxDoc()]; + values.intersect(field, new MergePointVisitor(sortedPackedPoints, result, scores)); + return new Scorer(this) { + + DocIdSetIterator disi = new BitSetIterator(result, 10L); + + @Override + public float score() throws IOException { + return scores[docID()]; + } + + @Override + public int freq() throws IOException { + return 1; + } + + @Override + public int docID() { + return disi.docID(); + } + + @Override + public DocIdSetIterator iterator() { + return disi; + } + + }; + } + }; + } + + private class MergePointVisitor implements IntersectVisitor { + + private final FixedBitSet result; + private final float[] scores; + + private TermIterator iterator; + private Iterator scoreIterator; + private BytesRef nextQueryPoint; + float nextScore; + private final BytesRef scratch = new BytesRef(); + + private MergePointVisitor(PrefixCodedTerms sortedPackedPoints, FixedBitSet result, float[] scores) throws IOException { + this.result = result; + this.scores = scores; + scratch.length = bytesPerDim; + this.iterator = sortedPackedPoints.iterator(); + this.scoreIterator = aggregatedJoinScores.iterator(); + nextQueryPoint = iterator.next(); + if (scoreIterator.hasNext()) { + nextScore = scoreIterator.next(); + } + } + + @Override + public void visit(int docID) { + throw new IllegalStateException("shouldn't get here, since CELL_INSIDE_QUERY isn't emitted"); + } + + @Override + public void visit(int docID, byte[] packedValue) { + scratch.bytes = packedValue; + while (nextQueryPoint != null) { + int cmp = nextQueryPoint.compareTo(scratch); + if (cmp == 0) { + // Query point equals index point, so collect and return + if (multipleValuesPerDocument) { + if (result.get(docID) == false) { + result.set(docID); + scores[docID] = nextScore; + } + } else { + result.set(docID); + scores[docID] = nextScore; + } + break; + } else if (cmp < 0) { + // Query point is before index point, so we move to next query point + nextQueryPoint = iterator.next(); + if (scoreIterator.hasNext()) { + nextScore = scoreIterator.next(); + } + } else { + // Query point is after index point, so we don't collect and we return: + break; + } + } + } + + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + while (nextQueryPoint != null) { + scratch.bytes = minPackedValue; + int cmpMin = nextQueryPoint.compareTo(scratch); + if (cmpMin < 0) { + // query point is before the start of this cell + nextQueryPoint = iterator.next(); + if (scoreIterator.hasNext()) { + nextScore = scoreIterator.next(); + } + continue; + } + scratch.bytes = maxPackedValue; + int cmpMax = nextQueryPoint.compareTo(scratch); + if (cmpMax > 0) { + // query point is after the end of this cell + return Relation.CELL_OUTSIDE_QUERY; + } + + return Relation.CELL_CROSSES_QUERY; + } + + // We exhausted all points in the query: + return Relation.CELL_OUTSIDE_QUERY; + } + } + + @Override + public final int hashCode() { + int hash = super.hashCode(); + hash = 31 * hash + field.hashCode(); + hash = 31 * hash + originalQuery.hashCode(); + hash = 31 * hash + sortedPackedPointsHashCode; + hash = 31 * hash + bytesPerDim; + return hash; + } + + @Override + public final boolean equals(Object other) { + if (super.equals(other)) { + final PointInSetIncludingScoreQuery q = (PointInSetIncludingScoreQuery) other; + return q.field.equals(field) && + q.originalQuery.equals(originalQuery) && + q.bytesPerDim == bytesPerDim && + q.sortedPackedPointsHashCode == sortedPackedPointsHashCode && + q.sortedPackedPoints.equals(sortedPackedPoints); + } + + return false; + } + + @Override + public final String toString(String field) { + final StringBuilder sb = new StringBuilder(); + if (this.field.equals(field) == false) { + sb.append(this.field); + sb.append(':'); + } + + sb.append("{"); + + TermIterator iterator = sortedPackedPoints.iterator(); + byte[] pointBytes = new byte[bytesPerDim]; + boolean first = true; + for (BytesRef point = iterator.next(); point != null; point = iterator.next()) { + if (first == false) { + sb.append(" "); + } + first = false; + System.arraycopy(point.bytes, point.offset, pointBytes, 0, pointBytes.length); + sb.append(toString(pointBytes)); + } + sb.append("}"); + return sb.toString(); + } + + protected abstract String toString(byte[] value); +} diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java index 2796e015b51..c67811fd589 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java @@ -34,10 +34,16 @@ import java.util.TreeSet; import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockTokenizer; import org.apache.lucene.document.Document; +import org.apache.lucene.document.DoubleDocValuesField; +import org.apache.lucene.document.DoublePoint; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType.LegacyNumericType; +import org.apache.lucene.document.FloatDocValuesField; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.LegacyIntField; import org.apache.lucene.document.LegacyLongField; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; @@ -962,12 +968,36 @@ public class TestJoinUtil extends LuceneTestCase { final boolean muliValsQuery = multipleValuesPerDocument || random().nextBoolean(); final String fromField = from ? "from":"to"; final String toField = from ? "to":"from"; - - if (random().nextBoolean()) { // numbers - final LegacyNumericType numType = random().nextBoolean() ? LegacyNumericType.INT: LegacyNumericType.LONG ; - joinQuery = JoinUtil.createJoinQuery(fromField+numType, muliValsQuery, toField+numType, numType, actualQuery, indexSearcher, scoreMode); - } else { - joinQuery = JoinUtil.createJoinQuery(fromField, muliValsQuery, toField, actualQuery, indexSearcher, scoreMode); + + int surpriseMe = random().nextInt(3); + switch (surpriseMe) { + case 0: + Class numType; + String suffix; + if (random().nextBoolean()) { + numType = Integer.class; + suffix = "INT"; + } else if (random().nextBoolean()) { + numType = Float.class; + suffix = "FLOAT"; + } else if (random().nextBoolean()) { + numType = Long.class; + suffix = "LONG"; + } else { + numType = Double.class; + suffix = "DOUBLE"; + } + joinQuery = JoinUtil.createJoinQuery(fromField + suffix, muliValsQuery, toField + suffix, numType, actualQuery, indexSearcher, scoreMode); + break; + case 1: + final LegacyNumericType legacyNumType = random().nextBoolean() ? LegacyNumericType.INT: LegacyNumericType.LONG ; + joinQuery = JoinUtil.createJoinQuery(fromField+legacyNumType, muliValsQuery, toField+legacyNumType, legacyNumType, actualQuery, indexSearcher, scoreMode); + break; + case 2: + joinQuery = JoinUtil.createJoinQuery(fromField, muliValsQuery, toField, actualQuery, indexSearcher, scoreMode); + break; + default: + throw new RuntimeException("unexpected value " + surpriseMe); } } if (VERBOSE) { @@ -1304,19 +1334,27 @@ public class TestJoinUtil extends LuceneTestCase { document.add(newTextField(random, fieldName, linkValue, Field.Store.NO)); final int linkInt = Integer.parseUnsignedInt(linkValue,16); - document.add(new LegacyIntField(fieldName+ LegacyNumericType.INT, linkInt, Field.Store.NO)); + document.add(new LegacyIntField(fieldName + LegacyNumericType.INT, linkInt, Field.Store.NO)); + document.add(new IntPoint(fieldName + LegacyNumericType.INT, linkInt)); + document.add(new FloatPoint(fieldName + "FLOAT", linkInt)); final long linkLong = linkInt<<32 | linkInt; - document.add(new LegacyLongField(fieldName+ LegacyNumericType.LONG, linkLong, Field.Store.NO)); + document.add(new LegacyLongField(fieldName + LegacyNumericType.LONG, linkLong, Field.Store.NO)); + document.add(new LongPoint(fieldName + LegacyNumericType.LONG, linkLong)); + document.add(new DoublePoint(fieldName + "DOUBLE", linkLong)); if (multipleValuesPerDocument) { document.add(new SortedSetDocValuesField(fieldName, new BytesRef(linkValue))); document.add(new SortedNumericDocValuesField(fieldName+ LegacyNumericType.INT, linkInt)); + document.add(new SortedNumericDocValuesField(fieldName+ "FLOAT", Float.floatToRawIntBits(linkInt))); document.add(new SortedNumericDocValuesField(fieldName+ LegacyNumericType.LONG, linkLong)); + document.add(new SortedNumericDocValuesField(fieldName+ "DOUBLE", Double.doubleToRawLongBits(linkLong))); } else { document.add(new SortedDocValuesField(fieldName, new BytesRef(linkValue))); document.add(new NumericDocValuesField(fieldName+ LegacyNumericType.INT, linkInt)); + document.add(new FloatDocValuesField(fieldName+ "FLOAT", linkInt)); document.add(new NumericDocValuesField(fieldName+ LegacyNumericType.LONG, linkLong)); + document.add(new DoubleDocValuesField(fieldName+ "DOUBLE", linkLong)); } if (globalOrdinalJoin) { document.add(new SortedDocValuesField("join_field", new BytesRef(linkValue)));