Add minimal sanity checks to custom/scripted similarities. (#33564)
Add minimal sanity checks to custom/scripted similarities. Lucene 8 introduced more constraints on similarities, in particular: - scores must not be negative, - scores must not decrease when term freq increases, - scores must not increase when norm (interpreted as an unsigned long) increases. We can't check every single case, but could at least run some sanity checks. Relates #33309
This commit is contained in:
parent
7f473b683d
commit
c4261bab44
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.similarity;
|
||||
|
||||
import org.apache.lucene.index.FieldInvertState;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
|
||||
/**
|
||||
* A {@link Similarity} that rejects negative scores. This class exists so that users get
|
||||
* an error instead of silently corrupt top hits. It should be applied to any custom or
|
||||
* scripted similarity.
|
||||
*/
|
||||
// public for testing
|
||||
public final class NonNegativeScoresSimilarity extends Similarity {
|
||||
|
||||
// Escape hatch
|
||||
private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores";
|
||||
private static final boolean ENFORCE_POSITIVE_SCORES;
|
||||
static {
|
||||
String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES);
|
||||
if (enforcePositiveScores == null) {
|
||||
ENFORCE_POSITIVE_SCORES = true;
|
||||
} else if ("false".equals(enforcePositiveScores)) {
|
||||
ENFORCE_POSITIVE_SCORES = false;
|
||||
} else {
|
||||
throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" +
|
||||
enforcePositiveScores + "]");
|
||||
}
|
||||
}
|
||||
|
||||
private final Similarity in;
|
||||
|
||||
public NonNegativeScoresSimilarity(Similarity in) {
|
||||
this.in = in;
|
||||
}
|
||||
|
||||
public Similarity getDelegate() {
|
||||
return in;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long computeNorm(FieldInvertState state) {
|
||||
return in.computeNorm(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
|
||||
final SimScorer inScorer = in.scorer(boost, collectionStats, termStats);
|
||||
return new SimScorer() {
|
||||
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
float score = inScorer.score(freq, norm);
|
||||
if (score < 0f) {
|
||||
if (ENFORCE_POSITIVE_SCORES) {
|
||||
throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" +
|
||||
inScorer.explain(Explanation.match(freq, "term frequency"), norm));
|
||||
} else {
|
||||
return 0f;
|
||||
}
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(Explanation freq, long norm) {
|
||||
Explanation expl = inScorer.explain(freq, norm);
|
||||
if (expl.isMatch() && expl.getValue().floatValue() < 0) {
|
||||
expl = Explanation.match(0f, "max of:",
|
||||
expl, Explanation.match(0f, "Minimum allowed score"));
|
||||
}
|
||||
return expl;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -19,15 +19,22 @@
|
|||
|
||||
package org.elasticsearch.index.similarity;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.lucene.index.FieldInvertState;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.BooleanSimilarity;
|
||||
import org.apache.lucene.search.similarities.ClassicSimilarity;
|
||||
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.TriFunction;
|
||||
import org.elasticsearch.common.logging.DeprecationLogger;
|
||||
import org.elasticsearch.common.logging.Loggers;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.AbstractIndexComponent;
|
||||
import org.elasticsearch.index.IndexModule;
|
||||
|
@ -44,7 +51,7 @@ import java.util.function.Supplier;
|
|||
|
||||
public final class SimilarityService extends AbstractIndexComponent {
|
||||
|
||||
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
|
||||
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
|
||||
public static final String DEFAULT_SIMILARITY = "BM25";
|
||||
private static final String CLASSIC_SIMILARITY = "classic";
|
||||
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
|
||||
|
@ -131,8 +138,14 @@ public final class SimilarityService extends AbstractIndexComponent {
|
|||
}
|
||||
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
|
||||
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
|
||||
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
|
||||
providers.put(name, () -> similarity);
|
||||
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
|
||||
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
|
||||
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
|
||||
// We don't trust custom similarities
|
||||
similarity = new NonNegativeScoresSimilarity(similarity);
|
||||
}
|
||||
final Similarity similarityF = similarity; // like similarity but final
|
||||
providers.put(name, () -> similarityF);
|
||||
}
|
||||
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
|
||||
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
|
||||
|
@ -182,4 +195,80 @@ public final class SimilarityService extends AbstractIndexComponent {
|
|||
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
|
||||
}
|
||||
}
|
||||
|
||||
static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
|
||||
validateScoresArePositive(indexCreatedVersion, similarity);
|
||||
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
|
||||
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
|
||||
}
|
||||
|
||||
private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
|
||||
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
|
||||
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
|
||||
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
|
||||
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
|
||||
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
|
||||
final long norm = similarity.computeNorm(state);
|
||||
for (int freq = 1; freq <= 10; ++freq) {
|
||||
float score = scorer.score(freq, norm);
|
||||
if (score < 0) {
|
||||
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
|
||||
scorer.explain(Explanation.match(freq, "term freq"), norm));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
|
||||
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
|
||||
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
|
||||
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
|
||||
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
|
||||
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
|
||||
final long norm = similarity.computeNorm(state);
|
||||
float previousScore = 0;
|
||||
for (int freq = 1; freq <= 10; ++freq) {
|
||||
float score = scorer.score(freq, norm);
|
||||
if (score < previousScore) {
|
||||
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
|
||||
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
|
||||
scorer.explain(Explanation.match(freq, "term freq"), norm));
|
||||
}
|
||||
previousScore = score;
|
||||
}
|
||||
}
|
||||
|
||||
private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
|
||||
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
|
||||
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
|
||||
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
|
||||
|
||||
long previousNorm = 0;
|
||||
float previousScore = Float.MAX_VALUE;
|
||||
for (int length = 1; length <= 10; ++length) {
|
||||
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
|
||||
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
|
||||
final long norm = similarity.computeNorm(state);
|
||||
if (Long.compareUnsigned(previousNorm, norm) > 0) {
|
||||
// esoteric similarity, skip this check
|
||||
break;
|
||||
}
|
||||
float score = scorer.score(1, norm);
|
||||
if (score > previousScore) {
|
||||
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
|
||||
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
|
||||
scorer.explain(Explanation.match(1, "term freq"), norm));
|
||||
}
|
||||
previousScore = score;
|
||||
previousNorm = norm;
|
||||
}
|
||||
}
|
||||
|
||||
private static void fail(Version indexCreatedVersion, String message) {
|
||||
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
|
||||
throw new IllegalArgumentException(message);
|
||||
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
|
||||
DEPRECATION_LOGGER.deprecated(message);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -59,6 +59,7 @@ import org.elasticsearch.index.shard.IndexSearcherWrapper;
|
|||
import org.elasticsearch.index.shard.IndexingOperationListener;
|
||||
import org.elasticsearch.index.shard.SearchOperationListener;
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
|
||||
import org.elasticsearch.index.similarity.SimilarityService;
|
||||
import org.elasticsearch.index.store.IndexStore;
|
||||
import org.elasticsearch.indices.IndicesModule;
|
||||
|
@ -77,6 +78,7 @@ import org.elasticsearch.test.TestSearchContext;
|
|||
import org.elasticsearch.test.engine.MockEngineFactory;
|
||||
import org.elasticsearch.threadpool.TestThreadPool;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -295,10 +297,13 @@ public class IndexModuleTests extends ESTestCase {
|
|||
|
||||
IndexService indexService = newIndexService(module);
|
||||
SimilarityService similarityService = indexService.similarityService();
|
||||
assertNotNull(similarityService.getSimilarity("my_similarity"));
|
||||
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
|
||||
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
|
||||
assertNotNull(similarity);
|
||||
assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class));
|
||||
similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate();
|
||||
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
|
||||
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
|
||||
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
|
||||
assertEquals("there is a key", ((TestSimilarity) similarity).key);
|
||||
indexService.close("simon says", false);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*
|
||||
* Licensed to Elasticsearch under one or more contributor
|
||||
* license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright
|
||||
* ownership. Elasticsearch licenses this file to you under
|
||||
* the Apache License, Version 2.0 (the "License"); you may
|
||||
* not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing,
|
||||
* software distributed under the License is distributed on an
|
||||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied. See the License for the
|
||||
* specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.index.similarity;
|
||||
|
||||
import org.apache.lucene.index.FieldInvertState;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.search.similarities.Similarity.SimScorer;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
public class NonNegativeScoresSimilarityTests extends ESTestCase {
|
||||
|
||||
public void testBasics() {
|
||||
Similarity negativeScoresSim = new Similarity() {
|
||||
|
||||
@Override
|
||||
public long computeNorm(FieldInvertState state) {
|
||||
return state.getLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
|
||||
return new SimScorer() {
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
return freq - 5;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim);
|
||||
SimScorer scorer = assertingSimilarity.scorer(1f, null);
|
||||
assertEquals(2f, scorer.score(7f, 1L), 0f);
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L));
|
||||
assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores"));
|
||||
}
|
||||
|
||||
}
|
|
@ -18,12 +18,18 @@
|
|||
*/
|
||||
package org.elasticsearch.index.similarity;
|
||||
|
||||
import org.apache.lucene.index.FieldInvertState;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.BooleanSimilarity;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.IndexSettings;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.IndexSettingsModule;
|
||||
import org.hamcrest.Matchers;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
|
@ -56,4 +62,76 @@ public class SimilarityServiceTests extends ESTestCase {
|
|||
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
|
||||
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
|
||||
}
|
||||
|
||||
public void testSimilarityValidation() {
|
||||
Similarity negativeScoresSim = new Similarity() {
|
||||
|
||||
@Override
|
||||
public long computeNorm(FieldInvertState state) {
|
||||
return state.getLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
|
||||
return new SimScorer() {
|
||||
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
};
|
||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
|
||||
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim));
|
||||
assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores"));
|
||||
|
||||
Similarity decreasingScoresWithFreqSim = new Similarity() {
|
||||
|
||||
@Override
|
||||
public long computeNorm(FieldInvertState state) {
|
||||
return state.getLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
|
||||
return new SimScorer() {
|
||||
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
return 1 / (freq + norm);
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
};
|
||||
e = expectThrows(IllegalArgumentException.class,
|
||||
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim));
|
||||
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases"));
|
||||
|
||||
Similarity increasingScoresWithNormSim = new Similarity() {
|
||||
|
||||
@Override
|
||||
public long computeNorm(FieldInvertState state) {
|
||||
return state.getLength();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
|
||||
return new SimScorer() {
|
||||
|
||||
@Override
|
||||
public float score(float freq, long norm) {
|
||||
return freq + norm;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
};
|
||||
e = expectThrows(IllegalArgumentException.class,
|
||||
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim));
|
||||
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases"));
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
package org.elasticsearch.indices;
|
||||
|
||||
import org.apache.lucene.search.similarities.BM25Similarity;
|
||||
import org.apache.lucene.search.similarities.Similarity;
|
||||
import org.apache.lucene.store.AlreadyClosedException;
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags;
|
||||
|
@ -56,6 +57,7 @@ import org.elasticsearch.index.shard.IndexShard;
|
|||
import org.elasticsearch.index.shard.IndexShardState;
|
||||
import org.elasticsearch.index.shard.ShardId;
|
||||
import org.elasticsearch.index.shard.ShardPath;
|
||||
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
|
||||
import org.elasticsearch.indices.IndicesService.ShardDeletionCheckResult;
|
||||
import org.elasticsearch.plugins.EnginePlugin;
|
||||
import org.elasticsearch.plugins.MapperPlugin;
|
||||
|
@ -448,8 +450,10 @@ public class IndicesServiceTests extends ESSingleNodeTestCase {
|
|||
.build();
|
||||
MapperService mapperService = indicesService.createIndexMapperService(indexMetaData);
|
||||
assertNotNull(mapperService.documentMapperParser().parserContext("type").typeParser("fake-mapper"));
|
||||
assertThat(mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(),
|
||||
instanceOf(BM25Similarity.class));
|
||||
Similarity sim = mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get();
|
||||
assertThat(sim, instanceOf(NonNegativeScoresSimilarity.class));
|
||||
sim = ((NonNegativeScoresSimilarity) sim).getDelegate();
|
||||
assertThat(sim, instanceOf(BM25Similarity.class));
|
||||
}
|
||||
|
||||
public void testStatsByShardDoesNotDieFromExpectedExceptions() {
|
||||
|
|
Loading…
Reference in New Issue