Merge remote-tracking branch 'origin/master' into gradle-master

This commit is contained in:
Dawid Weiss 2019-12-09 22:37:08 +01:00
commit 511bcaa4c1
10 changed files with 148 additions and 66 deletions

View File

@ -187,11 +187,11 @@ public class BM25Similarity extends Similarity {
float[] cache = new float[256];
for (int i = 0; i < cache.length; i++) {
cache[i] = k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl);
cache[i] = 1f / (k1 * ((1 - b) + b * LENGTH_TABLE[i] / avgdl));
}
return new BM25Scorer(boost, k1, b, idf, avgdl, cache);
}
/** Collection statistics for the BM25 model. */
private static class BM25Scorer extends SimScorer {
/** query boost */
@ -221,8 +221,17 @@ public class BM25Similarity extends Similarity {
@Override
public float score(float freq, long encodedNorm) {
double norm = cache[((byte) encodedNorm) & 0xFF];
return weight * (float) (freq / (freq + norm));
// In order to guarantee monotonicity with both freq and norm without
// promoting to doubles, we rewrite freq / (freq + norm) to
// 1 - 1 / (1 + freq * 1/norm).
// freq * 1/norm is guaranteed to be monotonic for both freq and norm due
// to the fact that multiplication and division round to the nearest
// float. And then monotonicity is preserved through composition via
// x -> 1 + x and x -> 1 - 1/x.
// Finally we expand weight * (1 - 1 / (1 + freq * 1/norm)) to
// weight - weight / (1 + freq * 1/norm), which runs slightly faster.
float normInverse = cache[((byte) encodedNorm) & 0xFF];
return weight - weight / (1f + freq * normInverse);
}
@Override
@ -230,8 +239,11 @@ public class BM25Similarity extends Similarity {
List<Explanation> subs = new ArrayList<>(explainConstantFactors());
Explanation tfExpl = explainTF(freq, encodedNorm);
subs.add(tfExpl);
return Explanation.match(weight * tfExpl.getValue().floatValue(),
"score(freq="+freq.getValue()+"), product of:", subs);
float normInverse = cache[((byte) encodedNorm) & 0xFF];
// not using "product of" since the rewrite that we do in score()
// introduces a small rounding error that CheckHits complains about
return Explanation.match(weight - weight / (1f + freq.getValue().floatValue() * normInverse),
"score(freq="+freq.getValue()+"), computed as boost * idf * tf from:", subs);
}
private Explanation explainTF(Explanation freq, long norm) {
@ -246,9 +258,9 @@ public class BM25Similarity extends Similarity {
subs.add(Explanation.match(doclen, "dl, length of field"));
}
subs.add(Explanation.match(avgdl, "avgdl, average length of field"));
float normValue = k1 * ((1 - b) + b * doclen / avgdl);
float normInverse = 1f / (k1 * ((1 - b) + b * doclen / avgdl));
return Explanation.match(
(float) (freq.getValue().floatValue() / (freq.getValue().floatValue() + (double) normValue)),
1f - 1f / (1 + freq.getValue().floatValue() * normInverse),
"tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", subs);
}

View File

@ -115,6 +115,11 @@ public class Facet2DStream extends TupleStream implements Expressible {
params.add(namedParam.getName(), namedParam.getParameter().toString().trim());
}
}
if(params.get("q") == null) {
params.set("q", "*:*");
}
Bucket x = null;
if (bucketXExpression != null) {
if (bucketXExpression.getParameter() instanceof StreamExpressionValue) {
@ -148,8 +153,8 @@ public class Facet2DStream extends TupleStream implements Expressible {
String bucketSortString = metric.getIdentifier() + " desc";
FieldComparator bucketSort = parseBucketSort(bucketSortString, x, y);
int dimensionX = 0;
int dimensionY = 0;
int dimensionX = 10;
int dimensionY = 10;
if (dimensionsExpression != null) {
if (dimensionsExpression.getParameter() instanceof StreamExpressionValue) {
String[] strDimensions = ((StreamExpressionValue) dimensionsExpression.getParameter()).getValue().split(",");

View File

@ -47,6 +47,7 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.Bucket;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.params.ModifiableSolrParams;
@ -148,6 +149,10 @@ public class FacetStream extends TupleStream implements Expressible {
}
}
if(params.get("q") == null) {
params.set("q", "*:*");
}
// buckets, required - comma separated
Bucket[] buckets = null;
if(null != bucketExpression){
@ -166,12 +171,31 @@ public class FacetStream extends TupleStream implements Expressible {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one bucket expected. eg. 'buckets=\"name\"'",expression,collectionName));
}
// Construct the metrics
Metric[] metrics = new Metric[metricExpressions.size()];
for(int idx = 0; idx < metricExpressions.size(); ++idx) {
metrics[idx] = factory.constructMetric(metricExpressions.get(idx));
}
if(metrics.length == 0) {
metrics = new Metric[1];
metrics[0] = new CountMetric();
}
String bucketSortString = null;
if(bucketSortExpression == null) {
throw new IOException("The bucketSorts parameter is required for the facet function.");
bucketSortString = metrics[0].getIdentifier()+" desc";
} else {
bucketSortString = ((StreamExpressionValue)bucketSortExpression.getParameter()).getValue();
if(bucketSortString.contains("(") &&
metricExpressions.size() == 0 &&
(!bucketSortExpression.equals("count(*) desc") &&
!bucketSortExpression.equals("count(*) asc"))) {
//Attempting bucket sort on a metric that is not going to be calculated.
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - the bucketSort is being performed on a metric that is not being calculated.",expression,collectionName));
}
}
FieldComparator[] bucketSorts = parseBucketSorts(bucketSortString, buckets);
@ -180,15 +204,7 @@ public class FacetStream extends TupleStream implements Expressible {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one bucket sort expected. eg. 'bucketSorts=\"name asc\"'",expression,collectionName));
}
// Construct the metrics
Metric[] metrics = new Metric[metricExpressions.size()];
for(int idx = 0; idx < metricExpressions.size(); ++idx) {
metrics[idx] = factory.constructMetric(metricExpressions.get(idx));
}
if(0 == metrics.length) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one metric expected.",expression,collectionName));
}
boolean refine = false;
@ -210,7 +226,7 @@ public class FacetStream extends TupleStream implements Expressible {
methodStr = ((StreamExpressionValue) methodExpression.getParameter()).getValue();
}
int overfetchInt = 150;
int overfetchInt = 250;
if(overfetchExpression != null) {
String overfetchStr = ((StreamExpressionValue) overfetchExpression.getParameter()).getValue();
overfetchInt = Integer.parseInt(overfetchStr);

View File

@ -53,11 +53,6 @@ public class RandomFacadeStream extends TupleStream implements Expressible {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
// Named parameters - passed directly to solr as solrparams
if(0 == namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
}
// pull out known named params
Map<String,String> params = new HashMap<String,String>();
for(StreamExpressionNamedParameter namedParam : namedParams){
@ -66,6 +61,20 @@ public class RandomFacadeStream extends TupleStream implements Expressible {
}
}
//Add sensible defaults
if(!params.containsKey("q")) {
params.put("q", "*:*");
}
if(!params.containsKey("fl")) {
params.put("fl", "*");
}
if(!params.containsKey("rows")) {
params.put("rows", "500");
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){

View File

@ -88,10 +88,6 @@ public class RandomStream extends TupleStream implements Expressible {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
// Named parameters - passed directly to solr as solrparams
if(0 == namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
}
// pull out known named params
Map<String,String> params = new HashMap<String,String>();
@ -101,6 +97,7 @@ public class RandomStream extends TupleStream implements Expressible {
}
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){

View File

@ -91,21 +91,12 @@ public class StatsStream extends TupleStream implements Expressible {
List<StreamExpression> metricExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Metric.class);
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
// Validate there are no unknown parameters - zkHost is namedParameter so we don't need to count it twice
if(expression.getParameters().size() != 1 + namedParams.size() + metricExpressions.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
}
// Collection Name
if(null == collectionName){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
// Named parameters - passed directly to solr as solrparams
if(0 == namedParams.size()){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
}
ModifiableSolrParams params = new ModifiableSolrParams();
for(StreamExpressionNamedParameter namedParam : namedParams){
if(!namedParam.getName().equals("zkHost")){
@ -113,6 +104,10 @@ public class StatsStream extends TupleStream implements Expressible {
}
}
if(params.get("q") == null) {
params.set("q", "*:*");
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){

View File

@ -45,6 +45,7 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParamete
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.common.params.ModifiableSolrParams;
@ -100,34 +101,36 @@ public class TimeSeriesStream extends TupleStream implements Expressible {
StreamExpressionNamedParameter fieldExpression = factory.getNamedOperand(expression, "field");
StreamExpressionNamedParameter gapExpression = factory.getNamedOperand(expression, "gap");
StreamExpressionNamedParameter formatExpression = factory.getNamedOperand(expression, "format");
StreamExpressionNamedParameter qExpression = factory.getNamedOperand(expression, "q");
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
List<StreamExpression> metricExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, Metric.class);
if(qExpression == null) {
throw new IOException("The timeseries expression requires the q parameter");
}
String start = null;
if(startExpression != null) {
start = ((StreamExpressionValue)startExpression.getParameter()).getValue();
} else {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - start parameter is required",expression));
}
String end = null;
if(endExpression != null) {
end = ((StreamExpressionValue)endExpression.getParameter()).getValue();
} else {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - end parameter is required",expression));
}
String gap = null;
if(gapExpression != null) {
gap = ((StreamExpressionValue)gapExpression.getParameter()).getValue();
}
} else {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - gap parameter is required",expression));
}
String field = null;
if(fieldExpression != null) {
field = ((StreamExpressionValue)fieldExpression.getParameter()).getValue();
} else {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - field parameter is required",expression));
}
String format = null;
@ -146,13 +149,15 @@ public class TimeSeriesStream extends TupleStream implements Expressible {
}
// Construct the metrics
Metric[] metrics = new Metric[metricExpressions.size()];
for(int idx = 0; idx < metricExpressions.size(); ++idx){
metrics[idx] = factory.constructMetric(metricExpressions.get(idx));
}
if(0 == metrics.length){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one metric expected.",expression,collectionName));
Metric[] metrics = null;
if(metricExpressions.size() > 0) {
metrics = new Metric[metricExpressions.size()];
for(int idx = 0; idx < metricExpressions.size(); ++idx){
metrics[idx] = factory.constructMetric(metricExpressions.get(idx));
}
} else {
metrics = new Metric[1];
metrics[0] = new CountMetric();
}
// pull out known named params
@ -163,6 +168,10 @@ public class TimeSeriesStream extends TupleStream implements Expressible {
}
}
if(params.get("q") == null) {
params.set("q", "*:*");
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){

View File

@ -237,10 +237,10 @@ public class UpdateStream extends TupleStream implements Expressible {
private int extractBatchSize(StreamExpression expression, StreamFactory factory) throws IOException {
StreamExpressionNamedParameter batchSizeParam = factory.getNamedOperand(expression, "batchSize");
if(null == batchSizeParam || null == batchSizeParam.getParameter() || !(batchSizeParam.getParameter() instanceof StreamExpressionValue)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting a 'batchSize' parameter of type positive integer but didn't find one",expression));
if(batchSizeParam == null) {
// Sensible default batch size
return 250;
}
String batchSizeStr = ((StreamExpressionValue)batchSizeParam.getParameter()).getValue();
return parseBatchSize(batchSizeStr, expression);
}

View File

@ -664,12 +664,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assert (tuples4.size() == 1);
sParams = new ModifiableSolrParams(StreamingTest.mapParams(CommonParams.QT, "/stream"));
sParams.add("expr", "random(" + COLLECTIONORALIAS + ", q=\"*:*\", rows=\"10001\", fl=\"id, a_i\")");
sParams.add("expr", "random(" + COLLECTIONORALIAS + ")");
jetty = cluster.getJettySolrRunner(0);
solrStream = new SolrStream(jetty.getBaseUrl().toString() + "/collection1", sParams);
tuples4 = getTuples(solrStream);
assert (tuples4.size() == 1000);
assert(tuples4.size() == 500);
Map fields = tuples4.get(0).fields;
assert(fields.containsKey("id"));
assert(fields.containsKey("a_f"));
assert(fields.containsKey("a_i"));
assert(fields.containsKey("a_s"));
} finally {
cache.close();
}
@ -792,6 +796,41 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 10);
//Test without query
expr = "stats(" + COLLECTIONORALIAS + ", sum(a_i), sum(a_f), min(a_i), min(a_f), max(a_i), max(a_f), avg(a_i), avg(a_f), count(*))";
expression = StreamExpressionParser.parse(expr);
stream = factory.constructStream(expression);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 1);
//Test Long and Double Sums
tuple = tuples.get(0);
sumi = tuple.getDouble("sum(a_i)");
sumf = tuple.getDouble("sum(a_f)");
mini = tuple.getDouble("min(a_i)");
minf = tuple.getDouble("min(a_f)");
maxi = tuple.getDouble("max(a_i)");
maxf = tuple.getDouble("max(a_f)");
avgi = tuple.getDouble("avg(a_i)");
avgf = tuple.getDouble("avg(a_f)");
count = tuple.getDouble("count(*)");
assertTrue(sumi.longValue() == 70);
assertTrue(sumf.doubleValue() == 55.0D);
assertTrue(mini.doubleValue() == 0.0D);
assertTrue(minf.doubleValue() == 1.0D);
assertTrue(maxi.doubleValue() == 14.0D);
assertTrue(maxf.doubleValue() == 10.0D);
assertTrue(avgi.doubleValue() == 7.0D);
assertTrue(avgf.doubleValue() == 5.5D);
assertTrue(count.doubleValue() == 10);
//Test with shards parameter
List<String> shardUrls = TupleStream.getShards(cluster.getZkServer().getZkAddress(), COLLECTIONORALIAS, streamContext);
@ -939,7 +978,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
paramsLoc = new ModifiableSolrParams();
expr = "facet2D(collection1, q=\"*:*\", x=\"diseases_s\", y=\"symptoms_s\", dimensions=\"3,1\")";
expr = "facet2D(collection1, x=\"diseases_s\", y=\"symptoms_s\", dimensions=\"3,1\")";
paramsLoc.set("expr", expr);
paramsLoc.set("qt", "/stream");
@ -2780,7 +2819,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertTrue(tuples.get(3).get("term_s").equals("f"));
// update
expression = StreamExpressionParser.parse("update(destinationCollection, batchSize=5, " + featuresExpression + ")");
expression = StreamExpressionParser.parse("update(destinationCollection, " + featuresExpression + ")");
stream = new UpdateStream(expression, factory);
stream.setStreamContext(streamContext);
getTuples(stream);

View File

@ -296,7 +296,7 @@ public class StreamExpressionToExpessionTest extends SolrTestCase {
assertTrue(expressionString.contains("bucketSorts=\"sum(a_i) asc\""));
assertTrue(expressionString.contains("rows=10"));
assertTrue(expressionString.contains("offset=0"));
assertTrue(expressionString.contains("overfetch=150"));
assertTrue(expressionString.contains("overfetch=250"));
assertTrue(expressionString.contains("sum(a_i)"));
assertTrue(expressionString.contains("sum(a_f)"));
assertTrue(expressionString.contains("min(a_i)"));
@ -306,8 +306,8 @@ public class StreamExpressionToExpessionTest extends SolrTestCase {
assertTrue(expressionString.contains("avg(a_i,false)"));
assertTrue(expressionString.contains("avg(a_f,false)"));
assertTrue(expressionString.contains("count(*)"));
assertEquals(stream.getOverfetch(), 150);
assertEquals(stream.getBucketSizeLimit(), 160);
assertEquals(stream.getOverfetch(), 250);
assertEquals(stream.getBucketSizeLimit(), 260);
assertEquals(stream.getRows(), 10);
assertEquals(stream.getOffset(), 0);
}
@ -332,7 +332,7 @@ public class StreamExpressionToExpessionTest extends SolrTestCase {
assertTrue(!expressionString.contains("bucketSizeLimit"));
assertTrue(expressionString.contains("rows=10"));
assertTrue(expressionString.contains("offset=0"));
assertTrue(expressionString.contains("overfetch=150"));
assertTrue(expressionString.contains("overfetch=250"));
assertTrue(expressionString.contains("method=dvhash"));
assertTrue(expressionString.contains("sum(a_i)"));
assertTrue(expressionString.contains("sum(a_f)"));
@ -343,10 +343,10 @@ public class StreamExpressionToExpessionTest extends SolrTestCase {
assertTrue(expressionString.contains("avg(a_i,false)"));
assertTrue(expressionString.contains("avg(a_f,false)"));
assertTrue(expressionString.contains("count(*)"));
assertEquals(stream.getBucketSizeLimit(), 160);
assertEquals(stream.getBucketSizeLimit(), 260);
assertEquals(stream.getRows(), 10);
assertEquals(stream.getOffset(), 0);
assertEquals(stream.getOverfetch(), 150);
assertEquals(stream.getOverfetch(), 250);
}