Added support for 'more like this' in ElasticsearchOperations

https://github.com/BioMedCentralLtd/spring-data-elasticsearch/issues/4
This commit is contained in:
Rizwan Idrees 2013-03-19 15:08:34 +00:00
parent f53c1bc660
commit 93a512f014
4 changed files with 360 additions and 49 deletions

View File

@ -202,5 +202,13 @@ public interface ElasticsearchOperations {
*/
<T> Page<T> scroll(String scrollId, long scrollTimeInMillis, ResultsMapper<T> resultsMapper);
/**
* more like this query to search for documents that are "like" a specific document.
* @param query
* @param clazz
* @param <T>
* @return
*/
<T> Page<T> moreLikeThis(MoreLikeThisQuery query, Class<T> clazz);
}

View File

@ -24,6 +24,7 @@ import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.count.CountRequestBuilder;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.mlt.MoreLikeThisRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
@ -54,7 +55,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.commons.collections.CollectionUtils.isNotEmpty;
import static org.apache.commons.lang.StringUtils.isBlank;
import static org.apache.commons.lang.StringUtils.isNotBlank;
import static org.elasticsearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
import static org.elasticsearch.action.search.SearchType.SCAN;
import static org.elasticsearch.client.Requests.indicesExistsRequest;
@ -127,7 +130,7 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
return mapResults(response, clazz, query.getPageable());
}
@Override
public <T> Page<T> queryForPage(SearchQuery query, ResultsMapper<T> resultsMapper) {
SearchResponse response = doSearch(prepareSearch(query), query.getElasticsearchQuery(), query.getElasticsearchFilter(),query.getElasticsearchSort());
return resultsMapper.mapResults(response);
@ -144,18 +147,6 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
return extractIds(response);
}
private SearchResponse doSearch(SearchRequestBuilder searchRequest, QueryBuilder query, FilterBuilder filter, SortBuilder sortBuilder){
if(filter != null){
searchRequest.setFilter(filter);
}
if(sortBuilder != null){
searchRequest.addSort(sortBuilder);
}
return searchRequest.setQuery(query).execute().actionGet();
}
@Override
public <T> Page<T> queryForPage(CriteriaQuery query, Class<T> clazz) {
QueryBuilder elasticsearchQuery = new CriteriaQueryProcessor().createQueryFromCriteria(query.getCriteria());
@ -229,6 +220,114 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
.execute().actionGet();
}
@Override
public String scan(SearchQuery query, long scrollTimeInMillis, boolean noFields) {
Assert.notNull(query.getIndices(), "No index defined for Query");
Assert.notNull(query.getTypes(), "No type define for Query");
Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll");
SearchRequestBuilder requestBuilder = client.prepareSearch(toArray(query.getIndices()))
.setSearchType(SCAN)
.setQuery(query.getElasticsearchQuery())
.setTypes(toArray(query.getTypes()))
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis))
.setFrom(0)
.setSize(query.getPageable().getPageSize());
if(query.getElasticsearchFilter() != null){
requestBuilder.setFilter(query.getElasticsearchFilter());
}
if(noFields){
requestBuilder.setNoFields();
}
return requestBuilder.execute().actionGet().getScrollId();
}
@Override
public <T> Page<T> scroll(String scrollId, long scrollTimeInMillis, ResultsMapper<T> resultsMapper) {
SearchResponse response = client.prepareSearchScroll(scrollId)
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis))
.execute().actionGet();
return resultsMapper.mapResults(response);
}
@Override
public <T> Page<T> moreLikeThis(MoreLikeThisQuery query, Class<T> clazz) {
int startRecord = 0;
ElasticsearchPersistentEntity persistentEntity = getPersistentEntityFor(clazz);
String indexName = isNotBlank(query.getIndexName())? query.getIndexName(): persistentEntity.getIndexName();
String type = isNotBlank(query.getType())? query.getType() : persistentEntity.getIndexType();
Assert.notNull(indexName,"No 'indexName' defined for MoreLikeThisQuery");
Assert.notNull(type, "No 'type' defined for MoreLikeThisQuery");
Assert.notNull(query.getId(), "No document id defined for MoreLikeThisQuery");
MoreLikeThisRequestBuilder requestBuilder =
client.prepareMoreLikeThis(indexName,type, query.getId());
if(query.getPageable() != null){
startRecord = ((query.getPageable().getPageNumber() - 1) * query.getPageable().getPageSize());
requestBuilder.setSearchSize(query.getPageable().getPageSize());
}
requestBuilder.setSearchFrom(startRecord < 0 ? 0 : startRecord);
if(isNotEmpty(query.getSearchIndices())){
requestBuilder.setSearchIndices(toArray(query.getSearchIndices()));
}
if(isNotEmpty(query.getSearchTypes())){
requestBuilder.setSearchTypes(toArray(query.getSearchTypes()));
}
if(isNotEmpty(query.getFields())){
requestBuilder.setField(toArray(query.getFields()));
}
if(isNotBlank(query.getRouting())){
requestBuilder.setRouting(query.getRouting());
}
if(query.getPercentTermsToMatch() != null){
requestBuilder.setPercentTermsToMatch(query.getPercentTermsToMatch());
}
if(query.getMinTermFreq() != null){
requestBuilder.setMinTermFreq(query.getMinTermFreq());
}
if(query.getMaxQueryTerms() != null){
requestBuilder.maxQueryTerms(query.getMaxQueryTerms());
}
if(isNotEmpty(query.getStopWords())){
requestBuilder.setStopWords(toArray(query.getStopWords()));
}
if(query.getMinDocFreq() != null){
requestBuilder.setMinDocFreq(query.getMinDocFreq());
}
if(query.getMaxDocFreq() != null){
requestBuilder.setMaxDocFreq(query.getMaxDocFreq());
}
if(query.getMinWordLen() != null){
requestBuilder.setMinWordLen(query.getMinWordLen());
}
if(query.getMaxWordLen() != null){
requestBuilder.setMaxWordLen(query.getMaxWordLen());
}
if(query.getBoostTerms() != null){
requestBuilder.setBoostTerms(query.getBoostTerms());
}
SearchResponse response = requestBuilder.execute().actionGet();
return mapResults(response, clazz, query.getPageable());
}
private SearchResponse doSearch(SearchRequestBuilder searchRequest, QueryBuilder query, FilterBuilder filter, SortBuilder sortBuilder){
if(filter != null){
searchRequest.setFilter(filter);
}
if(sortBuilder != null){
searchRequest.addSort(sortBuilder);
}
return searchRequest.setQuery(query).execute().actionGet();
}
private boolean createIndexIfNotCreated(String indexName) {
return indexExists(indexName) || createIndex(indexName);
}
@ -256,7 +355,7 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
private SearchRequestBuilder prepareSearch(Query query){
Assert.notNull(query.getIndices(), "No index defined for Query");
Assert.notNull(query.getTypes(), "No type define for Query");
Assert.notNull(query.getTypes(), "No type defined for Query");
int startRecord = 0;
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(toArray(query.getIndices()))
@ -320,38 +419,6 @@ public class ElasticsearchTemplate implements ElasticsearchOperations {
.refresh(refreshRequest(persistentEntity.getIndexName()).waitForOperations(waitForOperation)).actionGet();
}
@Override
public String scan(SearchQuery query, long scrollTimeInMillis, boolean noFields) {
Assert.notNull(query.getIndices(), "No index defined for Query");
Assert.notNull(query.getTypes(), "No type define for Query");
Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll");
SearchRequestBuilder requestBuilder = client.prepareSearch(toArray(query.getIndices()))
.setSearchType(SCAN)
.setQuery(query.getElasticsearchQuery())
.setTypes(toArray(query.getTypes()))
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis))
.setFrom(0)
.setSize(query.getPageable().getPageSize());
if(query.getElasticsearchFilter() != null){
requestBuilder.setFilter(query.getElasticsearchFilter());
}
if(noFields){
requestBuilder.setNoFields();
}
return requestBuilder.execute().actionGet().getScrollId();
}
@Override
public <T> Page<T> scroll(String scrollId, long scrollTimeInMillis, ResultsMapper<T> resultsMapper) {
SearchResponse response = client.prepareSearchScroll(scrollId)
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis))
.execute().actionGet();
return resultsMapper.mapResults(response);
}
private ElasticsearchPersistentEntity getPersistentEntityFor(Class clazz){
Assert.isTrue(clazz.isAnnotationPresent(Document.class), "Unable to identify index name. " +
clazz.getSimpleName() + " is not a Document. Make sure the document class is annotated with @Document(indexName=\"foo\")");

View File

@ -0,0 +1,190 @@
/*
* Copyright 2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.core.query;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import java.util.ArrayList;
import java.util.List;
import static org.apache.commons.collections.CollectionUtils.addAll;
import static org.springframework.data.elasticsearch.core.query.Query.DEFAULT_PAGE_SIZE;
/**
* MoreLikeThisQuery
*
* @author Rizwan Idrees
* @author Mohsin Husen
*/
public class MoreLikeThisQuery {
private String id;
private String indexName;
private String type;
private List<String> searchIndices = new ArrayList<String>();
private List<String> searchTypes = new ArrayList<String>();
private List<String> fields = new ArrayList<String>();
private String routing;
private Float percentTermsToMatch;
private Integer minTermFreq;
private Integer maxQueryTerms;
private List<String> stopWords = new ArrayList<String>();
private Integer minDocFreq;
private Integer maxDocFreq;
private Integer minWordLen;
private Integer maxWordLen;
private Float boostTerms;
private Pageable pageable = new PageRequest(0, DEFAULT_PAGE_SIZE);
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getIndexName() {
return indexName;
}
public void setIndexName(String indexName) {
this.indexName = indexName;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public List<String> getSearchIndices() {
return searchIndices;
}
public void addSearchIndices(String...searchIndices) {
addAll(this.searchIndices, searchIndices);
}
public List<String> getSearchTypes() {
return searchTypes;
}
public void addSearchTypes(String...searchTypes) {
addAll(this.searchTypes, searchTypes);
}
public List<String> getFields() {
return fields;
}
public void addFields(String...fields) {
addAll(this.fields,fields);
}
public String getRouting() {
return routing;
}
public void setRouting(String routing) {
this.routing = routing;
}
public Float getPercentTermsToMatch() {
return percentTermsToMatch;
}
public void setPercentTermsToMatch(Float percentTermsToMatch) {
this.percentTermsToMatch = percentTermsToMatch;
}
public Integer getMinTermFreq() {
return minTermFreq;
}
public void setMinTermFreq(Integer minTermFreq) {
this.minTermFreq = minTermFreq;
}
public Integer getMaxQueryTerms() {
return maxQueryTerms;
}
public void setMaxQueryTerms(Integer maxQueryTerms) {
this.maxQueryTerms = maxQueryTerms;
}
public List<String> getStopWords() {
return stopWords;
}
public void addStopWords(String...stopWords) {
addAll(this.stopWords,stopWords);
}
public Integer getMinDocFreq() {
return minDocFreq;
}
public void setMinDocFreq(Integer minDocFreq) {
this.minDocFreq = minDocFreq;
}
public Integer getMaxDocFreq() {
return maxDocFreq;
}
public void setMaxDocFreq(Integer maxDocFreq) {
this.maxDocFreq = maxDocFreq;
}
public Integer getMinWordLen() {
return minWordLen;
}
public void setMinWordLen(Integer minWordLen) {
this.minWordLen = minWordLen;
}
public Integer getMaxWordLen() {
return maxWordLen;
}
public void setMaxWordLen(Integer maxWordLen) {
this.maxWordLen = maxWordLen;
}
public Float getBoostTerms() {
return boostTerms;
}
public void setBoostTerms(Float boostTerms) {
this.boostTerms = boostTerms;
}
public Pageable getPageable() {
return pageable;
}
public void setPageable(Pageable pageable) {
this.pageable = pageable;
}
}

View File

@ -111,7 +111,6 @@ public class ElasticsearchTemplateTest {
@Test
public void shouldReturnPageForGivenSearchQuery(){
//given
//given
String documentId = randomNumeric(5);
SampleEntity sampleEntity = new SampleEntity();
@ -247,7 +246,7 @@ public class ElasticsearchTemplateTest {
}
@Test
public void shouldTestFilterBuilder(){
public void shouldFilterSearchResultsGivenFilter(){
//given
String documentId = randomNumeric(5);
SampleEntity sampleEntity = new SampleEntity();
@ -271,7 +270,7 @@ public class ElasticsearchTemplateTest {
}
@Test
public void shouldTestSortBuilder(){
public void shouldSortResultsGivenSortCriteria(){
//given
List<IndexQuery> indexQueries = new ArrayList<IndexQuery>();
//first document
@ -494,4 +493,51 @@ public class ElasticsearchTemplateTest {
assertThat(page.getTotalElements(), is(equalTo(1L)));
assertThat(page.getContent().get(0), is(message));
}
@Test
public void shouldReturnSimilarResultsGivenMoreLikeThisQuery(){
//given
String sampleMessage = "So we build a web site or an application and want to add search to it, " +
"and then it hits us: getting search working is hard. We want our search solution to be fast," +
" we want a painless setup and a completely free search schema, we want to be able to index data simply using JSON over HTTP, " +
"we want our search server to be always available, we want to be able to start with one machine and scale to hundreds, " +
"we want real-time search, we want simple multi-tenancy, and we want a solution that is built for the cloud.";
String documentId1 = randomNumeric(5);
SampleEntity sampleEntity1 = new SampleEntity();
sampleEntity1.setId(documentId1);
sampleEntity1.setMessage(sampleMessage);
sampleEntity1.setVersion(System.currentTimeMillis());
IndexQuery indexQuery1 = new IndexQuery();
indexQuery1.setId(documentId1);
indexQuery1.setObject(sampleEntity1);
elasticsearchTemplate.index(indexQuery1);
String documentId2 = randomNumeric(5);
SampleEntity sampleEntity2 = new SampleEntity();
sampleEntity2.setId(documentId2);
sampleEntity2.setMessage(sampleMessage);
sampleEntity2.setVersion(System.currentTimeMillis());
IndexQuery indexQuery2 = new IndexQuery();
indexQuery2.setId(documentId2);
indexQuery2.setObject(sampleEntity2);
elasticsearchTemplate.index(indexQuery2);
elasticsearchTemplate.refresh(SampleEntity.class, true);
MoreLikeThisQuery moreLikeThisQuery = new MoreLikeThisQuery();
moreLikeThisQuery.setId(documentId2);
moreLikeThisQuery.addFields("message");
moreLikeThisQuery.setMinDocFreq(1);
//when
Page<SampleEntity> sampleEntities = elasticsearchTemplate.moreLikeThis(moreLikeThisQuery, SampleEntity.class);
//then
assertThat(sampleEntities.getTotalElements(), is(equalTo(1L)));
assertThat(sampleEntities.getContent(), hasItem(sampleEntity1));
}
}