DATAES-13 added support for minimum score

This commit is contained in:
Artur Konczak 2014-01-29 22:06:45 +00:00
parent 6c5e8fee22
commit bbc46d95df
8 changed files with 158 additions and 0 deletions

View File

@ -210,6 +210,10 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
searchRequestBuilder.setQuery(QueryBuilders.matchAllQuery()); searchRequestBuilder.setQuery(QueryBuilders.matchAllQuery());
} }
if(criteriaQuery.getMinScore()>0){
searchRequestBuilder.setMinScore(criteriaQuery.getMinScore());
}
if (elasticsearchFilter != null) if (elasticsearchFilter != null)
searchRequestBuilder.setFilter(elasticsearchFilter); searchRequestBuilder.setFilter(elasticsearchFilter);
@ -520,6 +524,10 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
: SortOrder.ASC); : SortOrder.ASC);
} }
} }
if(query.getMinScore()>0){
searchRequestBuilder.setMinScore(query.getMinScore());
}
return searchRequestBuilder; return searchRequestBuilder;
} }

View File

@ -37,6 +37,7 @@ abstract class AbstractQuery implements Query {
protected List<String> indices = new ArrayList<String>(); protected List<String> indices = new ArrayList<String>();
protected List<String> types = new ArrayList<String>(); protected List<String> types = new ArrayList<String>();
protected List<String> fields = new ArrayList<String>(); protected List<String> fields = new ArrayList<String>();
protected float minScore;
@Override @Override
public Sort getSort() { public Sort getSort() {
@ -99,4 +100,12 @@ abstract class AbstractQuery implements Query {
return (T) this; return (T) this;
} }
public float getMinScore() {
return minScore;
}
public void setMinScore(float minScore) {
this.minScore = minScore;
}
} }

View File

@ -16,6 +16,7 @@
package org.springframework.data.elasticsearch.core.query; package org.springframework.data.elasticsearch.core.query;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.index.query.FilterBuilder; import org.elasticsearch.index.query.FilterBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.highlight.HighlightBuilder; import org.elasticsearch.search.highlight.HighlightBuilder;
@ -45,6 +46,7 @@ public class NativeSearchQueryBuilder {
private String[] indices; private String[] indices;
private String[] types; private String[] types;
private String[] fields; private String[] fields;
private float minScore;
public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) { public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder; this.queryBuilder = queryBuilder;
@ -91,6 +93,11 @@ public class NativeSearchQueryBuilder {
return this; return this;
} }
public NativeSearchQueryBuilder withMinScore(float minScore) {
this.minScore = minScore;
return this;
}
public NativeSearchQuery build() { public NativeSearchQuery build() {
NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilder, highlightFields); NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilder, highlightFields);
if (pageable != null) { if (pageable != null) {
@ -108,6 +115,10 @@ public class NativeSearchQueryBuilder {
if (CollectionUtils.isNotEmpty(facetRequests)) { if (CollectionUtils.isNotEmpty(facetRequests)) {
nativeSearchQuery.setFacets(facetRequests); nativeSearchQuery.setFacets(facetRequests);
} }
if(minScore>0){
nativeSearchQuery.setMinScore(minScore);
}
return nativeSearchQuery; return nativeSearchQuery;
} }
} }

View File

@ -108,4 +108,10 @@ public interface Query {
* @return * @return
*/ */
List<String> getFields(); List<String> getFields();
/**
* Get minimum score
* @return
*/
float getMinScore();
} }

View File

@ -112,4 +112,17 @@ public class SampleEntity {
return new HashCodeBuilder().append(id).append(type).append(message).append(rate).append(available).append(version) return new HashCodeBuilder().append(id).append(type).append(message).append(rate).append(available).append(version)
.toHashCode(); .toHashCode();
} }
@Override
public String toString() {
return "SampleEntity{" +
"id='" + id + '\'' +
", type='" + type + '\'' +
", message='" + message + '\'' +
", rate=" + rate +
", available=" + available +
", highlightedMessage='" + highlightedMessage + '\'' +
", version=" + version +
'}';
}
} }

View File

