LUCENE-6352: Added a new query time join to the join module that uses global ordinals, which is faster for subsequent joins between reopens.

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1670990 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Martijn van Groningen 2015-04-02 22:00:26 +00:00
parent 548edc5406
commit c4d9d6b3f1
8 changed files with 1334 additions and 83 deletions

View File

@ -45,6 +45,10 @@ New Features
faster intersection by avoiding loading positions in certain cases.
(Paul Elschot, Robert Muir via Mike McCandless)
* LUCENE-6352: Added a new query time join to the join module that uses
global ordinals, which is faster for subsequent joins between reopens.
(Martijn van Groningen, Adrien Grand)
Optimizations
* LUCENE-6379: IndexWriter.deleteDocuments(Query...) now detects if

View File

@ -0,0 +1,97 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.LongBitSet;
import java.io.IOException;
abstract class BaseGlobalOrdinalScorer extends Scorer {
final LongBitSet foundOrds;
final SortedDocValues values;
final Scorer approximationScorer;
float score;
public BaseGlobalOrdinalScorer(Weight weight, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) {
super(weight);
this.foundOrds = foundOrds;
this.values = values;
this.approximationScorer = approximationScorer;
}
@Override
public float score() throws IOException {
return score;
}
@Override
public int docID() {
return approximationScorer.docID();
}
@Override
public int nextDoc() throws IOException {
return advance(approximationScorer.docID() + 1);
}
@Override
public TwoPhaseIterator asTwoPhaseIterator() {
final DocIdSetIterator approximation = new DocIdSetIterator() {
@Override
public int docID() {
return approximationScorer.docID();
}
@Override
public int nextDoc() throws IOException {
return approximationScorer.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return approximationScorer.advance(target);
}
@Override
public long cost() {
return approximationScorer.cost();
}
};
return createTwoPhaseIterator(approximation);
}
@Override
public long cost() {
return approximationScorer.cost();
}
@Override
public int freq() throws IOException {
return 1;
}
protected abstract TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation);
}

View File

@ -0,0 +1,114 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.LongValues;
import java.io.IOException;
/**
* A collector that collects all ordinals from a specified field matching the query.
*
* @lucene.experimental
*/
final class GlobalOrdinalsCollector implements Collector {
final String field;
final LongBitSet collectedOrds;
final MultiDocValues.OrdinalMap ordinalMap;
GlobalOrdinalsCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
this.field = field;
this.ordinalMap = ordinalMap;
this.collectedOrds = new LongBitSet(valueCount);
}
public LongBitSet getCollectorOrdinals() {
return collectedOrds;
}
@Override
public boolean needsScores() {
return false;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field);
if (ordinalMap != null) {
LongValues segmentOrdToGlobalOrdLookup = ordinalMap.getGlobalOrds(context.ord);
return new OrdinalMapCollector(docTermOrds, segmentOrdToGlobalOrdLookup);
} else {
return new SegmentOrdinalCollector(docTermOrds);
}
}
final class OrdinalMapCollector implements LeafCollector {
private final SortedDocValues docTermOrds;
private final LongValues segmentOrdToGlobalOrdLookup;
OrdinalMapCollector(SortedDocValues docTermOrds, LongValues segmentOrdToGlobalOrdLookup) {
this.docTermOrds = docTermOrds;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
}
@Override
public void collect(int doc) throws IOException {
final long segmentOrd = docTermOrds.getOrd(doc);
if (segmentOrd != -1) {
final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd);
collectedOrds.set(globalOrd);
}
}
@Override
public void setScorer(Scorer scorer) throws IOException {
}
}
final class SegmentOrdinalCollector implements LeafCollector {
private final SortedDocValues docTermOrds;
SegmentOrdinalCollector(SortedDocValues docTermOrds) {
this.docTermOrds = docTermOrds;
}
@Override
public void collect(int doc) throws IOException {
final long segmentOrd = docTermOrds.getOrd(doc);
if (segmentOrd != -1) {
collectedOrds.set(segmentOrd);
}
}
@Override
public void setScorer(Scorer scorer) throws IOException {
}
}
}

View File

