diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 0468133a323..4ef155e4551 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -45,6 +45,10 @@ New Features faster intersection by avoiding loading positions in certain cases. (Paul Elschot, Robert Muir via Mike McCandless) +* LUCENE-6352: Added a new query time join to the join module that uses + global ordinals, which is faster for subsequent joins between reopens. + (Martijn van Groningen, Adrien Grand) + Optimizations * LUCENE-6379: IndexWriter.deleteDocuments(Query...) now detects if diff --git a/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java new file mode 100644 index 00000000000..4d81d58f92a --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java @@ -0,0 +1,97 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.LongBitSet; + +import java.io.IOException; + +abstract class BaseGlobalOrdinalScorer extends Scorer { + + final LongBitSet foundOrds; + final SortedDocValues values; + final Scorer approximationScorer; + + float score; + + public BaseGlobalOrdinalScorer(Weight weight, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) { + super(weight); + this.foundOrds = foundOrds; + this.values = values; + this.approximationScorer = approximationScorer; + } + + @Override + public float score() throws IOException { + return score; + } + + @Override + public int docID() { + return approximationScorer.docID(); + } + + @Override + public int nextDoc() throws IOException { + return advance(approximationScorer.docID() + 1); + } + + @Override + public TwoPhaseIterator asTwoPhaseIterator() { + final DocIdSetIterator approximation = new DocIdSetIterator() { + @Override + public int docID() { + return approximationScorer.docID(); + } + + @Override + public int nextDoc() throws IOException { + return approximationScorer.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return approximationScorer.advance(target); + } + + @Override + public long cost() { + return approximationScorer.cost(); + } + }; + return createTwoPhaseIterator(approximation); + } + + @Override + public long cost() { + return approximationScorer.cost(); + } + + @Override + public int freq() throws IOException { + return 1; + } + + protected abstract TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation); + +} \ No newline at end of file diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsCollector.java new file mode 100644 index 00000000000..8a874621a79 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsCollector.java @@ -0,0 +1,114 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.util.LongBitSet; +import org.apache.lucene.util.LongValues; + +import java.io.IOException; + +/** + * A collector that collects all ordinals from a specified field matching the query. + * + * @lucene.experimental + */ +final class GlobalOrdinalsCollector implements Collector { + + final String field; + final LongBitSet collectedOrds; + final MultiDocValues.OrdinalMap ordinalMap; + + GlobalOrdinalsCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + this.field = field; + this.ordinalMap = ordinalMap; + this.collectedOrds = new LongBitSet(valueCount); + } + + public LongBitSet getCollectorOrdinals() { + return collectedOrds; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field); + if (ordinalMap != null) { + LongValues segmentOrdToGlobalOrdLookup = ordinalMap.getGlobalOrds(context.ord); + return new OrdinalMapCollector(docTermOrds, segmentOrdToGlobalOrdLookup); + } else { + return new SegmentOrdinalCollector(docTermOrds); + } + } + + final class OrdinalMapCollector implements LeafCollector { + + private final SortedDocValues docTermOrds; + private final LongValues segmentOrdToGlobalOrdLookup; + + OrdinalMapCollector(SortedDocValues docTermOrds, LongValues segmentOrdToGlobalOrdLookup) { + this.docTermOrds = docTermOrds; + this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; + } + + @Override + public void collect(int doc) throws IOException { + final long segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); + collectedOrds.set(globalOrd); + } + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + } + } + + final class SegmentOrdinalCollector implements LeafCollector { + + private final SortedDocValues docTermOrds; + + SegmentOrdinalCollector(SortedDocValues docTermOrds) { + this.docTermOrds = docTermOrds; + } + + @Override + public void collect(int doc) throws IOException { + final long segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + collectedOrds.set(segmentOrd); + } + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + } + } + +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java new file mode 100644 index 00000000000..a96881df1a9 --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java @@ -0,0 +1,245 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.ComplexExplanation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LongBitSet; +import org.apache.lucene.util.LongValues; + +import java.io.IOException; +import java.util.Set; + +final class GlobalOrdinalsQuery extends Query { + + // All the ords of matching docs found with OrdinalsCollector. + private final LongBitSet foundOrds; + private final String joinField; + private final MultiDocValues.OrdinalMap globalOrds; + // Is also an approximation of the docs that will match. Can be all docs that have toField or something more specific. + private final Query toQuery; + + // just for hashcode and equals: + private final Query fromQuery; + private final IndexReader indexReader; + + GlobalOrdinalsQuery(LongBitSet foundOrds, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, IndexReader indexReader) { + this.foundOrds = foundOrds; + this.joinField = joinField; + this.globalOrds = globalOrds; + this.toQuery = toQuery; + this.fromQuery = fromQuery; + this.indexReader = indexReader; + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException { + return new W(this, toQuery.createWeight(searcher, false)); + } + + @Override + public void extractTerms(Set terms) { + fromQuery.extractTerms(terms); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + + GlobalOrdinalsQuery that = (GlobalOrdinalsQuery) o; + + if (!fromQuery.equals(that.fromQuery)) return false; + if (!joinField.equals(that.joinField)) return false; + if (!toQuery.equals(that.toQuery)) return false; + if (!indexReader.equals(that.indexReader)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + joinField.hashCode(); + result = 31 * result + toQuery.hashCode(); + result = 31 * result + fromQuery.hashCode(); + result = 31 * result + indexReader.hashCode(); + return result; + } + + @Override + public String toString(String field) { + return "GlobalOrdinalsQuery{" + + "joinField=" + joinField + + '}'; + } + + final class W extends Weight { + + private final Weight approximationWeight; + + private float queryNorm; + private float queryWeight; + + W(Query query, Weight approximationWeight) { + super(query); + this.approximationWeight = approximationWeight; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + if (values != null) { + int segmentOrd = values.getOrd(doc); + if (segmentOrd != -1) { + BytesRef joinValue = values.lookupOrd(segmentOrd); + return new ComplexExplanation(true, queryNorm, "Score based on join value " + joinValue.utf8ToString()); + } + } + return new ComplexExplanation(false, 0.0f, "Not a match"); + } + + @Override + public float getValueForNormalization() throws IOException { + queryWeight = getBoost(); + return queryWeight * queryWeight; + } + + @Override + public void normalize(float norm, float topLevelBoost) { + this.queryNorm = norm * topLevelBoost; + queryWeight *= this.queryNorm; + } + + @Override + public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException { + SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + if (values == null) { + return null; + } + + Scorer approximationScorer = approximationWeight.scorer(context, acceptDocs); + if (approximationScorer == null) { + return null; + } + if (globalOrds != null) { + return new OrdinalMapScorer(this, queryNorm, foundOrds, values, approximationScorer, globalOrds.getGlobalOrds(context.ord)); + } { + return new SegmentOrdinalScorer(this, queryNorm, foundOrds, values, approximationScorer); + } + } + + } + + final static class OrdinalMapScorer extends BaseGlobalOrdinalScorer { + + final LongValues segmentOrdToGlobalOrdLookup; + + public OrdinalMapScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) { + super(weight, foundOrds, values, approximationScorer); + this.score = score; + this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; + } + + @Override + public int advance(int target) throws IOException { + for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { + final long segmentOrd = values.getOrd(docID); + if (segmentOrd != -1) { + final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (foundOrds.get(globalOrd)) { + return docID; + } + } + } + return NO_MORE_DOCS; + } + + @Override + protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) { + return new TwoPhaseIterator(approximation) { + + @Override + public boolean matches() throws IOException { + final long segmentOrd = values.getOrd(approximationScorer.docID()); + if (segmentOrd != -1) { + final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (foundOrds.get(globalOrd)) { + return true; + } + } + return false; + } + }; + } + } + + final static class SegmentOrdinalScorer extends BaseGlobalOrdinalScorer { + + public SegmentOrdinalScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) { + super(weight, foundOrds, values, approximationScorer); + this.score = score; + } + + @Override + public int advance(int target) throws IOException { + for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { + final long segmentOrd = values.getOrd(docID); + if (segmentOrd != -1) { + if (foundOrds.get(segmentOrd)) { + return docID; + } + } + } + return NO_MORE_DOCS; + } + + @Override + protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) { + return new TwoPhaseIterator(approximation) { + + @Override + public boolean matches() throws IOException { + final long segmentOrd = values.getOrd(approximationScorer.docID()); + if (segmentOrd != -1) { + if (foundOrds.get(segmentOrd)) { + return true; + } + } + return false; + } + }; + } + + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java new file mode 100644 index 00000000000..37ac699280f --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreCollector.java @@ -0,0 +1,250 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.util.LongBitSet; +import org.apache.lucene.util.LongValues; + +import java.io.IOException; + +abstract class GlobalOrdinalsWithScoreCollector implements Collector { + + final String field; + final MultiDocValues.OrdinalMap ordinalMap; + final LongBitSet collectedOrds; + protected final Scores scores; + + GlobalOrdinalsWithScoreCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + if (valueCount > Integer.MAX_VALUE) { + // We simply don't support more than + throw new IllegalStateException("Can't collect more than [" + Integer.MAX_VALUE + "] ids"); + } + this.field = field; + this.ordinalMap = ordinalMap; + this.collectedOrds = new LongBitSet(valueCount); + this.scores = new Scores(valueCount); + } + + public LongBitSet getCollectorOrdinals() { + return collectedOrds; + } + + public float score(int globalOrdinal) { + return scores.getScore(globalOrdinal); + } + + protected abstract void doScore(int globalOrd, float existingScore, float newScore); + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field); + if (ordinalMap != null) { + LongValues segmentOrdToGlobalOrdLookup = ordinalMap.getGlobalOrds(context.ord); + return new OrdinalMapCollector(docTermOrds, segmentOrdToGlobalOrdLookup); + } else { + return new SegmentOrdinalCollector(docTermOrds); + } + } + + @Override + public boolean needsScores() { + return true; + } + + final class OrdinalMapCollector implements LeafCollector { + + private final SortedDocValues docTermOrds; + private final LongValues segmentOrdToGlobalOrdLookup; + private Scorer scorer; + + OrdinalMapCollector(SortedDocValues docTermOrds, LongValues segmentOrdToGlobalOrdLookup) { + this.docTermOrds = docTermOrds; + this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; + } + + @Override + public void collect(int doc) throws IOException { + final long segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + final int globalOrd = (int) segmentOrdToGlobalOrdLookup.get(segmentOrd); + collectedOrds.set(globalOrd); + float existingScore = scores.getScore(globalOrd); + float newScore = scorer.score(); + doScore(globalOrd, existingScore, newScore); + } + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + } + + final class SegmentOrdinalCollector implements LeafCollector { + + private final SortedDocValues docTermOrds; + private Scorer scorer; + + SegmentOrdinalCollector(SortedDocValues docTermOrds) { + this.docTermOrds = docTermOrds; + } + + @Override + public void collect(int doc) throws IOException { + final int segmentOrd = docTermOrds.getOrd(doc); + if (segmentOrd != -1) { + collectedOrds.set(segmentOrd); + float existingScore = scores.getScore(segmentOrd); + float newScore = scorer.score(); + doScore(segmentOrd, existingScore, newScore); + } + } + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + } + + static final class Max extends GlobalOrdinalsWithScoreCollector { + + public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + super(field, ordinalMap, valueCount); + } + + @Override + protected void doScore(int globalOrd, float existingScore, float newScore) { + scores.setScore(globalOrd, Math.max(existingScore, newScore)); + } + + } + + static final class Sum extends GlobalOrdinalsWithScoreCollector { + + public Sum(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + super(field, ordinalMap, valueCount); + } + + @Override + protected void doScore(int globalOrd, float existingScore, float newScore) { + scores.setScore(globalOrd, existingScore + newScore); + } + + } + + static final class Avg extends GlobalOrdinalsWithScoreCollector { + + private final Occurrences occurrences; + + public Avg(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) { + super(field, ordinalMap, valueCount); + this.occurrences = new Occurrences(valueCount); + } + + @Override + protected void doScore(int globalOrd, float existingScore, float newScore) { + occurrences.increment(globalOrd); + scores.setScore(globalOrd, existingScore + newScore); + } + + @Override + public float score(int globalOrdinal) { + return scores.getScore(globalOrdinal) / occurrences.getOccurence(globalOrdinal); + } + } + + // Because the global ordinal is directly used as a key to a score we should be somewhat smart about allocation + // the scores array. Most of the times not all docs match so splitting the scores array up in blocks can prevent creation of huge arrays. + // Also working with smaller arrays is supposed to be more gc friendly + // + // At first a hash map implementation would make sense, but in the case that more than half of docs match this becomes more expensive + // then just using an array. + + // Maybe this should become a method parameter? + static final int arraySize = 4096; + + static final class Scores { + + final float[][] blocks; + + private Scores(long valueCount) { + long blockSize = valueCount + arraySize - 1; + blocks = new float[(int) ((blockSize) / arraySize)][]; + } + + public void setScore(int globalOrdinal, float score) { + int block = globalOrdinal / arraySize; + int offset = globalOrdinal % arraySize; + float[] scores = blocks[block]; + if (scores == null) { + blocks[block] = scores = new float[arraySize]; + } + scores[offset] = score; + } + + public float getScore(int globalOrdinal) { + int block = globalOrdinal / arraySize; + int offset = globalOrdinal % arraySize; + float[] scores = blocks[block]; + float score; + if (scores != null) { + score = scores[offset]; + } else { + score = 0f; + } + return score; + } + + } + + static final class Occurrences { + + final int[][] blocks; + + private Occurrences(long valueCount) { + long blockSize = valueCount + arraySize - 1; + blocks = new int[(int) (blockSize / arraySize)][]; + } + + public void increment(int globalOrdinal) { + int block = globalOrdinal / arraySize; + int offset = globalOrdinal % arraySize; + int[] occurrences = blocks[block]; + if (occurrences == null) { + blocks[block] = occurrences = new int[arraySize]; + } + occurrences[offset]++; + } + + public int getOccurence(int globalOrdinal) { + int block = globalOrdinal / arraySize; + int offset = globalOrdinal % arraySize; + int[] occurrences = blocks[block]; + return occurrences[offset]; + } + + } + +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java new file mode 100644 index 00000000000..f9d4df7418d --- /dev/null +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java @@ -0,0 +1,256 @@ +package org.apache.lucene.search.join; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.ComplexExplanation; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LongValues; + +import java.io.IOException; +import java.util.Set; + +final class GlobalOrdinalsWithScoreQuery extends Query { + + private final GlobalOrdinalsWithScoreCollector collector; + private final String joinField; + private final MultiDocValues.OrdinalMap globalOrds; + // Is also an approximation of the docs that will match. Can be all docs that have toField or something more specific. + private final Query toQuery; + + // just for hashcode and equals: + private final Query fromQuery; + private final IndexReader indexReader; + + GlobalOrdinalsWithScoreQuery(GlobalOrdinalsWithScoreCollector collector, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, IndexReader indexReader) { + this.collector = collector; + this.joinField = joinField; + this.globalOrds = globalOrds; + this.toQuery = toQuery; + this.fromQuery = fromQuery; + this.indexReader = indexReader; + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException { + return new W(this, toQuery.createWeight(searcher, false)); + } + + @Override + public void extractTerms(Set terms) { + fromQuery.extractTerms(terms); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + if (!super.equals(o)) return false; + + GlobalOrdinalsWithScoreQuery that = (GlobalOrdinalsWithScoreQuery) o; + + if (!fromQuery.equals(that.fromQuery)) return false; + if (!joinField.equals(that.joinField)) return false; + if (!toQuery.equals(that.toQuery)) return false; + if (!indexReader.equals(that.indexReader)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + joinField.hashCode(); + result = 31 * result + toQuery.hashCode(); + result = 31 * result + fromQuery.hashCode(); + result = 31 * result + indexReader.hashCode(); + return result; + } + + @Override + public String toString(String field) { + return "GlobalOrdinalsQuery{" + + "joinField=" + joinField + + '}'; + } + + final class W extends Weight { + + private final Weight approximationWeight; + + private float queryNorm; + private float queryWeight; + + W(Query query, Weight approximationWeight) { + super(query); + this.approximationWeight = approximationWeight; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + if (values != null) { + int segmentOrd = values.getOrd(doc); + if (segmentOrd != -1) { + final float score; + if (globalOrds != null) { + long globalOrd = globalOrds.getGlobalOrds(context.ord).get(segmentOrd); + score = collector.scores.getScore((int) globalOrd); + } else { + score = collector.score(segmentOrd); + } + BytesRef joinValue = values.lookupOrd(segmentOrd); + return new ComplexExplanation(true, score, "Score based on join value " + joinValue.utf8ToString()); + } + } + return new ComplexExplanation(false, 0.0f, "Not a match"); + } + + @Override + public float getValueForNormalization() throws IOException { + queryWeight = getBoost(); + return queryWeight * queryWeight; + } + + @Override + public void normalize(float norm, float topLevelBoost) { + this.queryNorm = norm * topLevelBoost; + queryWeight *= this.queryNorm; + } + + @Override + public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException { + SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + if (values == null) { + return null; + } + + Scorer approximationScorer = approximationWeight.scorer(context, acceptDocs); + if (approximationScorer == null) { + return null; + } else if (globalOrds != null) { + return new OrdinalMapScorer(this, collector, values, approximationScorer, globalOrds.getGlobalOrds(context.ord)); + } else { + return new SegmentOrdinalScorer(this, collector, values, approximationScorer); + } + } + + } + + final static class OrdinalMapScorer extends BaseGlobalOrdinalScorer { + + final LongValues segmentOrdToGlobalOrdLookup; + final GlobalOrdinalsWithScoreCollector collector; + + public OrdinalMapScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) { + super(weight, collector.getCollectorOrdinals(), values, approximationScorer); + this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; + this.collector = collector; + } + + @Override + public int advance(int target) throws IOException { + for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { + final long segmentOrd = values.getOrd(docID); + if (segmentOrd != -1) { + final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (foundOrds.get(globalOrd)) { + score = collector.score((int) globalOrd); + return docID; + } + } + } + return NO_MORE_DOCS; + } + + @Override + protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) { + return new TwoPhaseIterator(approximation) { + + @Override + public boolean matches() throws IOException { + final long segmentOrd = values.getOrd(approximationScorer.docID()); + if (segmentOrd != -1) { + final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd); + if (foundOrds.get(globalOrd)) { + score = collector.score((int) globalOrd); + return true; + } + } + return false; + } + + }; + } + } + + final static class SegmentOrdinalScorer extends BaseGlobalOrdinalScorer { + + final GlobalOrdinalsWithScoreCollector collector; + + public SegmentOrdinalScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer) { + super(weight, collector.getCollectorOrdinals(), values, approximationScorer); + this.collector = collector; + } + + @Override + public int advance(int target) throws IOException { + for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) { + final int segmentOrd = values.getOrd(docID); + if (segmentOrd != -1) { + if (foundOrds.get(segmentOrd)) { + score = collector.score(segmentOrd); + return docID; + } + } + } + return NO_MORE_DOCS; + } + + @Override + protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) { + return new TwoPhaseIterator(approximation) { + + @Override + public boolean matches() throws IOException { + final int segmentOrd = values.getOrd(approximationScorer.docID()); + if (segmentOrd != -1) { + if (foundOrds.get(segmentOrd)) { + score = collector.score(segmentOrd); + return true; + } + } + return false; + } + }; + } + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java index 44abe7bc464..89ac5089797 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/JoinUtil.java @@ -17,7 +17,12 @@ package org.apache.lucene.search.join; * limitations under the License. */ +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.MultiDocValues; +import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import java.io.IOException; @@ -90,4 +95,78 @@ public final class JoinUtil { } } + /** + * A query time join using global ordinals over a dedicated join field. + * + * This join has certain restrictions and requirements: + * 1) A document can only refer to one other document. (but can be referred by one or more documents) + * 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field + * that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account. + * 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field + * should store the join values as UTF-8 strings. + * 4) An ordinal map must be provided that is created on top of the join field. + * + * @param joinField The {@link org.apache.lucene.index.SortedDocValues} field containing the join values + * @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents. + * @param toQuery The query identifying all documents on the "to" side. + * @param searcher The index searcher used to execute the from query + * @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query + * @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map + * needs to be provided. + * @return a {@link Query} instance that can be used to join documents based on the join field + * @throws IOException If I/O related errors occur + */ + public static Query createJoinQuery(String joinField, + Query fromQuery, + Query toQuery, + IndexSearcher searcher, + ScoreMode scoreMode, + MultiDocValues.OrdinalMap ordinalMap) throws IOException { + IndexReader indexReader = searcher.getIndexReader(); + int numSegments = indexReader.leaves().size(); + final long valueCount; + if (numSegments == 0) { + return new MatchNoDocsQuery(); + } else if (numSegments == 1) { + // No need to use the ordinal map, because there is just one segment. + ordinalMap = null; + LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader(); + SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField); + if (joinSortedDocValues != null) { + valueCount = joinSortedDocValues.getValueCount(); + } else { + return new MatchNoDocsQuery(); + } + } else { + if (ordinalMap == null) { + throw new IllegalArgumentException("OrdinalMap is required, because there is more than 1 segment"); + } + valueCount = ordinalMap.getValueCount(); + } + + Query rewrittenFromQuery = searcher.rewrite(fromQuery); + if (scoreMode == ScoreMode.None) { + GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount); + searcher.search(fromQuery, globalOrdinalsCollector); + return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader); + } + + GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector; + switch (scoreMode) { + case Total: + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount); + break; + case Max: + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount); + break; + case Avg: + globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount); + break; + default: + throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode)); + } + searcher.search(fromQuery, globalOrdinalsWithScoreCollector); + return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader); + } + } diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java index 46fa0c1979b..bb2fb5ff0ac 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java @@ -17,19 +17,6 @@ package org.apache.lucene.search.join; * limitations under the License. */ -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import java.util.SortedSet; -import java.util.TreeSet; - import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockTokenizer; import org.apache.lucene.document.Document; @@ -38,27 +25,29 @@ import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.TextField; import org.apache.lucene.index.BinaryDocValues; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DocValues; -import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiDocValues; import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.SlowCompositeReaderWrapper; +import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.Collector; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.FilterLeafCollector; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Scorer; @@ -74,8 +63,22 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.TestUtil; +import org.apache.lucene.util.packed.PackedInts; import org.junit.Test; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + public class TestJoinUtil extends LuceneTestCase { public void testSimple() throws Exception { @@ -169,6 +172,180 @@ public class TestJoinUtil extends LuceneTestCase { dir.close(); } + public void testSimpleOrdinalsJoin() throws Exception { + final String idField = "id"; + final String productIdField = "productId"; + // A field indicating to what type a document belongs, which is then used to distinques between documents during joining. + final String typeField = "type"; + // A single sorted doc values field that holds the join values for all document types. + // Typically during indexing a schema will automatically create this field with the values + final String joinField = idField + productIdField; + + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random(), + dir, + newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(NoMergePolicy.INSTANCE)); + + // 0 + Document doc = new Document(); + doc.add(new TextField(idField, "1", Field.Store.NO)); + doc.add(new TextField(typeField, "product", Field.Store.NO)); + doc.add(new TextField("description", "random text", Field.Store.NO)); + doc.add(new TextField("name", "name1", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("1"))); + w.addDocument(doc); + + // 1 + doc = new Document(); + doc.add(new TextField(productIdField, "1", Field.Store.NO)); + doc.add(new TextField(typeField, "price", Field.Store.NO)); + doc.add(new TextField("price", "10.0", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("1"))); + w.addDocument(doc); + + // 2 + doc = new Document(); + doc.add(new TextField(productIdField, "1", Field.Store.NO)); + doc.add(new TextField(typeField, "price", Field.Store.NO)); + doc.add(new TextField("price", "20.0", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("1"))); + w.addDocument(doc); + + // 3 + doc = new Document(); + doc.add(new TextField(idField, "2", Field.Store.NO)); + doc.add(new TextField(typeField, "product", Field.Store.NO)); + doc.add(new TextField("description", "more random text", Field.Store.NO)); + doc.add(new TextField("name", "name2", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("2"))); + w.addDocument(doc); + w.commit(); + + // 4 + doc = new Document(); + doc.add(new TextField(productIdField, "2", Field.Store.NO)); + doc.add(new TextField(typeField, "price", Field.Store.NO)); + doc.add(new TextField("price", "10.0", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("2"))); + w.addDocument(doc); + + // 5 + doc = new Document(); + doc.add(new TextField(productIdField, "2", Field.Store.NO)); + doc.add(new TextField(typeField, "price", Field.Store.NO)); + doc.add(new TextField("price", "20.0", Field.Store.NO)); + doc.add(new SortedDocValuesField(joinField, new BytesRef("2"))); + w.addDocument(doc); + + IndexSearcher indexSearcher = new IndexSearcher(w.getReader()); + w.close(); + + IndexReader r = indexSearcher.getIndexReader(); + SortedDocValues[] values = new SortedDocValues[r.leaves().size()]; + for (int i = 0; i < values.length; i++) { + LeafReader leafReader = r.leaves().get(i).reader(); + values[i] = DocValues.getSorted(leafReader, joinField); + } + MultiDocValues.OrdinalMap ordinalMap = MultiDocValues.OrdinalMap.build( + r.getCoreCacheKey(), values, PackedInts.DEFAULT + ); + + Query toQuery = new TermQuery(new Term(typeField, "price")); + Query fromQuery = new TermQuery(new Term("name", "name2")); + // Search for product and return prices + Query joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap); + TopDocs result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(4, result.scoreDocs[0].doc); + assertEquals(5, result.scoreDocs[1].doc); + + fromQuery = new TermQuery(new Term("name", "name1")); + joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap); + result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(1, result.scoreDocs[0].doc); + assertEquals(2, result.scoreDocs[1].doc); + + // Search for prices and return products + fromQuery = new TermQuery(new Term("price", "20.0")); + toQuery = new TermQuery(new Term(typeField, "product")); + joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap); + result = indexSearcher.search(joinQuery, 10); + assertEquals(2, result.totalHits); + assertEquals(0, result.scoreDocs[0].doc); + assertEquals(3, result.scoreDocs[1].doc); + + indexSearcher.getIndexReader().close(); + dir.close(); + } + + public void testRandomOrdinalsJoin() throws Exception { + Directory dir = newDirectory(); + RandomIndexWriter w = new RandomIndexWriter( + random(), + dir, + newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)).setMergePolicy(newLogMergePolicy()) + ); + IndexIterationContext context = createContext(100, w, false, true); + + w.forceMerge(1); + + w.close(); + IndexReader topLevelReader = DirectoryReader.open(dir); + + SortedDocValues[] values = new SortedDocValues[topLevelReader.leaves().size()]; + for (LeafReaderContext leadContext : topLevelReader.leaves()) { + values[leadContext.ord] = DocValues.getSorted(leadContext.reader(), "join_field"); + } + context.ordinalMap = MultiDocValues.OrdinalMap.build( + topLevelReader.getCoreCacheKey(), values, PackedInts.DEFAULT + ); + IndexSearcher indexSearcher = newSearcher(topLevelReader); + + int r = random().nextInt(context.randomUniqueValues.length); + boolean from = context.randomFrom[r]; + String randomValue = context.randomUniqueValues[r]; + BitSet expectedResult = createExpectedResult(randomValue, from, indexSearcher.getIndexReader(), context); + + final Query actualQuery = new TermQuery(new Term("value", randomValue)); + if (VERBOSE) { + System.out.println("actualQuery=" + actualQuery); + } + final ScoreMode scoreMode = ScoreMode.values()[random().nextInt(ScoreMode.values().length)]; + if (VERBOSE) { + System.out.println("scoreMode=" + scoreMode); + } + + final Query joinQuery; + if (from) { + BooleanQuery fromQuery = new BooleanQuery(); + fromQuery.add(new TermQuery(new Term("type", "from")), BooleanClause.Occur.FILTER); + fromQuery.add(actualQuery, BooleanClause.Occur.MUST); + Query toQuery = new TermQuery(new Term("type", "to")); + joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, indexSearcher, scoreMode, context.ordinalMap); + } else { + BooleanQuery fromQuery = new BooleanQuery(); + fromQuery.add(new TermQuery(new Term("type", "to")), BooleanClause.Occur.FILTER); + fromQuery.add(actualQuery, BooleanClause.Occur.MUST); + Query toQuery = new TermQuery(new Term("type", "from")); + joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, indexSearcher, scoreMode, context.ordinalMap); + } + if (VERBOSE) { + System.out.println("joinQuery=" + joinQuery); + } + + final BitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc()); + final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10); + indexSearcher.search(joinQuery, MultiCollector.wrap(new BitSetCollector(actualResult), topScoreDocCollector)); + assertBitSet(expectedResult, actualResult, indexSearcher); + TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context); + TopDocs actualTopDocs = topScoreDocCollector.topDocs(); + assertTopDocs(expectedTopDocs, actualTopDocs, scoreMode, indexSearcher, joinQuery); + topLevelReader.close(); + dir.close(); + } + // TermsWithScoreCollector.MV.Avg forgets to grow beyond TermsWithScoreCollector.INITIAL_ARRAY_SIZE public void testOverflowTermsWithScoreCollector() throws Exception { test300spartans(true, ScoreMode.Avg); @@ -218,7 +395,7 @@ public class TestJoinUtil extends LuceneTestCase { TopDocs result = indexSearcher.search(joinQuery, 10); assertEquals(1, result.totalHits); assertEquals(0, result.scoreDocs[0].doc); - + indexSearcher.getIndexReader().close(); dir.close(); @@ -310,7 +487,7 @@ public class TestJoinUtil extends LuceneTestCase { assertFalse("optimized bulkScorer was not used for join query embedded in boolean query!", sawFive); } } - + @Override public boolean needsScores() { return false; @@ -448,7 +625,7 @@ public class TestJoinUtil extends LuceneTestCase { dir, newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)).setMergePolicy(newLogMergePolicy()) ); - IndexIterationContext context = createContext(numberOfDocumentsToIndex, w, multipleValuesPerDocument); + IndexIterationContext context = createContext(numberOfDocumentsToIndex, w, multipleValuesPerDocument, false); IndexReader topLevelReader = w.getReader(); w.close(); @@ -485,73 +662,64 @@ public class TestJoinUtil extends LuceneTestCase { // Need to know all documents that have matches. TopDocs doesn't give me that and then I'd be also testing TopDocsCollector... final BitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc()); final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10); - indexSearcher.search(joinQuery, new Collector() { - - @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { - final int docBase = context.docBase; - final LeafCollector in = topScoreDocCollector.getLeafCollector(context); - return new FilterLeafCollector(in) { - - @Override - public void collect(int doc) throws IOException { - super.collect(doc); - actualResult.set(doc + docBase); - } - }; - } - - @Override - public boolean needsScores() { - return topScoreDocCollector.needsScores(); - } - }); + indexSearcher.search(joinQuery, MultiCollector.wrap(new BitSetCollector(actualResult), topScoreDocCollector)); // Asserting bit set... - if (VERBOSE) { - System.out.println("expected cardinality:" + expectedResult.cardinality()); - DocIdSetIterator iterator = new BitSetIterator(expectedResult, expectedResult.cardinality()); - for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { - System.out.println(String.format(Locale.ROOT, "Expected doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id"))); - } - System.out.println("actual cardinality:" + actualResult.cardinality()); - iterator = new BitSetIterator(actualResult, actualResult.cardinality()); - for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { - System.out.println(String.format(Locale.ROOT, "Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id"))); - } - } - assertEquals(expectedResult, actualResult); - + assertBitSet(expectedResult, actualResult, indexSearcher); // Asserting TopDocs... TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context); TopDocs actualTopDocs = topScoreDocCollector.topDocs(); - assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits); - assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length); - if (scoreMode == ScoreMode.None) { - continue; - } - - assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f); - for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) { - if (VERBOSE) { - System.out.printf(Locale.ENGLISH, "Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); - System.out.printf(Locale.ENGLISH, "Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score); - } - assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); - assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f); - Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc); - assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f); - } + assertTopDocs(expectedTopDocs, actualTopDocs, scoreMode, indexSearcher, joinQuery); } topLevelReader.close(); dir.close(); } } - private IndexIterationContext createContext(int nDocs, RandomIndexWriter writer, boolean multipleValuesPerDocument) throws IOException { - return createContext(nDocs, writer, writer, multipleValuesPerDocument); + private void assertBitSet(BitSet expectedResult, BitSet actualResult, IndexSearcher indexSearcher) throws IOException { + if (VERBOSE) { + System.out.println("expected cardinality:" + expectedResult.cardinality()); + DocIdSetIterator iterator = new BitSetIterator(expectedResult, expectedResult.cardinality()); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + System.out.println(String.format(Locale.ROOT, "Expected doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id"))); + } + System.out.println("actual cardinality:" + actualResult.cardinality()); + iterator = new BitSetIterator(actualResult, actualResult.cardinality()); + for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) { + System.out.println(String.format(Locale.ROOT, "Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id"))); + } + } + assertEquals(expectedResult, actualResult); } - private IndexIterationContext createContext(int nDocs, RandomIndexWriter fromWriter, RandomIndexWriter toWriter, boolean multipleValuesPerDocument) throws IOException { + private void assertTopDocs(TopDocs expectedTopDocs, TopDocs actualTopDocs, ScoreMode scoreMode, IndexSearcher indexSearcher, Query joinQuery) throws IOException { + assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits); + assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length); + if (scoreMode == ScoreMode.None) { + return; + } + + assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f); + for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) { + if (VERBOSE) { + System.out.printf(Locale.ENGLISH, "Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); + System.out.printf(Locale.ENGLISH, "Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score); + } + assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc); + assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f); + Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc); + assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f); + } + } + + private IndexIterationContext createContext(int nDocs, RandomIndexWriter writer, boolean multipleValuesPerDocument, boolean ordinalJoin) throws IOException { + return createContext(nDocs, writer, writer, multipleValuesPerDocument, ordinalJoin); + } + + private IndexIterationContext createContext(int nDocs, RandomIndexWriter fromWriter, RandomIndexWriter toWriter, boolean multipleValuesPerDocument, boolean globalOrdinalJoin) throws IOException { + if (globalOrdinalJoin) { + assertFalse("ordinal join doesn't support multiple join values per document", multipleValuesPerDocument); + } + IndexIterationContext context = new IndexIterationContext(); int numRandomValues = nDocs / 2; context.randomUniqueValues = new String[numRandomValues]; @@ -560,8 +728,8 @@ public class TestJoinUtil extends LuceneTestCase { for (int i = 0; i < numRandomValues; i++) { String uniqueRandomValue; do { - uniqueRandomValue = TestUtil.randomRealisticUnicodeString(random()); -// uniqueRandomValue = _TestUtil.randomSimpleString(random); +// uniqueRandomValue = TestUtil.randomRealisticUnicodeString(random()); + uniqueRandomValue = TestUtil.randomSimpleString(random()); } while ("".equals(uniqueRandomValue) || trackSet.contains(uniqueRandomValue)); // Generate unique values and empty strings aren't allowed. trackSet.add(uniqueRandomValue); @@ -581,15 +749,18 @@ public class TestJoinUtil extends LuceneTestCase { boolean from = context.randomFrom[randomI]; int numberOfLinkValues = multipleValuesPerDocument ? 2 + random().nextInt(10) : 1; docs[i] = new RandomDoc(id, numberOfLinkValues, value, from); + if (globalOrdinalJoin) { + document.add(newStringField("type", from ? "from" : "to", Field.Store.NO)); + } for (int j = 0; j < numberOfLinkValues; j++) { String linkValue = context.randomUniqueValues[random().nextInt(context.randomUniqueValues.length)]; docs[i].linkValues.add(linkValue); if (from) { if (!context.fromDocuments.containsKey(linkValue)) { - context.fromDocuments.put(linkValue, new ArrayList()); + context.fromDocuments.put(linkValue, new ArrayList<>()); } if (!context.randomValueFromDocs.containsKey(value)) { - context.randomValueFromDocs.put(value, new ArrayList()); + context.randomValueFromDocs.put(value, new ArrayList<>()); } context.fromDocuments.get(linkValue).add(docs[i]); @@ -600,12 +771,15 @@ public class TestJoinUtil extends LuceneTestCase { } else { document.add(new SortedDocValuesField("from", new BytesRef(linkValue))); } + if (globalOrdinalJoin) { + document.add(new SortedDocValuesField("join_field", new BytesRef(linkValue))); + } } else { if (!context.toDocuments.containsKey(linkValue)) { - context.toDocuments.put(linkValue, new ArrayList()); + context.toDocuments.put(linkValue, new ArrayList<>()); } if (!context.randomValueToDocs.containsKey(value)) { - context.randomValueToDocs.put(value, new ArrayList()); + context.randomValueToDocs.put(value, new ArrayList<>()); } context.toDocuments.get(linkValue).add(docs[i]); @@ -616,6 +790,9 @@ public class TestJoinUtil extends LuceneTestCase { } else { document.add(new SortedDocValuesField("to", new BytesRef(linkValue))); } + if (globalOrdinalJoin) { + document.add(new SortedDocValuesField("join_field", new BytesRef(linkValue))); + } } } @@ -707,6 +884,9 @@ public class TestJoinUtil extends LuceneTestCase { if (joinScore == null) { joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore()); } + if (VERBOSE) { + System.out.println("expected val=" + joinValue.utf8ToString() + " expected score=" + scorer.score()); + } joinScore.addScore(scorer.score()); } @@ -720,7 +900,7 @@ public class TestJoinUtil extends LuceneTestCase { public void setScorer(Scorer scorer) { this.scorer = scorer; } - + @Override public boolean needsScores() { return true; @@ -777,7 +957,7 @@ public class TestJoinUtil extends LuceneTestCase { @Override public void setScorer(Scorer scorer) {} - + @Override public boolean needsScores() { return false; @@ -875,6 +1055,7 @@ public class TestJoinUtil extends LuceneTestCase { Map> fromHitsToJoinScore = new HashMap<>(); Map> toHitsToJoinScore = new HashMap<>(); + MultiDocValues.OrdinalMap ordinalMap; } private static class RandomDoc { @@ -922,4 +1103,29 @@ public class TestJoinUtil extends LuceneTestCase { } + private static class BitSetCollector extends SimpleCollector { + + private final BitSet bitSet; + private int docBase; + + private BitSetCollector(BitSet bitSet) { + this.bitSet = bitSet; + } + + @Override + public void collect(int doc) throws IOException { + bitSet.set(docBase + doc); + } + + @Override + protected void doSetNextReader(LeafReaderContext context) throws IOException { + docBase = context.docBase; + } + + @Override + public boolean needsScores() { + return false; + } + } + }