Add minimal sanity checks to custom/scripted similarities. (#33564)

Add minimal sanity checks to custom/scripted similarities.

Lucene 8 introduced more constraints on similarities, in particular:
 - scores must not be negative,
 - scores must not decrease when term freq increases,
 - scores must not increase when norm (interpreted as an unsigned long)
   increases.

We can't check every single case, but could at least run some sanity checks.

Relates #33309
This commit is contained in:
Adrien Grand 2018-09-19 09:19:13 +02:00 committed by GitHub
parent 7f473b683d
commit c4261bab44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 339 additions and 10 deletions

View File

@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.similarity;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
/**
* A {@link Similarity} that rejects negative scores. This class exists so that users get
* an error instead of silently corrupt top hits. It should be applied to any custom or
* scripted similarity.
*/
// public for testing
public final class NonNegativeScoresSimilarity extends Similarity {
// Escape hatch
private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores";
private static final boolean ENFORCE_POSITIVE_SCORES;
static {
String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES);
if (enforcePositiveScores == null) {
ENFORCE_POSITIVE_SCORES = true;
} else if ("false".equals(enforcePositiveScores)) {
ENFORCE_POSITIVE_SCORES = false;
} else {
throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" +
enforcePositiveScores + "]");
}
}
private final Similarity in;
public NonNegativeScoresSimilarity(Similarity in) {
this.in = in;
}
public Similarity getDelegate() {
return in;
}
@Override
public long computeNorm(FieldInvertState state) {
return in.computeNorm(state);
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
final SimScorer inScorer = in.scorer(boost, collectionStats, termStats);
return new SimScorer() {
@Override
public float score(float freq, long norm) {
float score = inScorer.score(freq, norm);
if (score < 0f) {
if (ENFORCE_POSITIVE_SCORES) {
throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" +
inScorer.explain(Explanation.match(freq, "term frequency"), norm));
} else {
return 0f;
}
}
return score;
}
@Override
public Explanation explain(Explanation freq, long norm) {
Explanation expl = inScorer.explain(freq, norm);
if (expl.isMatch() && expl.getValue().floatValue() < 0) {
expl = Explanation.match(0f, "max of:",
expl, Explanation.match(0f, "Minimum allowed score"));
}
return expl;
}
};
}
}

View File

@ -19,15 +19,22 @@
package org.elasticsearch.index.similarity; package org.elasticsearch.index.similarity;
import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper; import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent; import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexModule;
@ -44,7 +51,7 @@ import java.util.function.Supplier;
public final class SimilarityService extends AbstractIndexComponent { public final class SimilarityService extends AbstractIndexComponent {
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class)); private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25"; public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic"; private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS; private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
@ -131,8 +138,14 @@ public final class SimilarityService extends AbstractIndexComponent {
} }
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName); TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory); TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
providers.put(name, () -> similarity); validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
// We don't trust custom similarities
similarity = new NonNegativeScoresSimilarity(similarity);
}
final Similarity similarityF = similarity; // like similarity but final
providers.put(name, () -> similarityF);
} }
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) { for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated())); providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
@ -151,7 +164,7 @@ public final class SimilarityService extends AbstractIndexComponent {
defaultSimilarity; defaultSimilarity;
} }
public SimilarityProvider getSimilarity(String name) { public SimilarityProvider getSimilarity(String name) {
Supplier<Similarity> sim = similarities.get(name); Supplier<Similarity> sim = similarities.get(name);
if (sim == null) { if (sim == null) {
@ -182,4 +195,80 @@ public final class SimilarityService extends AbstractIndexComponent {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity; return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
} }
} }
static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}
private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
}
}
private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
previousScore = score;
}
}
private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
scorer.explain(Explanation.match(1, "term freq"), norm));
}
previousScore = score;
previousNorm = norm;
}
}
private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
DEPRECATION_LOGGER.deprecated(message);
}
}
} }

View File