@ -0,0 +1,245 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ComplexExplanation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.LongValues;
import java.io.IOException;
import java.util.Set;
final class GlobalOrdinalsQuery extends Query {
// All the ords of matching docs found with OrdinalsCollector.
private final LongBitSet foundOrds;
private final String joinField;
private final MultiDocValues.OrdinalMap globalOrds;
// Is also an approximation of the docs that will match. Can be all docs that have toField or something more specific.
private final Query toQuery;
// just for hashcode and equals:
private final Query fromQuery;
private final IndexReader indexReader;
GlobalOrdinalsQuery(LongBitSet foundOrds, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, IndexReader indexReader) {
this.foundOrds = foundOrds;
this.joinField = joinField;
this.globalOrds = globalOrds;
this.toQuery = toQuery;
this.fromQuery = fromQuery;
this.indexReader = indexReader;
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
return new W(this, toQuery.createWeight(searcher, false));
}
@Override
public void extractTerms(Set<Term> terms) {
fromQuery.extractTerms(terms);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
GlobalOrdinalsQuery that = (GlobalOrdinalsQuery) o;
if (!fromQuery.equals(that.fromQuery)) return false;
if (!joinField.equals(that.joinField)) return false;
if (!toQuery.equals(that.toQuery)) return false;
if (!indexReader.equals(that.indexReader)) return false;
return true;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + joinField.hashCode();
result = 31 * result + toQuery.hashCode();
result = 31 * result + fromQuery.hashCode();
result = 31 * result + indexReader.hashCode();
return result;
}
@Override
public String toString(String field) {
return "GlobalOrdinalsQuery{" +
"joinField=" + joinField +
'}';
}
final class W extends Weight {
private final Weight approximationWeight;
private float queryNorm;
private float queryWeight;
W(Query query, Weight approximationWeight) {
super(query);
this.approximationWeight = approximationWeight;
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
if (values != null) {
int segmentOrd = values.getOrd(doc);
if (segmentOrd != -1) {
BytesRef joinValue = values.lookupOrd(segmentOrd);
return new ComplexExplanation(true, queryNorm, "Score based on join value " + joinValue.utf8ToString());
}
}
return new ComplexExplanation(false, 0.0f, "Not a match");
}
@Override
public float getValueForNormalization() throws IOException {
queryWeight = getBoost();
return queryWeight * queryWeight;
}
@Override
public void normalize(float norm, float topLevelBoost) {
this.queryNorm = norm * topLevelBoost;
queryWeight *= this.queryNorm;
}
@Override
public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
if (values == null) {
return null;
}
Scorer approximationScorer = approximationWeight.scorer(context, acceptDocs);
if (approximationScorer == null) {
return null;
}
if (globalOrds != null) {
return new OrdinalMapScorer(this, queryNorm, foundOrds, values, approximationScorer, globalOrds.getGlobalOrds(context.ord));
} {
return new SegmentOrdinalScorer(this, queryNorm, foundOrds, values, approximationScorer);
}
}
}
final static class OrdinalMapScorer extends BaseGlobalOrdinalScorer {
final LongValues segmentOrdToGlobalOrdLookup;
public OrdinalMapScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) {
super(weight, foundOrds, values, approximationScorer);
this.score = score;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
}
@Override
public int advance(int target) throws IOException {
for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) {
final long segmentOrd = values.getOrd(docID);
if (segmentOrd != -1) {
final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd);
if (foundOrds.get(globalOrd)) {
return docID;
}
}
}
return NO_MORE_DOCS;
}
@Override
protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) {
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
final long segmentOrd = values.getOrd(approximationScorer.docID());
if (segmentOrd != -1) {
final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd);
if (foundOrds.get(globalOrd)) {
return true;
}
}
return false;
}
};
}
}
final static class SegmentOrdinalScorer extends BaseGlobalOrdinalScorer {
public SegmentOrdinalScorer(Weight weight, float score, LongBitSet foundOrds, SortedDocValues values, Scorer approximationScorer) {
super(weight, foundOrds, values, approximationScorer);
this.score = score;
}
@Override
public int advance(int target) throws IOException {
for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) {
final long segmentOrd = values.getOrd(docID);
if (segmentOrd != -1) {
if (foundOrds.get(segmentOrd)) {
return docID;
}
}
}
return NO_MORE_DOCS;
}
@Override
protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) {
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
final long segmentOrd = values.getOrd(approximationScorer.docID());
if (segmentOrd != -1) {
if (foundOrds.get(segmentOrd)) {
return true;
}
}
return false;
}
};
}
}
}

View File

