diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index e33730043..883872094 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -210,6 +210,10 @@ public class ElasticsearchTemplate implements ElasticsearchOperations { searchRequestBuilder.setQuery(QueryBuilders.matchAllQuery()); } + if(criteriaQuery.getMinScore()>0){ + searchRequestBuilder.setMinScore(criteriaQuery.getMinScore()); + } + if (elasticsearchFilter != null) searchRequestBuilder.setFilter(elasticsearchFilter); @@ -520,6 +524,10 @@ public class ElasticsearchTemplate implements ElasticsearchOperations { : SortOrder.ASC); } } + + if(query.getMinScore()>0){ + searchRequestBuilder.setMinScore(query.getMinScore()); + } return searchRequestBuilder; } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index 2bb79698f..8fb89fd19 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -37,6 +37,7 @@ abstract class AbstractQuery implements Query { protected List indices = new ArrayList(); protected List types = new ArrayList(); protected List fields = new ArrayList(); + protected float minScore; @Override public Sort getSort() { @@ -99,4 +100,12 @@ abstract class AbstractQuery implements Query { return (T) this; } + + public float getMinScore() { + return minScore; + } + + public void setMinScore(float minScore) { + this.minScore = minScore; + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java index db667a245..7f885e0dd 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java @@ -16,6 +16,7 @@ package org.springframework.data.elasticsearch.core.query; import org.apache.commons.collections.CollectionUtils; +import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.index.query.FilterBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.highlight.HighlightBuilder; @@ -45,6 +46,7 @@ public class NativeSearchQueryBuilder { private String[] indices; private String[] types; private String[] fields; + private float minScore; public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -91,6 +93,11 @@ public class NativeSearchQueryBuilder { return this; } + public NativeSearchQueryBuilder withMinScore(float minScore) { + this.minScore = minScore; + return this; + } + public NativeSearchQuery build() { NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilder, highlightFields); if (pageable != null) { @@ -108,6 +115,10 @@ public class NativeSearchQueryBuilder { if (CollectionUtils.isNotEmpty(facetRequests)) { nativeSearchQuery.setFacets(facetRequests); } + + if(minScore>0){ + nativeSearchQuery.setMinScore(minScore); + } return nativeSearchQuery; } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java index 3d85f9bf7..1d57e6968 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java @@ -108,4 +108,10 @@ public interface Query { * @return */ List getFields(); + + /** + * Get minimum score + * @return + */ + float getMinScore(); } diff --git a/src/test/java/org/springframework/data/elasticsearch/SampleEntity.java b/src/test/java/org/springframework/data/elasticsearch/SampleEntity.java index aceb9588a..795ac6266 100644 --- a/src/test/java/org/springframework/data/elasticsearch/SampleEntity.java +++ b/src/test/java/org/springframework/data/elasticsearch/SampleEntity.java @@ -112,4 +112,17 @@ public class SampleEntity { return new HashCodeBuilder().append(id).append(type).append(message).append(rate).append(available).append(version) .toHashCode(); } + + @Override + public String toString() { + return "SampleEntity{" + + "id='" + id + '\'' + + ", type='" + type + '\'' + + ", message='" + message + '\'' + + ", rate=" + rate + + ", available=" + available + + ", highlightedMessage='" + highlightedMessage + '\'' + + ", version=" + version + + '}'; + } } diff --git a/src/test/java/org/springframework/data/elasticsearch/SampleEntityBuilder.java b/src/test/java/org/springframework/data/elasticsearch/SampleEntityBuilder.java new file mode 100644 index 000000000..a3d50f6dc --- /dev/null +++ b/src/test/java/org/springframework/data/elasticsearch/SampleEntityBuilder.java @@ -0,0 +1,59 @@ +package org.springframework.data.elasticsearch; + +import org.springframework.data.elasticsearch.core.query.IndexQuery; + +/** + * User: dead + * Date: 23/01/14 + * Time: 18:25 + */ +public class SampleEntityBuilder { + + private SampleEntity result; + + public SampleEntityBuilder(String id) { + result = new SampleEntity(); + result.setId(id); + } + + public SampleEntityBuilder type(String type) { + result.setType(type); + return this; + } + + public SampleEntityBuilder message(String message) { + result.setMessage(message); + return this; + } + + public SampleEntityBuilder rate(int rate) { + result.setRate(rate); + return this; + } + + public SampleEntityBuilder available(boolean available) { + result.setAvailable(available); + return this; + } + + public SampleEntityBuilder highlightedMessage(String highlightedMessage) { + result.setHighlightedMessage(highlightedMessage); + return this; + } + + public SampleEntityBuilder version(Long version) { + result.setVersion(version); + return this; + } + + public SampleEntity build() { + return result; + } + + public IndexQuery buildIndex() { + IndexQuery indexQuery = new IndexQuery(); + indexQuery.setId(result.getId()); + indexQuery.setObject(result); + return indexQuery; + } +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 9953bc24b..1269d400b 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -33,6 +33,7 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.SampleEntity; +import org.springframework.data.elasticsearch.SampleEntityBuilder; import org.springframework.data.elasticsearch.SampleMappingEntity; import org.springframework.data.elasticsearch.core.query.*; import org.springframework.test.context.ContextConfiguration; @@ -1007,4 +1008,30 @@ public class ElasticsearchTemplateTests { assertThat(ids, is(notNullValue())); assertThat(ids.size(), is(30)); } + + @Test + public void shouldReturnDocumentAboveMinimalScoreGivenQuery() { + // given + List indexQueries = new ArrayList(); + + indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex()); + indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex()); + indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex()); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class, true); + + // when + SearchQuery searchQuery = new NativeSearchQueryBuilder() + .withQuery(boolQuery().must(wildcardQuery("message", "*a*")).should(wildcardQuery("message", "*b*"))) + .withIndices("test-index") + .withTypes("test-type") + .withMinScore(0.5F) + .build(); + + Page page = elasticsearchTemplate.queryForPage(searchQuery, SampleEntity.class); + // then + assertThat(page.getTotalElements(),is(1L)); + assertThat(page.getContent().get(0).getMessage(), is("ab")); + } } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/query/CriteriaQueryTests.java b/src/test/java/org/springframework/data/elasticsearch/core/query/CriteriaQueryTests.java index 43014c219..e7c71a34a 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/query/CriteriaQueryTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/query/CriteriaQueryTests.java @@ -15,12 +15,15 @@ */ package org.springframework.data.elasticsearch.core.query; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.SortOrder; import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.data.domain.Page; import org.springframework.data.elasticsearch.SampleEntity; +import org.springframework.data.elasticsearch.SampleEntityBuilder; import org.springframework.data.elasticsearch.core.ElasticsearchTemplate; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; @@ -30,6 +33,7 @@ import java.util.ArrayList; import java.util.List; import static org.apache.commons.lang.RandomStringUtils.randomNumeric; +import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery; import static org.hamcrest.Matchers.*; import static org.junit.Assert.*; @@ -700,4 +704,25 @@ public class CriteriaQueryTests { // then assertThat(page.getTotalElements(), is(greaterThanOrEqualTo(1L))); } + + @Test + public void shouldReturnDocumentAboveMinimalScoreGivenCriteria() { + // given + List indexQueries = new ArrayList(); + + indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex()); + indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex()); + indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex()); + + elasticsearchTemplate.bulkIndex(indexQueries); + elasticsearchTemplate.refresh(SampleEntity.class, true); + + // when + CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("a").or(new Criteria("message").contains("b"))); + criteriaQuery.setMinScore(0.5F); + Page page = elasticsearchTemplate.queryForPage(criteriaQuery, SampleEntity.class); + // then + assertThat(page.getTotalElements(),is(1L)); + assertThat(page.getContent().get(0).getMessage(), is("ab")); + } }