SOLR-10351: Add analyze Stream Evaluator to support streaming NLP

This commit is contained in:
Joel Bernstein 2017-03-30 17:34:28 +01:00
parent edafcbad14
commit 6c2155c024
10 changed files with 278 additions and 9 deletions

View File

@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
*
*/
package org.apache.solr.handler;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.eval.*;
import org.apache.solr.common.SolrException;
import org.apache.lucene.analysis.*;
import org.apache.solr.core.SolrCore;
public class AnalyzeEvaluator extends SimpleEvaluator {
private static final long serialVersionUID = 1L;
private String fieldName;
private String analyzerField;
private Analyzer analyzer;
public AnalyzeEvaluator(String _fieldName, String _analyzerField) {
init(_fieldName, _analyzerField);
}
public AnalyzeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
String _fieldName = factory.getValueOperand(expression, 0);
String _analyzerField = factory.getValueOperand(expression, 1);
init(_fieldName, _analyzerField);
}
public void setStreamContext(StreamContext context) {
Object solrCoreObj = context.get("solr-core");
if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) {
throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
}
SolrCore solrCore = (SolrCore) solrCoreObj;
analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer();
}
private void init(String fieldName, String analyzerField) {
this.fieldName = fieldName;
if(analyzerField == null) {
this.analyzerField = fieldName;
} else {
this.analyzerField = analyzerField;
}
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
String value = tuple.getString(fieldName);
if(value == null) {
return null;
}
TokenStream tokenStream = analyzer.tokenStream(analyzerField, value);
CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class);
tokenStream.reset();
List<String> tokens = new ArrayList();
while (tokenStream.incrementToken()) {
tokens.add(termAtt.toString());
}
tokenStream.end();
tokenStream.close();
return tokens;
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return new StreamExpressionValue(fieldName);
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
return new Explanation(nodeId.toString())
.withExpressionType(ExpressionType.EVALUATOR)
.withImplementingClass(getClass().getName())
.withExpression(toExpression(factory).toString());
}
}

View File

@ -209,6 +209,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("log", NaturalLogEvaluator.class)
// Conditional Stream Evaluators
.withFunctionName("if", IfThenElseEvaluator.class)
.withFunctionName("analyze", AnalyzeEvaluator.class)
;
// This pulls all the overrides and additions from the config

View File