@ -0,0 +1,250 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.util.LongBitSet;
import org.apache.lucene.util.LongValues;
import java.io.IOException;
abstract class GlobalOrdinalsWithScoreCollector implements Collector {
final String field;
final MultiDocValues.OrdinalMap ordinalMap;
final LongBitSet collectedOrds;
protected final Scores scores;
GlobalOrdinalsWithScoreCollector(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
if (valueCount > Integer.MAX_VALUE) {
// We simply don't support more than
throw new IllegalStateException("Can't collect more than [" + Integer.MAX_VALUE + "] ids");
}
this.field = field;
this.ordinalMap = ordinalMap;
this.collectedOrds = new LongBitSet(valueCount);
this.scores = new Scores(valueCount);
}
public LongBitSet getCollectorOrdinals() {
return collectedOrds;
}
public float score(int globalOrdinal) {
return scores.getScore(globalOrdinal);
}
protected abstract void doScore(int globalOrd, float existingScore, float newScore);
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
SortedDocValues docTermOrds = DocValues.getSorted(context.reader(), field);
if (ordinalMap != null) {
LongValues segmentOrdToGlobalOrdLookup = ordinalMap.getGlobalOrds(context.ord);
return new OrdinalMapCollector(docTermOrds, segmentOrdToGlobalOrdLookup);
} else {
return new SegmentOrdinalCollector(docTermOrds);
}
}
@Override
public boolean needsScores() {
return true;
}
final class OrdinalMapCollector implements LeafCollector {
private final SortedDocValues docTermOrds;
private final LongValues segmentOrdToGlobalOrdLookup;
private Scorer scorer;
OrdinalMapCollector(SortedDocValues docTermOrds, LongValues segmentOrdToGlobalOrdLookup) {
this.docTermOrds = docTermOrds;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
}
@Override
public void collect(int doc) throws IOException {
final long segmentOrd = docTermOrds.getOrd(doc);
if (segmentOrd != -1) {
final int globalOrd = (int) segmentOrdToGlobalOrdLookup.get(segmentOrd);
collectedOrds.set(globalOrd);
float existingScore = scores.getScore(globalOrd);
float newScore = scorer.score();
doScore(globalOrd, existingScore, newScore);
}
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
}
final class SegmentOrdinalCollector implements LeafCollector {
private final SortedDocValues docTermOrds;
private Scorer scorer;
SegmentOrdinalCollector(SortedDocValues docTermOrds) {
this.docTermOrds = docTermOrds;
}
@Override
public void collect(int doc) throws IOException {
final int segmentOrd = docTermOrds.getOrd(doc);
if (segmentOrd != -1) {
collectedOrds.set(segmentOrd);
float existingScore = scores.getScore(segmentOrd);
float newScore = scorer.score();
doScore(segmentOrd, existingScore, newScore);
}
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
}
static final class Max extends GlobalOrdinalsWithScoreCollector {
public Max(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
super(field, ordinalMap, valueCount);
}
@Override
protected void doScore(int globalOrd, float existingScore, float newScore) {
scores.setScore(globalOrd, Math.max(existingScore, newScore));
}
}
static final class Sum extends GlobalOrdinalsWithScoreCollector {
public Sum(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
super(field, ordinalMap, valueCount);
}
@Override
protected void doScore(int globalOrd, float existingScore, float newScore) {
scores.setScore(globalOrd, existingScore + newScore);
}
}
static final class Avg extends GlobalOrdinalsWithScoreCollector {
private final Occurrences occurrences;
public Avg(String field, MultiDocValues.OrdinalMap ordinalMap, long valueCount) {
super(field, ordinalMap, valueCount);
this.occurrences = new Occurrences(valueCount);
}
@Override
protected void doScore(int globalOrd, float existingScore, float newScore) {
occurrences.increment(globalOrd);
scores.setScore(globalOrd, existingScore + newScore);
}
@Override
public float score(int globalOrdinal) {
return scores.getScore(globalOrdinal) / occurrences.getOccurence(globalOrdinal);
}
}
// Because the global ordinal is directly used as a key to a score we should be somewhat smart about allocation
// the scores array. Most of the times not all docs match so splitting the scores array up in blocks can prevent creation of huge arrays.
// Also working with smaller arrays is supposed to be more gc friendly
//
// At first a hash map implementation would make sense, but in the case that more than half of docs match this becomes more expensive
// then just using an array.
// Maybe this should become a method parameter?
static final int arraySize = 4096;
static final class Scores {
final float[][] blocks;
private Scores(long valueCount) {
long blockSize = valueCount + arraySize - 1;
blocks = new float[(int) ((blockSize) / arraySize)][];
}
public void setScore(int globalOrdinal, float score) {
int block = globalOrdinal / arraySize;
int offset = globalOrdinal % arraySize;
float[] scores = blocks[block];
if (scores == null) {
blocks[block] = scores = new float[arraySize];
}
scores[offset] = score;
}
public float getScore(int globalOrdinal) {
int block = globalOrdinal / arraySize;
int offset = globalOrdinal % arraySize;
float[] scores = blocks[block];
float score;
if (scores != null) {
score = scores[offset];
} else {
score = 0f;
}
return score;
}
}
static final class Occurrences {
final int[][] blocks;
private Occurrences(long valueCount) {
long blockSize = valueCount + arraySize - 1;
blocks = new int[(int) (blockSize / arraySize)][];
}
public void increment(int globalOrdinal) {
int block = globalOrdinal / arraySize;
int offset = globalOrdinal % arraySize;
int[] occurrences = blocks[block];
if (occurrences == null) {
blocks[block] = occurrences = new int[arraySize];
}
occurrences[offset]++;
}
public int getOccurence(int globalOrdinal) {
int block = globalOrdinal / arraySize;
int offset = globalOrdinal % arraySize;
int[] occurrences = blocks[block];
return occurrences[offset];
}
}
}

View File

@ -0,0 +1,256 @@
package org.apache.lucene.search.join;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.ComplexExplanation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.LongValues;
import java.io.IOException;
import java.util.Set;
final class GlobalOrdinalsWithScoreQuery extends Query {
private final GlobalOrdinalsWithScoreCollector collector;
private final String joinField;
private final MultiDocValues.OrdinalMap globalOrds;
// Is also an approximation of the docs that will match. Can be all docs that have toField or something more specific.
private final Query toQuery;
// just for hashcode and equals:
private final Query fromQuery;
private final IndexReader indexReader;
GlobalOrdinalsWithScoreQuery(GlobalOrdinalsWithScoreCollector collector, String joinField, MultiDocValues.OrdinalMap globalOrds, Query toQuery, Query fromQuery, IndexReader indexReader) {
this.collector = collector;
this.joinField = joinField;
this.globalOrds = globalOrds;
this.toQuery = toQuery;
this.fromQuery = fromQuery;
this.indexReader = indexReader;
}
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores) throws IOException {
return new W(this, toQuery.createWeight(searcher, false));
}
@Override
public void extractTerms(Set<Term> terms) {
fromQuery.extractTerms(terms);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
GlobalOrdinalsWithScoreQuery that = (GlobalOrdinalsWithScoreQuery) o;
if (!fromQuery.equals(that.fromQuery)) return false;
if (!joinField.equals(that.joinField)) return false;
if (!toQuery.equals(that.toQuery)) return false;
if (!indexReader.equals(that.indexReader)) return false;
return true;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + joinField.hashCode();
result = 31 * result + toQuery.hashCode();
result = 31 * result + fromQuery.hashCode();
result = 31 * result + indexReader.hashCode();
return result;
}
@Override
public String toString(String field) {
return "GlobalOrdinalsQuery{" +
"joinField=" + joinField +
'}';
}
final class W extends Weight {
private final Weight approximationWeight;
private float queryNorm;
private float queryWeight;
W(Query query, Weight approximationWeight) {
super(query);
this.approximationWeight = approximationWeight;
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
if (values != null) {
int segmentOrd = values.getOrd(doc);
if (segmentOrd != -1) {
final float score;
if (globalOrds != null) {
long globalOrd = globalOrds.getGlobalOrds(context.ord).get(segmentOrd);
score = collector.scores.getScore((int) globalOrd);
} else {
score = collector.score(segmentOrd);
}
BytesRef joinValue = values.lookupOrd(segmentOrd);
return new ComplexExplanation(true, score, "Score based on join value " + joinValue.utf8ToString());
}
}
return new ComplexExplanation(false, 0.0f, "Not a match");
}
@Override
public float getValueForNormalization() throws IOException {
queryWeight = getBoost();
return queryWeight * queryWeight;
}
@Override
public void normalize(float norm, float topLevelBoost) {
this.queryNorm = norm * topLevelBoost;
queryWeight *= this.queryNorm;
}
@Override
public Scorer scorer(LeafReaderContext context, Bits acceptDocs) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
if (values == null) {
return null;
}
Scorer approximationScorer = approximationWeight.scorer(context, acceptDocs);
if (approximationScorer == null) {
return null;
} else if (globalOrds != null) {
return new OrdinalMapScorer(this, collector, values, approximationScorer, globalOrds.getGlobalOrds(context.ord));
} else {
return new SegmentOrdinalScorer(this, collector, values, approximationScorer);
}
}
}
final static class OrdinalMapScorer extends BaseGlobalOrdinalScorer {
final LongValues segmentOrdToGlobalOrdLookup;
final GlobalOrdinalsWithScoreCollector collector;
public OrdinalMapScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer, LongValues segmentOrdToGlobalOrdLookup) {
super(weight, collector.getCollectorOrdinals(), values, approximationScorer);
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
this.collector = collector;
}
@Override
public int advance(int target) throws IOException {
for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) {
final long segmentOrd = values.getOrd(docID);
if (segmentOrd != -1) {
final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd);
if (foundOrds.get(globalOrd)) {
score = collector.score((int) globalOrd);
return docID;
}
}
}
return NO_MORE_DOCS;
}
@Override
protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) {
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
final long segmentOrd = values.getOrd(approximationScorer.docID());
if (segmentOrd != -1) {
final long globalOrd = segmentOrdToGlobalOrdLookup.get(segmentOrd);
if (foundOrds.get(globalOrd)) {
score = collector.score((int) globalOrd);
return true;
}
}
return false;
}
};
}
}
final static class SegmentOrdinalScorer extends BaseGlobalOrdinalScorer {
final GlobalOrdinalsWithScoreCollector collector;
public SegmentOrdinalScorer(Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, Scorer approximationScorer) {
super(weight, collector.getCollectorOrdinals(), values, approximationScorer);
this.collector = collector;
}
@Override
public int advance(int target) throws IOException {
for (int docID = approximationScorer.advance(target); docID < NO_MORE_DOCS; docID = approximationScorer.nextDoc()) {
final int segmentOrd = values.getOrd(docID);
if (segmentOrd != -1) {
if (foundOrds.get(segmentOrd)) {
score = collector.score(segmentOrd);
return docID;
}
}
}
return NO_MORE_DOCS;
}
@Override
protected TwoPhaseIterator createTwoPhaseIterator(DocIdSetIterator approximation) {
return new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
final int segmentOrd = values.getOrd(approximationScorer.docID());
if (segmentOrd != -1) {
if (foundOrds.get(segmentOrd)) {
score = collector.score(segmentOrd);
return true;
}
}
return false;
}
};
}
}
}

