mirror of https://github.com/apache/lucene.git
SOLR-10351: Add analyze Stream Evaluator to support streaming NLP
This commit is contained in:
parent
b54b08db7d
commit
f5b7738da8
|
@ -0,0 +1,111 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
/**
|
||||
*
|
||||
*/
|
||||
package org.apache.solr.handler;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
import org.apache.solr.client.solrj.io.eval.*;
|
||||
|
||||
import org.apache.solr.common.SolrException;
|
||||
import org.apache.lucene.analysis.*;
|
||||
import org.apache.solr.core.SolrCore;
|
||||
|
||||
public class AnalyzeEvaluator extends SimpleEvaluator {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
private String fieldName;
|
||||
private String analyzerField;
|
||||
private Analyzer analyzer;
|
||||
|
||||
public AnalyzeEvaluator(String _fieldName, String _analyzerField) {
|
||||
init(_fieldName, _analyzerField);
|
||||
}
|
||||
|
||||
public AnalyzeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
String _fieldName = factory.getValueOperand(expression, 0);
|
||||
String _analyzerField = factory.getValueOperand(expression, 1);
|
||||
init(_fieldName, _analyzerField);
|
||||
}
|
||||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
Object solrCoreObj = context.get("solr-core");
|
||||
if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) {
|
||||
throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
|
||||
}
|
||||
SolrCore solrCore = (SolrCore) solrCoreObj;
|
||||
|
||||
analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer();
|
||||
}
|
||||
|
||||
private void init(String fieldName, String analyzerField) {
|
||||
this.fieldName = fieldName;
|
||||
if(analyzerField == null) {
|
||||
this.analyzerField = fieldName;
|
||||
} else {
|
||||
this.analyzerField = analyzerField;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object evaluate(Tuple tuple) throws IOException {
|
||||
String value = tuple.getString(fieldName);
|
||||
if(value == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
TokenStream tokenStream = analyzer.tokenStream(analyzerField, value);
|
||||
CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class);
|
||||
tokenStream.reset();
|
||||
List<String> tokens = new ArrayList();
|
||||
while (tokenStream.incrementToken()) {
|
||||
tokens.add(termAtt.toString());
|
||||
}
|
||||
|
||||
tokenStream.end();
|
||||
tokenStream.close();
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
|
||||
return new StreamExpressionValue(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation toExplanation(StreamFactory factory) throws IOException {
|
||||
return new Explanation(nodeId.toString())
|
||||
.withExpressionType(ExpressionType.EVALUATOR)
|
||||
.withImplementingClass(getClass().getName())
|
||||
.withExpression(toExpression(factory).toString());
|
||||
}
|
||||
|
||||
}
|
|
@ -209,6 +209,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
|
|||
.withFunctionName("log", NaturalLogEvaluator.class)
|
||||
// Conditional Stream Evaluators
|
||||
.withFunctionName("if", IfThenElseEvaluator.class)
|
||||
.withFunctionName("analyze", AnalyzeEvaluator.class)
|
||||
;
|
||||
|
||||
// This pulls all the overrides and additions from the config
|
||||
|
|
|
@ -24,11 +24,13 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public abstract class BooleanEvaluator extends ComplexEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
protected StreamContext streamContext;
|
||||
|
||||
public BooleanEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
@ -45,7 +47,12 @@ public abstract class BooleanEvaluator extends ComplexEvaluator {
|
|||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
public void setStreamContext(StreamContext streamContext) {
|
||||
this.streamContext = streamContext;
|
||||
}
|
||||
|
||||
|
||||
public interface Checker {
|
||||
default boolean isNullAllowed(){
|
||||
return false;
|
||||
|
|
|
@ -24,11 +24,13 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public abstract class ConditionalEvaluator extends ComplexEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
protected StreamContext streamContext;
|
||||
|
||||
public ConditionalEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
@ -42,6 +44,10 @@ public abstract class ConditionalEvaluator extends ComplexEvaluator {
|
|||
|
||||
return results;
|
||||
}
|
||||
|
||||
public void setStreamContext(StreamContext streamContext) {
|
||||
this.streamContext = streamContext;
|
||||
}
|
||||
|
||||
public interface Checker {
|
||||
default boolean isNullAllowed(){
|
||||
|
|
|
@ -26,11 +26,13 @@ import java.util.List;
|
|||
import java.util.Locale;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public abstract class NumberEvaluator extends ComplexEvaluator {
|
||||
protected static final long serialVersionUID = 1L;
|
||||
protected StreamContext streamContext;
|
||||
|
||||
public NumberEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
super(expression, factory);
|
||||
|
@ -38,6 +40,10 @@ public abstract class NumberEvaluator extends ComplexEvaluator {
|
|||
|
||||
// restrict result to a Number
|
||||
public abstract Number evaluate(Tuple tuple) throws IOException;
|
||||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.streamContext = context;
|
||||
}
|
||||
|
||||
public List<BigDecimal> evaluateAll(final Tuple tuple) throws IOException {
|
||||
// evaluate each and confirm they are all either null or numeric
|
||||
|
|
|
@ -21,9 +21,16 @@ package org.apache.solr.client.solrj.io.eval;
|
|||
|
||||
import java.util.UUID;
|
||||
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
|
||||
public abstract class SimpleEvaluator implements StreamEvaluator {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
protected UUID nodeId = UUID.randomUUID();
|
||||
protected StreamContext streamContext;
|
||||
|
||||
public void setStreamContext(StreamContext streamContext) {
|
||||
this.streamContext = streamContext;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -23,8 +23,10 @@ import java.io.IOException;
|
|||
import java.io.Serializable;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.stream.StreamContext;
|
||||
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
|
||||
|
||||
public interface StreamEvaluator extends Expressible, Serializable {
|
||||
Object evaluate(final Tuple tuple) throws IOException;
|
||||
void setStreamContext(StreamContext streamContext);
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ public class CartesianProductStream extends TupleStream implements Expressible {
|
|||
private List<NamedEvaluator> evaluators;
|
||||
private StreamComparator orderBy;
|
||||
|
||||
// Used to contain the sorted queue of generated tuples
|
||||
// Used to contain the sorted queue of generated tuples
|
||||
private LinkedList<Tuple> generatedTuples;
|
||||
|
||||
public CartesianProductStream(StreamExpression expression,StreamFactory factory) throws IOException {
|
||||
|
@ -59,7 +59,6 @@ public class CartesianProductStream extends TupleStream implements Expressible {
|
|||
List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
|
||||
List<StreamExpressionParameter> evaluateAsExpressions = factory.getOperandsOfType(expression, StreamExpressionValue.class);
|
||||
StreamExpressionNamedParameter orderByExpression = factory.getNamedOperand(expression, "productSort");
|
||||
|
||||
// validate expression contains only what we want.
|
||||
if(expression.getParameters().size() != streamExpressions.size() + evaluateAsExpressions.size() + (null == orderByExpression ? 0 : 1)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid %s expression %s - unknown operands found", functionName, expression));
|
||||
|
@ -259,6 +258,9 @@ public class CartesianProductStream extends TupleStream implements Expressible {
|
|||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.stream.setStreamContext(context);
|
||||
for(NamedEvaluator evaluator : evaluators) {
|
||||
evaluator.getEvaluator().setStreamContext(context);
|
||||
}
|
||||
}
|
||||
|
||||
public List<TupleStream> children() {
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.StreamComparator;
|
||||
|
@ -213,6 +214,11 @@ public class SelectStream extends TupleStream implements Expressible {
|
|||
|
||||
public void setStreamContext(StreamContext context) {
|
||||
this.stream.setStreamContext(context);
|
||||
Set<StreamEvaluator> evaluators = selectedEvaluators.keySet();
|
||||
|
||||
for(StreamEvaluator evaluator : evaluators) {
|
||||
evaluator.setStreamContext(context);
|
||||
}
|
||||
}
|
||||
|
||||
public List<TupleStream> children() {
|
||||
|
|
|
@ -61,6 +61,7 @@ import org.apache.solr.cloud.AbstractDistribZkTestBase;
|
|||
import org.apache.solr.cloud.SolrCloudTestCase;
|
||||
import org.apache.solr.common.params.CommonParams;
|
||||
import org.apache.solr.common.params.ModifiableSolrParams;
|
||||
import org.apache.solr.handler.AnalyzeEvaluator;
|
||||
import org.junit.Assume;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
|
@ -379,7 +380,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
stream = factory.constructStream("sort(search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"id,a_s,a_i,a_f\", sort=\"a_f asc\"), by=\"a_i asc\")");
|
||||
tuples = getTuples(stream);
|
||||
assert(tuples.size() == 6);
|
||||
assertOrder(tuples, 0,1,5,2,3,4);
|
||||
assertOrder(tuples, 0, 1, 5, 2, 3, 4);
|
||||
|
||||
// Basic test desc
|
||||
stream = factory.constructStream("sort(search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"id,a_s,a_i,a_f\", sort=\"a_f asc\"), by=\"a_i desc\")");
|
||||
|
@ -1908,7 +1909,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
stream = new InnerJoinStream(expression, factory);
|
||||
tuples = getTuples(stream);
|
||||
assert(tuples.size() == 8);
|
||||
assertOrder(tuples, 1,1,15,15,3,4,5,7);
|
||||
assertOrder(tuples, 1, 1, 15, 15, 3, 4, 5, 7);
|
||||
|
||||
// Basic desc
|
||||
expression = StreamExpressionParser.parse("innerJoin("
|
||||
|
@ -1922,9 +1923,9 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
|
||||
// Results in both searches, no join matches
|
||||
expression = StreamExpressionParser.parse("innerJoin("
|
||||
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:left\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\"),"
|
||||
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:right\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\", aliases=\"id=right.id, join1_i=right.join1_i, join2_s=right.join2_s, ident_s=right.ident_s\"),"
|
||||
+ "on=\"ident_s=right.ident_s\")");
|
||||
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:left\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\"),"
|
||||
+ "search(" + COLLECTIONORALIAS + ", q=\"side_s:right\", fl=\"id,join1_i,join2_s,ident_s\", sort=\"ident_s asc\", aliases=\"id=right.id, join1_i=right.join1_i, join2_s=right.join2_s, ident_s=right.ident_s\"),"
|
||||
+ "on=\"ident_s=right.ident_s\")");
|
||||
stream = new InnerJoinStream(expression, factory);
|
||||
tuples = getTuples(stream);
|
||||
assert(tuples.size() == 0);
|
||||
|
@ -1938,7 +1939,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
tuples = getTuples(stream);
|
||||
|
||||
assert(tuples.size() == 8);
|
||||
assertOrder(tuples, 1,1,15,15,3,4,5,7);
|
||||
assertOrder(tuples, 1, 1, 15, 15, 3, 4, 5, 7);
|
||||
|
||||
}
|
||||
|
||||
|
@ -4347,6 +4348,126 @@ public class StreamExpressionTest extends SolrCloudTestCase {
|
|||
CollectionAdminRequest.deleteCollection("checkpointCollection").process(cluster.getSolrClient());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAnalyzeEvaluator() throws Exception {
|
||||
|
||||
UpdateRequest updateRequest = new UpdateRequest();
|
||||
updateRequest.add(id, "1", "test_t", "l b c d c");
|
||||
updateRequest.commit(cluster.getSolrClient(), COLLECTIONORALIAS);
|
||||
|
||||
|
||||
SolrClientCache cache = new SolrClientCache();
|
||||
try {
|
||||
|
||||
String expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t, test_t) as test_t)";
|
||||
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
|
||||
|
||||
SolrStream solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
StreamContext context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
List<Tuple> tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 5);
|
||||
|
||||
Tuple t = tuples.get(0);
|
||||
assertTrue(t.getString("test_t").equals("l"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
t = tuples.get(1);
|
||||
assertTrue(t.getString("test_t").equals("b"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
t = tuples.get(2);
|
||||
assertTrue(t.getString("test_t").equals("c"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
t = tuples.get(3);
|
||||
assertTrue(t.getString("test_t").equals("d"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
t = tuples.get(4);
|
||||
assertTrue(t.getString("test_t").equals("c"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
//Try with single param
|
||||
expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t) as test_t)";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 5);
|
||||
|
||||
t = tuples.get(0);
|
||||
assertTrue(t.getString("test_t").equals("l"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
t = tuples.get(1);
|
||||
assertTrue(t.getString("test_t").equals("b"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
t = tuples.get(2);
|
||||
assertTrue(t.getString("test_t").equals("c"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
t = tuples.get(3);
|
||||
assertTrue(t.getString("test_t").equals("d"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
t = tuples.get(4);
|
||||
assertTrue(t.getString("test_t").equals("c"));
|
||||
assertTrue(t.getString("id").equals("1"));
|
||||
|
||||
|
||||
//Try with null in the test_t field
|
||||
expr = "cartesianProduct(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id\", sort=\"id desc\"), analyze(test_t, test_t) as test_t)";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
|
||||
//Test annotating tuple
|
||||
expr = "select(search("+COLLECTIONORALIAS+", q=\"*:*\", fl=\"id, test_t\", sort=\"id desc\"), analyze(test_t, test_t) as test1_t)";
|
||||
paramsLoc = new ModifiableSolrParams();
|
||||
paramsLoc.set("expr", expr);
|
||||
paramsLoc.set("qt", "/stream");
|
||||
|
||||
solrStream = new SolrStream(url, paramsLoc);
|
||||
|
||||
context = new StreamContext();
|
||||
solrStream.setStreamContext(context);
|
||||
tuples = getTuples(solrStream);
|
||||
assertTrue(tuples.size() == 1);
|
||||
List l = (List)tuples.get(0).get("test1_t");
|
||||
assertTrue(l.get(0).equals("l"));
|
||||
assertTrue(l.get(1).equals("b"));
|
||||
assertTrue(l.get(2).equals("c"));
|
||||
assertTrue(l.get(3).equals("d"));
|
||||
assertTrue(l.get(4).equals("c"));
|
||||
} finally {
|
||||
cache.close();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testExecutorStream() throws Exception {
|
||||
CollectionAdminRequest.createCollection("workQueue", "conf", 2, 1).process(cluster.getSolrClient());
|
||||
|
|
Loading…
Reference in New Issue