@ -24,11 +24,13 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public abstract class BooleanEvaluator extends ComplexEvaluator {
protected static final long serialVersionUID = 1L;
protected StreamContext streamContext;
public BooleanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
@ -46,6 +48,11 @@ public abstract class BooleanEvaluator extends ComplexEvaluator {
return results;
}
public void setStreamContext(StreamContext streamContext) {
this.streamContext = streamContext;
}
public interface Checker {
default boolean isNullAllowed(){
return false;

View File

@ -24,11 +24,13 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public abstract class ConditionalEvaluator extends ComplexEvaluator {
protected static final long serialVersionUID = 1L;
protected StreamContext streamContext;
public ConditionalEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
@ -43,6 +45,10 @@ public abstract class ConditionalEvaluator extends ComplexEvaluator {
return results;
}
public void setStreamContext(StreamContext streamContext) {
this.streamContext = streamContext;
}
public interface Checker {
default boolean isNullAllowed(){
return false;

View File

@ -26,11 +26,13 @@ import java.util.List;
import java.util.Locale;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public abstract class NumberEvaluator extends ComplexEvaluator {
protected static final long serialVersionUID = 1L;
protected StreamContext streamContext;
public NumberEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
@ -39,6 +41,10 @@ public abstract class NumberEvaluator extends ComplexEvaluator {
// restrict result to a Number
public abstract Number evaluate(Tuple tuple) throws IOException;
public void setStreamContext(StreamContext context) {
this.streamContext = context;
}
public List<BigDecimal> evaluateAll(final Tuple tuple) throws IOException {
// evaluate each and confirm they are all either null or numeric
List<BigDecimal> results = new ArrayList<BigDecimal>();

View File

@ -21,9 +21,16 @@ package org.apache.solr.client.solrj.io.eval;
import java.util.UUID;
import org.apache.solr.client.solrj.io.stream.StreamContext;
public abstract class SimpleEvaluator implements StreamEvaluator {
private static final long serialVersionUID = 1L;
protected UUID nodeId = UUID.randomUUID();
protected StreamContext streamContext;
public void setStreamContext(StreamContext streamContext) {
this.streamContext = streamContext;
}
}

View File

@ -23,8 +23,10 @@ import java.io.IOException;
import java.io.Serializable;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
public interface StreamEvaluator extends Expressible, Serializable {
Object evaluate(final Tuple tuple) throws IOException;
void setStreamContext(StreamContext streamContext);
}

View File

@ -59,7 +59,6 @@ public class CartesianProductStream extends TupleStream implements Expressible {
List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
List<StreamExpressionParameter> evaluateAsExpressions = factory.getOperandsOfType(expression, StreamExpressionValue.class);
StreamExpressionNamedParameter orderByExpression = factory.getNamedOperand(expression, "productSort");
// validate expression contains only what we want.
if(expression.getParameters().size() != streamExpressions.size() + evaluateAsExpressions.size() + (null == orderByExpression ? 0 : 1)){
throw new IOException(String.format(Locale.ROOT,"Invalid %s expression %s - unknown operands found", functionName, expression));
@ -259,6 +258,9 @@ public class CartesianProductStream extends TupleStream implements Expressible {
public void setStreamContext(StreamContext context) {
this.stream.setStreamContext(context);
for(NamedEvaluator evaluator : evaluators) {
evaluator.getEvaluator().setStreamContext(context);
}
}
public List<TupleStream> children() {

View File

@ -22,6 +22,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
@ -213,6 +214,11 @@ public class SelectStream extends TupleStream implements Expressible {
public void setStreamContext(StreamContext context) {
this.stream.setStreamContext(context);
Set<StreamEvaluator> evaluators = selectedEvaluators.keySet();
for(StreamEvaluator evaluator : evaluators) {
evaluator.setStreamContext(context);
}
}
public List<TupleStream> children() {

View File

@ -61,6 +61,7 @@ import org.apache.solr.cloud.AbstractDistribZkTestBase;
import org.apache.solr.cloud.SolrCloudTestCase;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.handler.AnalyzeEvaluator;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
@ -379,7 +380,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
stream = factory.constructStream("sort(search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"id,a_s,a_i,a_f\", sort=\"a_f asc\"), by=\"a_i asc\")");
tuples = getTuples(stream);
assert(tuples.size() == 6);
assertOrder(tuples, 0,1,5,2,3,4);
assertOrder(tuples, 0, 1, 5, 2, 3, 4);
// Basic test desc
stream = factory.constructStream("sort(search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"id,a_s,a_i,a_f\", sort=\"a_f asc\"), by=\"a_i desc\")");
@ -1908,7 +1909,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
stream = new InnerJoinStream(expression, factory);
tuples = getTuples(stream);
assert(tuples.size() == 8);
assertOrder(tuples, 1,1,15,15,3,4,5,7);
assertOrder(tuples, 1, 1, 15, 15, 3, 4, 5, 7);
// Basic desc
expression = StreamExpressionParser.parse("innerJoin("
@ -1922,9 +1923,9 @@ public class StreamExpressionTest extends SolrCloudTestCase {
// Results in both searches, no join matches
expression = StreamExpressionParser.parse("innerJoin("
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:left\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\"),"
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:right\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\", aliases=\"id=right.id, join1_i=right.join1_i, join2_s=right.join2_s, ident_s=right.ident_s\"),"
+ "on=\"ident_s=right.ident_s\")");
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:left\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\"),"
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:right\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\", aliases=\"id=right.id, join1_i=right.join1_i, join2_s=right.join2_s, ident_s=right.ident_s\"),"
+ "on=\"ident_s=right.ident_s\")");
stream = new InnerJoinStream(expression, factory);
tuples = getTuples(stream);
assert(tuples.size() == 0);
@ -1938,7 +1939,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
tuples = getTuples(stream);
assert(tuples.size() == 8);
assertOrder(tuples, 1,1,15,15,3,4,5,7);
assertOrder(tuples, 1, 1, 15, 15, 3, 4, 5, 7);
}
@ -4347,6 +4348,126 @@ public class StreamExpressionTest extends SolrCloudTestCase {
CollectionAdminRequest.deleteCollection("checkpointCollection").process(cluster.getSolrClient());
}
@Test
public void testAnalyzeEvaluator() throws Exception {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.add(id, "1", "test_t", "l b c d c");
updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
SolrClientCache cache = new SolrClientCache();
try {
String expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t, test_t) as test_t)";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
SolrStream solrStream = new SolrStream(url, paramsLoc);
StreamContext context = new StreamContext();
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 5);
Tuple t = tuples.get(0);
assertTrue(t.getString("test_t").equals("l"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(1);
assertTrue(t.getString("test_t").equals("b"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(2);
assertTrue(t.getString("test_t").equals("c"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(3);
assertTrue(t.getString("test_t").equals("d"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(4);
assertTrue(t.getString("test_t").equals("c"));
assertTrue(t.getString("id").equals("1"));
//Try with single param
expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t) as test_t)";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 5);
t = tuples.get(0);
assertTrue(t.getString("test_t").equals("l"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(1);
assertTrue(t.getString("test_t").equals("b"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(2);
assertTrue(t.getString("test_t").equals("c"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(3);
assertTrue(t.getString("test_t").equals("d"));
assertTrue(t.getString("id").equals("1"));
t = tuples.get(4);
assertTrue(t.getString("test_t").equals("c"));
assertTrue(t.getString("id").equals("1"));
//Try with null in the test_t field
expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id\", sort=\"id desc\"), analyze(test_t, test_t) as test_t)";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
//Test annotating tuple
expr = "select(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t, test_t) as test1_t)";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
solrStream = new SolrStream(url, paramsLoc);
context = new StreamContext();
solrStream.setStreamContext(context);
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
List l = (List)tuples.get(0).get("test1_t");
assertTrue(l.get(0).equals("l"));
assertTrue(l.get(1).equals("b"));
assertTrue(l.get(2).equals("c"));
assertTrue(l.get(3).equals("d"));
assertTrue(l.get(4).equals("c"));
} finally {
cache.close();
}
}
@Test
public void testExecutorStream() throws Exception {
CollectionAdminRequest.createCollection("workQueue", "conf", 2, 1).process(cluster.getSolrClient());