View File

@ -17,7 +17,12 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import java.io.IOException;
@ -90,4 +95,78 @@ public final class JoinUtil {
}
}
/**
* A query time join using global ordinals over a dedicated join field.
*
* This join has certain restrictions and requirements:
* 1) A document can only refer to one other document. (but can be referred by one or more documents)
* 2) Documents on each side of the join must be distinguishable. Typically this can be done by adding an extra field
* that identifies the "from" and "to" side and then the fromQuery and toQuery must take the this into account.
* 3) There must be a single sorted doc values join field used by both the "from" and "to" documents. This join field
* should store the join values as UTF-8 strings.
* 4) An ordinal map must be provided that is created on top of the join field.
*
* @param joinField The {@link org.apache.lucene.index.SortedDocValues} field containing the join values
* @param fromQuery The query containing the actual user query. Also the fromQuery can only match "from" documents.
* @param toQuery The query identifying all documents on the "to" side.
* @param searcher The index searcher used to execute the from query
* @param scoreMode Instructs how scores from the fromQuery are mapped to the returned query
* @param ordinalMap The ordinal map constructed over the joinField. In case of a single segment index, no ordinal map
* needs to be provided.
* @return a {@link Query} instance that can be used to join documents based on the join field
* @throws IOException If I/O related errors occur
*/
public static Query createJoinQuery(String joinField,
Query fromQuery,
Query toQuery,
IndexSearcher searcher,
ScoreMode scoreMode,
MultiDocValues.OrdinalMap ordinalMap) throws IOException {
IndexReader indexReader = searcher.getIndexReader();
int numSegments = indexReader.leaves().size();
final long valueCount;
if (numSegments == 0) {
return new MatchNoDocsQuery();
} else if (numSegments == 1) {
// No need to use the ordinal map, because there is just one segment.
ordinalMap = null;
LeafReader leafReader = searcher.getIndexReader().leaves().get(0).reader();
SortedDocValues joinSortedDocValues = leafReader.getSortedDocValues(joinField);
if (joinSortedDocValues != null) {
valueCount = joinSortedDocValues.getValueCount();
} else {
return new MatchNoDocsQuery();
}
} else {
if (ordinalMap == null) {
throw new IllegalArgumentException("OrdinalMap is required, because there is more than 1 segment");
}
valueCount = ordinalMap.getValueCount();
}
Query rewrittenFromQuery = searcher.rewrite(fromQuery);
if (scoreMode == ScoreMode.None) {
GlobalOrdinalsCollector globalOrdinalsCollector = new GlobalOrdinalsCollector(joinField, ordinalMap, valueCount);
searcher.search(fromQuery, globalOrdinalsCollector);
return new GlobalOrdinalsQuery(globalOrdinalsCollector.getCollectorOrdinals(), joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
}
GlobalOrdinalsWithScoreCollector globalOrdinalsWithScoreCollector;
switch (scoreMode) {
case Total:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Sum(joinField, ordinalMap, valueCount);
break;
case Max:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Max(joinField, ordinalMap, valueCount);
break;
case Avg:
globalOrdinalsWithScoreCollector = new GlobalOrdinalsWithScoreCollector.Avg(joinField, ordinalMap, valueCount);
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Score mode %s isn't supported.", scoreMode));
}
searcher.search(fromQuery, globalOrdinalsWithScoreCollector);
return new GlobalOrdinalsWithScoreQuery(globalOrdinalsWithScoreCollector, joinField, ordinalMap, toQuery, rewrittenFromQuery, indexReader);
}
}

