From 9cd6437d4b21dd6d9c16688eedb5af012ea67e86 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Thu, 29 Sep 2016 17:44:20 -0400 Subject: [PATCH] SOLR-9258: Optimizing, storing and deploying AI models with Streaming Expressions --- .../apache/solr/handler/ClassifyStream.java | 230 ++++++++++++++++++ .../apache/solr/handler/StreamHandler.java | 12 +- .../solr/client/solrj/io/ModelCache.java | 154 ++++++++++++ .../client/solrj/io/stream/ModelStream.java | 200 +++++++++++++++ .../client/solrj/io/stream/StreamContext.java | 11 + .../client/solrj/io/stream/TopicStream.java | 10 +- .../solrj/solr/configsets/ml/conf/schema.xml | 2 +- .../solrj/io/stream/StreamExpressionTest.java | 145 ++++++++++- 8 files changed, 758 insertions(+), 6 deletions(-) create mode 100644 solr/core/src/java/org/apache/solr/handler/ClassifyStream.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java diff --git a/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java new file mode 100644 index 00000000000..6b0a02a3fad --- /dev/null +++ b/solr/core/src/java/org/apache/solr/handler/ClassifyStream.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.handler; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.ArrayList; +import java.util.Map; +import java.util.Locale; + +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.stream.StreamContext; +import org.apache.solr.client.solrj.io.stream.TupleStream; +import org.apache.solr.client.solrj.io.stream.expr.Explanation; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.common.SolrException; +import org.apache.solr.core.SolrCore; +import org.apache.lucene.analysis.*; + +/** + * The classify expression retrieves a model trained by the train expression and uses it to classify documents from a stream + * Syntax: + * classif(model(...), anyStream(...), field="body") + **/ + +public class ClassifyStream extends TupleStream implements Expressible { + private TupleStream docStream; + private TupleStream modelStream; + + private String field; + private String analyzerField; + private Tuple modelTuple; + + Analyzer analyzer; + private Map termToIndex; + private List idfs; + private List modelWeights; + + public ClassifyStream(StreamExpression expression, StreamFactory factory) throws IOException { + List streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class); + if (streamExpressions.size() != 2) { + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two stream but found %d",expression, streamExpressions.size())); + } + + modelStream = factory.constructStream(streamExpressions.get(0)); + docStream = factory.constructStream(streamExpressions.get(1)); + + StreamExpressionNamedParameter fieldParameter = factory.getNamedOperand(expression, "field"); + if (fieldParameter == null) { + throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - field parameter must be specified",expression, streamExpressions.size())); + } + analyzerField = field = fieldParameter.getParameter().toString(); + + StreamExpressionNamedParameter analyzerFieldParameter = factory.getNamedOperand(expression, "analyzerField"); + if (analyzerFieldParameter != null) { + analyzerField = analyzerFieldParameter.getParameter().toString(); + } + } + + @Override + public void setStreamContext(StreamContext context) { + Object solrCoreObj = context.get("solr-core"); + if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) { + throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key"); + } + SolrCore solrCore = (SolrCore) solrCoreObj; + analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer(); + + this.docStream.setStreamContext(context); + this.modelStream.setStreamContext(context); + } + + @Override + public List children() { + List l = new ArrayList<>(); + l.add(docStream); + l.add(modelStream); + return l; + } + + @Override + public void open() throws IOException { + this.docStream.open(); + this.modelStream.open(); + } + + @Override + public void close() throws IOException { + this.docStream.close(); + this.modelStream.close(); + } + + @Override + public Tuple read() throws IOException { + if (modelTuple == null) { + + modelTuple = modelStream.read(); + if (modelTuple == null || modelTuple.EOF) { + throw new IOException("Model tuple not found for classify stream!"); + } + + termToIndex = new HashMap<>(); + + List terms = modelTuple.getStrings("terms_ss"); + + for (int i = 0; i < terms.size(); i++) { + termToIndex.put(terms.get(i), i); + } + + idfs = modelTuple.getDoubles("idfs_ds"); + modelWeights = modelTuple.getDoubles("weights_ds"); + } + + Tuple docTuple = docStream.read(); + if (docTuple.EOF) return docTuple; + + String text = docTuple.getString(field); + + double tfs[] = new double[termToIndex.size()]; + + TokenStream tokenStream = analyzer.tokenStream(analyzerField, text); + CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class); + tokenStream.reset(); + + int termCount = 0; + while (tokenStream.incrementToken()) { + termCount++; + if (termToIndex.containsKey(termAtt.toString())) { + tfs[termToIndex.get(termAtt.toString())]++; + } + } + + tokenStream.end(); + tokenStream.close(); + + List tfidfs = new ArrayList<>(termToIndex.size()); + tfidfs.add(1.0); + for (int i = 0; i < tfs.length; i++) { + if (tfs[i] != 0) { + tfs[i] = 1 + Math.log(tfs[i]); + } + tfidfs.add(this.idfs.get(i) * tfs[i]); + } + + double total = 0.0; + for (int i = 0; i < tfidfs.size(); i++) { + total += tfidfs.get(i) * modelWeights.get(i); + } + + double score = total * ((float) (1.0 / Math.sqrt(termCount))); + double positiveProb = sigmoid(total); + + docTuple.put("probability_d", positiveProb); + docTuple.put("score_d", score); + + return docTuple; + } + + private double sigmoid(double in) { + double d = 1.0 / (1+Math.exp(-in)); + return d; + } + + @Override + public StreamComparator getStreamSort() { + return null; + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + return toExpression(factory, true); + } + + private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException { + // function name + StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); + + if (includeStreams) { + if (docStream instanceof Expressible && modelStream instanceof Expressible) { + expression.addParameter(((Expressible)modelStream).toExpression(factory)); + expression.addParameter(((Expressible)docStream).toExpression(factory)); + } else { + throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression"); + } + } + + expression.addParameter(new StreamExpressionNamedParameter("field", field)); + expression.addParameter(new StreamExpressionNamedParameter("analyzerField", analyzerField)); + + return expression; + } + + @Override + public Explanation toExplanation(StreamFactory factory) throws IOException { + StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString()); + + explanation.setFunctionName(factory.getFunctionName(this.getClass())); + explanation.setImplementingClass(this.getClass().getName()); + explanation.setExpressionType(Explanation.ExpressionType.STREAM_DECORATOR); + explanation.setExpression(toExpression(factory, false).toString()); + + explanation.addChild(docStream.toExplanation(factory)); + explanation.addChild(modelStream.toExplanation(factory)); + + return explanation; + } +} \ No newline at end of file diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index d190e508884..6dbfdbe8fd1 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import org.apache.solr.client.solrj.io.ModelCache; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; @@ -36,10 +37,10 @@ import org.apache.solr.client.solrj.io.ops.GroupOperation; import org.apache.solr.client.solrj.io.ops.ReplaceOperation; import org.apache.solr.client.solrj.io.stream.*; import org.apache.solr.client.solrj.io.stream.expr.Explanation; +import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; import org.apache.solr.client.solrj.io.stream.expr.Expressible; import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; -import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType; import org.apache.solr.client.solrj.io.stream.metrics.CountMetric; import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric; import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric; @@ -64,6 +65,7 @@ import org.slf4j.LoggerFactory; public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, PermissionNameProvider { static SolrClientCache clientCache = new SolrClientCache(); + static ModelCache modelCache = null; private StreamFactory streamFactory = new StreamFactory(); private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private String coreName; @@ -96,6 +98,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, defaultZkhost = core.getCoreDescriptor().getCoreContainer().getZkController().getZkServerAddress(); streamFactory.withCollectionZkHost(defaultCollection, defaultZkhost); streamFactory.withDefaultZkHost(defaultZkhost); + modelCache = new ModelCache(250, + defaultZkhost, + clientCache); } streamFactory @@ -130,6 +135,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("gatherNodes", GatherNodesStream.class) .withFunctionName("select", SelectStream.class) .withFunctionName("scoreNodes", ScoreNodesStream.class) + .withFunctionName("model", ModelStream.class) + .withFunctionName("classify", ClassifyStream.class) // metrics .withFunctionName("min", MinMetric.class) @@ -197,7 +204,9 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, context.workerID = worker; context.numWorkers = numWorkers; context.setSolrClientCache(clientCache); + context.setModelCache(modelCache); context.put("core", this.coreName); + context.put("solr-core", req.getCore()); tupleStream.setStreamContext(context); // if asking for explanation then go get it @@ -454,5 +463,4 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, return tuple; } } - } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java new file mode 100644 index 00000000000..4fe3d8af790 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ModelCache.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.client.solrj.io; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.invoke.MethodHandles; +import java.util.Collections; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.HashMap; + +import org.apache.solr.client.solrj.SolrClient; +import org.apache.solr.client.solrj.impl.CloudSolrClient; +import org.apache.solr.client.solrj.impl.HttpSolrClient; +import org.apache.solr.client.solrj.io.stream.CloudSolrStream; +import org.apache.solr.client.solrj.io.stream.StreamContext; +import org.apache.solr.client.solrj.io.stream.TopicStream; +import org.apache.solr.common.params.ModifiableSolrParams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * The Model cache keeps a local in-memory copy of models + */ + +public class ModelCache implements Serializable { + + private final LRU models; + private String defaultZkHost; + private SolrClientCache solrClientCache; + + public ModelCache(int size, + String defaultZkHost, + SolrClientCache solrClientCache) { + this.models = new LRU(size); + this.defaultZkHost = defaultZkHost; + this.solrClientCache = solrClientCache; + } + + public Tuple getModel(String collection, + String modelID, + long checkMillis) throws IOException { + return getModel(defaultZkHost, collection, modelID, checkMillis); + } + + public Tuple getModel(String zkHost, + String collection, + String modelID, + long checkMillis) throws IOException { + Model model = null; + long currentTime = new Date().getTime(); + synchronized (this) { + model = models.get(modelID); + if(model != null && ((currentTime - model.getLastChecked()) <= checkMillis)) { + return model.getTuple(); + } + + if(model != null){ + //model is expired + models.remove(modelID); + } + } + + //Model is not in cache or has expired so fetch the model + ModifiableSolrParams params = new ModifiableSolrParams(); + params.set("q","name_s:"+modelID); + params.set("fl", "terms_ss, idfs_ds, weights_ds, iteration_i, _version_"); + params.set("sort", "iteration_i desc"); + StreamContext streamContext = new StreamContext(); + streamContext.setSolrClientCache(solrClientCache); + CloudSolrStream stream = new CloudSolrStream(zkHost, collection, params); + stream.setStreamContext(streamContext); + Tuple tuple = null; + try { + stream.open(); + tuple = stream.read(); + if (tuple.EOF) { + return null; + } + } finally { + stream.close(); + } + + synchronized (this) { + //check again to see if another thread has updated the same model + Model m = models.get(modelID); + if (m != null) { + Tuple t = m.getTuple(); + long v = t.getLong("_version_"); + if (v >= tuple.getLong("_version_")) { + return t; + } else { + models.put(modelID, new Model(tuple, currentTime)); + return tuple; + } + } else { + models.put(modelID, new Model(tuple, currentTime)); + return tuple; + } + } + } + + private static class Model { + private Tuple tuple; + private long lastChecked; + + public Model(Tuple tuple, long lastChecked) { + this.tuple = tuple; + this.lastChecked = lastChecked; + } + + public Tuple getTuple() { + return tuple; + } + + public long getLastChecked() { + return lastChecked; + } + } + + private static class LRU extends LinkedHashMap { + + private int maxSize; + + public LRU(int maxSize) { + this.maxSize = maxSize; + } + + public boolean removeEldestEntry(Map.Entry eldest) { + if(size()> maxSize) { + return true; + } else { + return false; + } + } + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java new file mode 100644 index 00000000000..70b740d5d81 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ModelStream.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.stream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.apache.solr.client.solrj.io.ModelCache; +import org.apache.solr.client.solrj.io.SolrClientCache; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.stream.expr.Explanation; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; + +/** +* The ModelStream retrieves a stored model from a Solr Cloud collection. +* +* Syntax: model(collection, id="modelID") +**/ + +public class ModelStream extends TupleStream implements Expressible { + + private static final long serialVersionUID = 1; + + protected String zkHost; + protected String collection; + protected String modelID; + protected ModelCache modelCache; + protected SolrClientCache solrClientCache; + protected Tuple model; + protected long cacheMillis; + + public ModelStream(String zkHost, + String collectionName, + String modelID, + long cacheMillis) throws IOException { + + init(collectionName, zkHost, modelID, cacheMillis); + } + + + public ModelStream(StreamExpression expression, StreamFactory factory) throws IOException{ + // grab all parameters out + String collectionName = factory.getValueOperand(expression, 0); + List namedParams = factory.getNamedOperands(expression); + StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost"); + + // Collection Name + if(null == collectionName){ + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression)); + } + + // Named parameters - passed directly to solr as solrparams + if(0 == namedParams.size()){ + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression)); + } + + Map params = new HashMap(); + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(!namedParam.getName().equals("zkHost")) { + params.put(namedParam.getName(), namedParam.getParameter().toString().trim()); + } + } + + String modelID = params.get("id"); + if (modelID == null) { + throw new IOException("id param cannot be null for ModelStream"); + } + + long cacheMillis = 300000; + String cacheMillisParam = params.get("cacheMillis"); + + if(cacheMillisParam != null) { + cacheMillis = Long.parseLong(cacheMillisParam); + } + + // zkHost, optional - if not provided then will look into factory list to get + String zkHost = null; + if(null == zkHostExpression) { + zkHost = factory.getCollectionZkHost(collectionName); + } + else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){ + zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue(); + } + + if (zkHost == null) { + zkHost = factory.getDefaultZkHost(); + } + + if(null == zkHost){ + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName)); + } + + // We've got all the required items + init(collectionName, zkHost, modelID, cacheMillis); + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + return toExpression(factory, true); + } + + private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException { + // function name + StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); + // collection + expression.addParameter(collection); + + // zkHost + expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost)); + expression.addParameter(new StreamExpressionNamedParameter("id", modelID)); + expression.addParameter(new StreamExpressionNamedParameter("cacheMillis", Long.toString(cacheMillis))); + + return expression; + } + + private void init(String collectionName, + String zkHost, + String modelID, + long cacheMillis) throws IOException { + this.zkHost = zkHost; + this.collection = collectionName; + this.modelID = modelID; + this.cacheMillis = cacheMillis; + } + + public void setStreamContext(StreamContext context) { + this.solrClientCache = context.getSolrClientCache(); + this.modelCache = context.getModelCache(); + } + + public void open() throws IOException { + this.model = modelCache.getModel(collection, modelID, cacheMillis); + } + + public List children() { + List l = new ArrayList(); + return l; + } + + public void close() throws IOException { + + } + + /** Return the stream sort - ie, the order in which records are returned */ + public StreamComparator getStreamSort(){ + return null; + } + + @Override + public Explanation toExplanation(StreamFactory factory) throws IOException { + StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString()); + explanation.setFunctionName(factory.getFunctionName(this.getClass())); + explanation.setImplementingClass(this.getClass().getName()); + explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL); + explanation.setExpression(toExpression(factory).toString()); + + return explanation; + } + + public Tuple read() throws IOException { + Tuple tuple = null; + + if(model != null) { + tuple = model; + model = null; + } else { + Map map = new HashMap(); + map.put("EOF", true); + tuple = new Tuple(map); + } + + return tuple; + } +} diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java index 8ca808feb68..6cbf09053fe 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java @@ -19,6 +19,8 @@ package org.apache.solr.client.solrj.io.stream; import java.io.Serializable; import java.util.Map; import java.util.HashMap; + +import org.apache.solr.client.solrj.io.ModelCache; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; @@ -37,6 +39,7 @@ public class StreamContext implements Serializable{ public int workerID; public int numWorkers; private SolrClientCache clientCache; + private ModelCache modelCache; private StreamFactory streamFactory; public Object get(Object key) { @@ -55,10 +58,18 @@ public class StreamContext implements Serializable{ this.clientCache = clientCache; } + public void setModelCache(ModelCache modelCache) { + this.modelCache = modelCache; + } + public SolrClientCache getSolrClientCache() { return this.clientCache; } + public ModelCache getModelCache() { + return this.modelCache; + } + public void setStreamFactory(StreamFactory streamFactory) { this.streamFactory = streamFactory; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java index 97317a05698..d81391d4210 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TopicStream.java @@ -72,6 +72,7 @@ public class TopicStream extends CloudSolrStream implements Expressible { private long count; private int runCount; + private boolean initialRun = true; private String id; protected long checkpointEvery; private Map checkpoints = new HashMap(); @@ -350,9 +351,14 @@ public class TopicStream extends CloudSolrStream implements Expressible { } public void close() throws IOException { - runCount = 0; try { - persistCheckpoints(); + + if (initialRun || runCount > 0) { + persistCheckpoints(); + initialRun = false; + runCount = 0; + } + } finally { if(solrStreams != null) { diff --git a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml index 32068110394..c70b9fd5857 100644 --- a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml +++ b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml @@ -50,7 +50,7 @@ - +