@ -0,0 +1,59 @@
package org.springframework.data.elasticsearch;
import org.springframework.data.elasticsearch.core.query.IndexQuery;
/**
* User: dead
* Date: 23/01/14
* Time: 18:25
*/
public class SampleEntityBuilder {
private SampleEntity result;
public SampleEntityBuilder(String id) {
result = new SampleEntity();
result.setId(id);
}
public SampleEntityBuilder type(String type) {
result.setType(type);
return this;
}
public SampleEntityBuilder message(String message) {
result.setMessage(message);
return this;
}
public SampleEntityBuilder rate(int rate) {
result.setRate(rate);
return this;
}
public SampleEntityBuilder available(boolean available) {
result.setAvailable(available);
return this;
}
public SampleEntityBuilder highlightedMessage(String highlightedMessage) {
result.setHighlightedMessage(highlightedMessage);
return this;
}
public SampleEntityBuilder version(Long version) {
result.setVersion(version);
return this;
}
public SampleEntity build() {
return result;
}
public IndexQuery buildIndex() {
IndexQuery indexQuery = new IndexQuery();
indexQuery.setId(result.getId());
indexQuery.setObject(result);
return indexQuery;
}
}

View File

@ -33,6 +33,7 @@ import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort;
import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.ElasticsearchException;
import org.springframework.data.elasticsearch.SampleEntity; import org.springframework.data.elasticsearch.SampleEntity;
import org.springframework.data.elasticsearch.SampleEntityBuilder;
import org.springframework.data.elasticsearch.SampleMappingEntity; import org.springframework.data.elasticsearch.SampleMappingEntity;
import org.springframework.data.elasticsearch.core.query.*; import org.springframework.data.elasticsearch.core.query.*;
import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.ContextConfiguration;
@ -1007,4 +1008,30 @@ public class ElasticsearchTemplateTests {
assertThat(ids, is(notNullValue())); assertThat(ids, is(notNullValue()));
assertThat(ids.size(), is(30)); assertThat(ids.size(), is(30));
} }
@Test
public void shouldReturnDocumentAboveMinimalScoreGivenQuery() {
// given
List<IndexQuery> indexQueries = new ArrayList<IndexQuery>();
indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex());
indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex());
indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex());
elasticsearchTemplate.bulkIndex(indexQueries);
elasticsearchTemplate.refresh(SampleEntity.class, true);
// when
SearchQuery searchQuery = new NativeSearchQueryBuilder()
.withQuery(boolQuery().must(wildcardQuery("message", "*a*")).should(wildcardQuery("message", "*b*")))
.withIndices("test-index")
.withTypes("test-type")
.withMinScore(0.5F)
.build();
Page<SampleEntity> page = elasticsearchTemplate.queryForPage(searchQuery, SampleEntity.class);
// then
assertThat(page.getTotalElements(),is(1L));
assertThat(page.getContent().get(0).getMessage(), is("ab"));
}
} }

View File

@ -15,12 +15,15 @@
*/ */
package org.springframework.data.elasticsearch.core.query; package org.springframework.data.elasticsearch.core.query;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.springframework.data.domain.Page; import org.springframework.data.domain.Page;
import org.springframework.data.elasticsearch.SampleEntity; import org.springframework.data.elasticsearch.SampleEntity;
import org.springframework.data.elasticsearch.SampleEntityBuilder;
import org.springframework.data.elasticsearch.core.ElasticsearchTemplate; import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
@ -30,6 +33,7 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.apache.commons.lang.RandomStringUtils.randomNumeric; import static org.apache.commons.lang.RandomStringUtils.randomNumeric;
import static org.elasticsearch.index.query.QueryBuilders.matchAllQuery;
import static org.hamcrest.Matchers.*; import static org.hamcrest.Matchers.*;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -700,4 +704,25 @@ public class CriteriaQueryTests {
// then // then
assertThat(page.getTotalElements(), is(greaterThanOrEqualTo(1L))); assertThat(page.getTotalElements(), is(greaterThanOrEqualTo(1L)));
} }
@Test
public void shouldReturnDocumentAboveMinimalScoreGivenCriteria() {
// given
List<IndexQuery> indexQueries = new ArrayList<IndexQuery>();
indexQueries.add(new SampleEntityBuilder("1").message("ab").buildIndex());
indexQueries.add(new SampleEntityBuilder("2").message("bc").buildIndex());
indexQueries.add(new SampleEntityBuilder("3").message("ac").buildIndex());
elasticsearchTemplate.bulkIndex(indexQueries);
elasticsearchTemplate.refresh(SampleEntity.class, true);
// when
CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("a").or(new Criteria("message").contains("b")));
criteriaQuery.setMinScore(0.5F);
Page<SampleEntity> page = elasticsearchTemplate.queryForPage(criteriaQuery, SampleEntity.class);
// then
assertThat(page.getTotalElements(),is(1L));
assertThat(page.getContent().get(0).getMessage(), is("ab"));
}
} }