@ -59,6 +59,7 @@ import org.elasticsearch.index.shard.IndexSearcherWrapper;
import org.elasticsearch.index.shard.IndexingOperationListener; import org.elasticsearch.index.shard.IndexingOperationListener;
import org.elasticsearch.index.shard.SearchOperationListener; import org.elasticsearch.index.shard.SearchOperationListener;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.index.store.IndexStore; import org.elasticsearch.index.store.IndexStore;
import org.elasticsearch.indices.IndicesModule; import org.elasticsearch.indices.IndicesModule;
@ -77,6 +78,7 @@ import org.elasticsearch.test.TestSearchContext;
import org.elasticsearch.test.engine.MockEngineFactory; import org.elasticsearch.test.engine.MockEngineFactory;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matchers;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -295,10 +297,13 @@ public class IndexModuleTests extends ESTestCase {
IndexService indexService = newIndexService(module); IndexService indexService = newIndexService(module);
SimilarityService similarityService = indexService.similarityService(); SimilarityService similarityService = indexService.similarityService();
assertNotNull(similarityService.getSimilarity("my_similarity")); Similarity similarity = similarityService.getSimilarity("my_similarity").get();
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity); assertNotNull(similarity);
assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class));
similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate();
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name()); assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key); assertEquals("there is a key", ((TestSimilarity) similarity).key);
indexService.close("simon says", false); indexService.close("simon says", false);
} }

View File

@ -0,0 +1,57 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.index.similarity;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;
public class NonNegativeScoresSimilarityTests extends ESTestCase {
public void testBasics() {
Similarity negativeScoresSim = new Similarity() {
@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq - 5;
}
};
}
};
Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim);
SimScorer scorer = assertingSimilarity.scorer(1f, null);
assertEquals(2f, scorer.score(7f, 1L), 0f);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L));
assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores"));
}
}

View File

@ -18,12 +18,18 @@
*/ */
package org.elasticsearch.index.similarity; package org.elasticsearch.index.similarity;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.IndexSettingsModule; import org.elasticsearch.test.IndexSettingsModule;
import org.hamcrest.Matchers;
import java.util.Collections; import java.util.Collections;
@ -56,4 +62,76 @@ public class SimilarityServiceTests extends ESTestCase {
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap()); SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity); assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
} }
public void testSimilarityValidation() {
Similarity negativeScoresSim = new Similarity() {
@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return -1;
}
};
}
};
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim));
assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores"));
Similarity decreasingScoresWithFreqSim = new Similarity() {
@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return 1 / (freq + norm);
}
};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases"));
Similarity increasingScoresWithNormSim = new Similarity() {
@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}
@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq + norm;
}
};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases"));
}
} }

View File

@ -19,6 +19,7 @@
package org.elasticsearch.indices; package org.elasticsearch.indices;
import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.AlreadyClosedException; import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags; import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags;
@ -56,6 +57,7 @@ import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.IndexShardState;
import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardPath; import org.elasticsearch.index.shard.ShardPath;
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
import org.elasticsearch.indices.IndicesService.ShardDeletionCheckResult; import org.elasticsearch.indices.IndicesService.ShardDeletionCheckResult;
import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.EnginePlugin;
import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.MapperPlugin;
@ -448,8 +450,10 @@ public class IndicesServiceTests extends ESSingleNodeTestCase {
.build(); .build();
MapperService mapperService = indicesService.createIndexMapperService(indexMetaData); MapperService mapperService = indicesService.createIndexMapperService(indexMetaData);
assertNotNull(mapperService.documentMapperParser().parserContext("type").typeParser("fake-mapper")); assertNotNull(mapperService.documentMapperParser().parserContext("type").typeParser("fake-mapper"));
assertThat(mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(), Similarity sim = mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get();
instanceOf(BM25Similarity.class)); assertThat(sim, instanceOf(NonNegativeScoresSimilarity.class));
sim = ((NonNegativeScoresSimilarity) sim).getDelegate();
assertThat(sim, instanceOf(BM25Similarity.class));
} }
public void testStatsByShardDoesNotDieFromExpectedExceptions() { public void testStatsByShardDoesNotDieFromExpectedExceptions() {