View File

@ -17,19 +17,6 @@ package org.apache.lucene.search.join;
* limitations under the License.
*/
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document;
@ -38,27 +25,29 @@ import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.MultiDocValues;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.SlowCompositeReaderWrapper;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.FilterLeafCollector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
@ -74,8 +63,22 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.apache.lucene.util.packed.PackedInts;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
public class TestJoinUtil extends LuceneTestCase {
public void testSimple() throws Exception {
@ -169,6 +172,180 @@ public class TestJoinUtil extends LuceneTestCase {
dir.close();
}
public void testSimpleOrdinalsJoin() throws Exception {
final String idField = "id";
final String productIdField = "productId";
// A field indicating to what type a document belongs, which is then used to distinques between documents during joining.
final String typeField = "type";
// A single sorted doc values field that holds the join values for all document types.
// Typically during indexing a schema will automatically create this field with the values
final String joinField = idField + productIdField;
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random())).setMergePolicy(NoMergePolicy.INSTANCE));
// 0
Document doc = new Document();
doc.add(new TextField(idField, "1", Field.Store.NO));
doc.add(new TextField(typeField, "product", Field.Store.NO));
doc.add(new TextField("description", "random text", Field.Store.NO));
doc.add(new TextField("name", "name1", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("1")));
w.addDocument(doc);
// 1
doc = new Document();
doc.add(new TextField(productIdField, "1", Field.Store.NO));
doc.add(new TextField(typeField, "price", Field.Store.NO));
doc.add(new TextField("price", "10.0", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("1")));
w.addDocument(doc);
// 2
doc = new Document();
doc.add(new TextField(productIdField, "1", Field.Store.NO));
doc.add(new TextField(typeField, "price", Field.Store.NO));
doc.add(new TextField("price", "20.0", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("1")));
w.addDocument(doc);
// 3
doc = new Document();
doc.add(new TextField(idField, "2", Field.Store.NO));
doc.add(new TextField(typeField, "product", Field.Store.NO));
doc.add(new TextField("description", "more random text", Field.Store.NO));
doc.add(new TextField("name", "name2", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("2")));
w.addDocument(doc);
w.commit();
// 4
doc = new Document();
doc.add(new TextField(productIdField, "2", Field.Store.NO));
doc.add(new TextField(typeField, "price", Field.Store.NO));
doc.add(new TextField("price", "10.0", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("2")));
w.addDocument(doc);
// 5
doc = new Document();
doc.add(new TextField(productIdField, "2", Field.Store.NO));
doc.add(new TextField(typeField, "price", Field.Store.NO));
doc.add(new TextField("price", "20.0", Field.Store.NO));
doc.add(new SortedDocValuesField(joinField, new BytesRef("2")));
w.addDocument(doc);
IndexSearcher indexSearcher = new IndexSearcher(w.getReader());
w.close();
IndexReader r = indexSearcher.getIndexReader();
SortedDocValues[] values = new SortedDocValues[r.leaves().size()];
for (int i = 0; i < values.length; i++) {
LeafReader leafReader = r.leaves().get(i).reader();
values[i] = DocValues.getSorted(leafReader, joinField);
}
MultiDocValues.OrdinalMap ordinalMap = MultiDocValues.OrdinalMap.build(
r.getCoreCacheKey(), values, PackedInts.DEFAULT
);
Query toQuery = new TermQuery(new Term(typeField, "price"));
Query fromQuery = new TermQuery(new Term("name", "name2"));
// Search for product and return prices
Query joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap);
TopDocs result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(4, result.scoreDocs[0].doc);
assertEquals(5, result.scoreDocs[1].doc);
fromQuery = new TermQuery(new Term("name", "name1"));
joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(1, result.scoreDocs[0].doc);
assertEquals(2, result.scoreDocs[1].doc);
// Search for prices and return products
fromQuery = new TermQuery(new Term("price", "20.0"));
toQuery = new TermQuery(new Term(typeField, "product"));
joinQuery = JoinUtil.createJoinQuery(joinField, fromQuery, toQuery, indexSearcher, ScoreMode.None, ordinalMap);
result = indexSearcher.search(joinQuery, 10);
assertEquals(2, result.totalHits);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
indexSearcher.getIndexReader().close();
dir.close();
}
public void testRandomOrdinalsJoin() throws Exception {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(
random(),
dir,
newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)).setMergePolicy(newLogMergePolicy())
);
IndexIterationContext context = createContext(100, w, false, true);
w.forceMerge(1);
w.close();
IndexReader topLevelReader = DirectoryReader.open(dir);
SortedDocValues[] values = new SortedDocValues[topLevelReader.leaves().size()];
for (LeafReaderContext leadContext : topLevelReader.leaves()) {
values[leadContext.ord] = DocValues.getSorted(leadContext.reader(), "join_field");
}
context.ordinalMap = MultiDocValues.OrdinalMap.build(
topLevelReader.getCoreCacheKey(), values, PackedInts.DEFAULT
);
IndexSearcher indexSearcher = newSearcher(topLevelReader);
int r = random().nextInt(context.randomUniqueValues.length);
boolean from = context.randomFrom[r];
String randomValue = context.randomUniqueValues[r];
BitSet expectedResult = createExpectedResult(randomValue, from, indexSearcher.getIndexReader(), context);
final Query actualQuery = new TermQuery(new Term("value", randomValue));
if (VERBOSE) {
System.out.println("actualQuery=" + actualQuery);
}
final ScoreMode scoreMode = ScoreMode.values()[random().nextInt(ScoreMode.values().length)];
if (VERBOSE) {
System.out.println("scoreMode=" + scoreMode);
}
final Query joinQuery;
if (from) {
BooleanQuery fromQuery = new BooleanQuery();
fromQuery.add(new TermQuery(new Term("type", "from")), BooleanClause.Occur.FILTER);
fromQuery.add(actualQuery, BooleanClause.Occur.MUST);
Query toQuery = new TermQuery(new Term("type", "to"));
joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, indexSearcher, scoreMode, context.ordinalMap);
} else {
BooleanQuery fromQuery = new BooleanQuery();
fromQuery.add(new TermQuery(new Term("type", "to")), BooleanClause.Occur.FILTER);
fromQuery.add(actualQuery, BooleanClause.Occur.MUST);
Query toQuery = new TermQuery(new Term("type", "from"));
joinQuery = JoinUtil.createJoinQuery("join_field", fromQuery, toQuery, indexSearcher, scoreMode, context.ordinalMap);
}
if (VERBOSE) {
System.out.println("joinQuery=" + joinQuery);
}
final BitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc());
final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10);
indexSearcher.search(joinQuery, MultiCollector.wrap(new BitSetCollector(actualResult), topScoreDocCollector));
assertBitSet(expectedResult, actualResult, indexSearcher);
TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context);
TopDocs actualTopDocs = topScoreDocCollector.topDocs();
assertTopDocs(expectedTopDocs, actualTopDocs, scoreMode, indexSearcher, joinQuery);
topLevelReader.close();
dir.close();
}
// TermsWithScoreCollector.MV.Avg forgets to grow beyond TermsWithScoreCollector.INITIAL_ARRAY_SIZE
public void testOverflowTermsWithScoreCollector() throws Exception {
test300spartans(true, ScoreMode.Avg);
@ -218,7 +395,7 @@ public class TestJoinUtil extends LuceneTestCase {
TopDocs result = indexSearcher.search(joinQuery, 10);
assertEquals(1, result.totalHits);
assertEquals(0, result.scoreDocs[0].doc);
indexSearcher.getIndexReader().close();
dir.close();
@ -310,7 +487,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertFalse("optimized bulkScorer was not used for join query embedded in boolean query!", sawFive);
}
}
@Override
public boolean needsScores() {
return false;
@ -448,7 +625,7 @@ public class TestJoinUtil extends LuceneTestCase {
dir,
newIndexWriterConfig(new MockAnalyzer(random(), MockTokenizer.KEYWORD, false)).setMergePolicy(newLogMergePolicy())
);
IndexIterationContext context = createContext(numberOfDocumentsToIndex, w, multipleValuesPerDocument);
IndexIterationContext context = createContext(numberOfDocumentsToIndex, w, multipleValuesPerDocument, false);
IndexReader topLevelReader = w.getReader();
w.close();
@ -485,73 +662,64 @@ public class TestJoinUtil extends LuceneTestCase {
// Need to know all documents that have matches. TopDocs doesn't give me that and then I'd be also testing TopDocsCollector...
final BitSet actualResult = new FixedBitSet(indexSearcher.getIndexReader().maxDoc());
final TopScoreDocCollector topScoreDocCollector = TopScoreDocCollector.create(10);
indexSearcher.search(joinQuery, new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
final int docBase = context.docBase;
final LeafCollector in = topScoreDocCollector.getLeafCollector(context);
return new FilterLeafCollector(in) {
@Override
public void collect(int doc) throws IOException {
super.collect(doc);
actualResult.set(doc + docBase);
}
};
}
@Override
public boolean needsScores() {
return topScoreDocCollector.needsScores();
}
});
indexSearcher.search(joinQuery, MultiCollector.wrap(new BitSetCollector(actualResult), topScoreDocCollector));
// Asserting bit set...
if (VERBOSE) {
System.out.println("expected cardinality:" + expectedResult.cardinality());
DocIdSetIterator iterator = new BitSetIterator(expectedResult, expectedResult.cardinality());
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
System.out.println(String.format(Locale.ROOT, "Expected doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id")));
}
System.out.println("actual cardinality:" + actualResult.cardinality());
iterator = new BitSetIterator(actualResult, actualResult.cardinality());
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
System.out.println(String.format(Locale.ROOT, "Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id")));
}
}
assertEquals(expectedResult, actualResult);
assertBitSet(expectedResult, actualResult, indexSearcher);
// Asserting TopDocs...
TopDocs expectedTopDocs = createExpectedTopDocs(randomValue, from, scoreMode, context);
TopDocs actualTopDocs = topScoreDocCollector.topDocs();
assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits);
assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length);
if (scoreMode == ScoreMode.None) {
continue;
}
assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f);
for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) {
if (VERBOSE) {
System.out.printf(Locale.ENGLISH, "Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
System.out.printf(Locale.ENGLISH, "Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score);
}
assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f);
Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f);
}
assertTopDocs(expectedTopDocs, actualTopDocs, scoreMode, indexSearcher, joinQuery);
}
topLevelReader.close();
dir.close();
}
}
private IndexIterationContext createContext(int nDocs, RandomIndexWriter writer, boolean multipleValuesPerDocument) throws IOException {
return createContext(nDocs, writer, writer, multipleValuesPerDocument);
private void assertBitSet(BitSet expectedResult, BitSet actualResult, IndexSearcher indexSearcher) throws IOException {
if (VERBOSE) {
System.out.println("expected cardinality:" + expectedResult.cardinality());
DocIdSetIterator iterator = new BitSetIterator(expectedResult, expectedResult.cardinality());
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
System.out.println(String.format(Locale.ROOT, "Expected doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id")));
}
System.out.println("actual cardinality:" + actualResult.cardinality());
iterator = new BitSetIterator(actualResult, actualResult.cardinality());
for (int doc = iterator.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iterator.nextDoc()) {
System.out.println(String.format(Locale.ROOT, "Actual doc[%d] with id value %s", doc, indexSearcher.doc(doc).get("id")));
}
}
assertEquals(expectedResult, actualResult);
}
private IndexIterationContext createContext(int nDocs, RandomIndexWriter fromWriter, RandomIndexWriter toWriter, boolean multipleValuesPerDocument) throws IOException {
private void assertTopDocs(TopDocs expectedTopDocs, TopDocs actualTopDocs, ScoreMode scoreMode, IndexSearcher indexSearcher, Query joinQuery) throws IOException {
assertEquals(expectedTopDocs.totalHits, actualTopDocs.totalHits);
assertEquals(expectedTopDocs.scoreDocs.length, actualTopDocs.scoreDocs.length);
if (scoreMode == ScoreMode.None) {
return;
}
assertEquals(expectedTopDocs.getMaxScore(), actualTopDocs.getMaxScore(), 0.0f);
for (int i = 0; i < expectedTopDocs.scoreDocs.length; i++) {
if (VERBOSE) {
System.out.printf(Locale.ENGLISH, "Expected doc: %d | Actual doc: %d\n", expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
System.out.printf(Locale.ENGLISH, "Expected score: %f | Actual score: %f\n", expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score);
}
assertEquals(expectedTopDocs.scoreDocs[i].doc, actualTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, actualTopDocs.scoreDocs[i].score, 0.0f);
Explanation explanation = indexSearcher.explain(joinQuery, expectedTopDocs.scoreDocs[i].doc);
assertEquals(expectedTopDocs.scoreDocs[i].score, explanation.getValue(), 0.0f);
}
}
private IndexIterationContext createContext(int nDocs, RandomIndexWriter writer, boolean multipleValuesPerDocument, boolean ordinalJoin) throws IOException {
return createContext(nDocs, writer, writer, multipleValuesPerDocument, ordinalJoin);
}
private IndexIterationContext createContext(int nDocs, RandomIndexWriter fromWriter, RandomIndexWriter toWriter, boolean multipleValuesPerDocument, boolean globalOrdinalJoin) throws IOException {
if (globalOrdinalJoin) {
assertFalse("ordinal join doesn't support multiple join values per document", multipleValuesPerDocument);
}
IndexIterationContext context = new IndexIterationContext();
int numRandomValues = nDocs / 2;
context.randomUniqueValues = new String[numRandomValues];
@ -560,8 +728,8 @@ public class TestJoinUtil extends LuceneTestCase {
for (int i = 0; i < numRandomValues; i++) {
String uniqueRandomValue;
do {
uniqueRandomValue = TestUtil.randomRealisticUnicodeString(random());
// uniqueRandomValue = _TestUtil.randomSimpleString(random);
// uniqueRandomValue = TestUtil.randomRealisticUnicodeString(random());
uniqueRandomValue = TestUtil.randomSimpleString(random());
} while ("".equals(uniqueRandomValue) || trackSet.contains(uniqueRandomValue));
// Generate unique values and empty strings aren't allowed.
trackSet.add(uniqueRandomValue);
@ -581,15 +749,18 @@ public class TestJoinUtil extends LuceneTestCase {
boolean from = context.randomFrom[randomI];
int numberOfLinkValues = multipleValuesPerDocument ? 2 + random().nextInt(10) : 1;
docs[i] = new RandomDoc(id, numberOfLinkValues, value, from);
if (globalOrdinalJoin) {
document.add(newStringField("type", from ? "from" : "to", Field.Store.NO));
}
for (int j = 0; j < numberOfLinkValues; j++) {
String linkValue = context.randomUniqueValues[random().nextInt(context.randomUniqueValues.length)];
docs[i].linkValues.add(linkValue);
if (from) {
if (!context.fromDocuments.containsKey(linkValue)) {
context.fromDocuments.put(linkValue, new ArrayList<RandomDoc>());
context.fromDocuments.put(linkValue, new ArrayList<>());
}
if (!context.randomValueFromDocs.containsKey(value)) {
context.randomValueFromDocs.put(value, new ArrayList<RandomDoc>());
context.randomValueFromDocs.put(value, new ArrayList<>());
}
context.fromDocuments.get(linkValue).add(docs[i]);
@ -600,12 +771,15 @@ public class TestJoinUtil extends LuceneTestCase {
} else {
document.add(new SortedDocValuesField("from", new BytesRef(linkValue)));
}
if (globalOrdinalJoin) {
document.add(new SortedDocValuesField("join_field", new BytesRef(linkValue)));
}
} else {
if (!context.toDocuments.containsKey(linkValue)) {
context.toDocuments.put(linkValue, new ArrayList<RandomDoc>());
context.toDocuments.put(linkValue, new ArrayList<>());
}
if (!context.randomValueToDocs.containsKey(value)) {
context.randomValueToDocs.put(value, new ArrayList<RandomDoc>());
context.randomValueToDocs.put(value, new ArrayList<>());
}
context.toDocuments.get(linkValue).add(docs[i]);
@ -616,6 +790,9 @@ public class TestJoinUtil extends LuceneTestCase {
} else {
document.add(new SortedDocValuesField("to", new BytesRef(linkValue)));
}
if (globalOrdinalJoin) {
document.add(new SortedDocValuesField("join_field", new BytesRef(linkValue)));
}
}
}
@ -707,6 +884,9 @@ public class TestJoinUtil extends LuceneTestCase {
if (joinScore == null) {
joinValueToJoinScores.put(BytesRef.deepCopyOf(joinValue), joinScore = new JoinScore());
}
if (VERBOSE) {
System.out.println("expected val=" + joinValue.utf8ToString() + " expected score=" + scorer.score());
}
joinScore.addScore(scorer.score());
}
@ -720,7 +900,7 @@ public class TestJoinUtil extends LuceneTestCase {
public void setScorer(Scorer scorer) {
this.scorer = scorer;
}
@Override
public boolean needsScores() {
return true;
@ -777,7 +957,7 @@ public class TestJoinUtil extends LuceneTestCase {
@Override
public void setScorer(Scorer scorer) {}
@Override
public boolean needsScores() {
return false;
@ -875,6 +1055,7 @@ public class TestJoinUtil extends LuceneTestCase {
Map<String, Map<Integer, JoinScore>> fromHitsToJoinScore = new HashMap<>();
Map<String, Map<Integer, JoinScore>> toHitsToJoinScore = new HashMap<>();
MultiDocValues.OrdinalMap ordinalMap;
}
private static class RandomDoc {
@ -922,4 +1103,29 @@ public class TestJoinUtil extends LuceneTestCase {
}
private static class BitSetCollector extends SimpleCollector {
private final BitSet bitSet;
private int docBase;
private BitSetCollector(BitSet bitSet) {
this.bitSet = bitSet;
}
@Override
public void collect(int doc) throws IOException {
bitSet.set(docBase + doc);
}
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
docBase = context.docBase;
}
@Override
public boolean needsScores() {
return false;
}
}
}