SOLR-12780: Add support for Leaky ReLU and TanH activations in contrib/ltr NeuralNetworkModel class.

(Kamuela Lau, Christine Poerschke)
This commit is contained in:
Christine Poerschke 2018-10-12 17:08:35 +01:00
parent 42ac07d11b
commit 9c8ffabfe3
2 changed files with 34 additions and 3 deletions

View File

@ -144,6 +144,9 @@ New Features
* SOLR-12843: Implement a MultiContentWriter in SolrJ to post multiple files/payload at once (noble) * SOLR-12843: Implement a MultiContentWriter in SolrJ to post multiple files/payload at once (noble)
* SOLR-12780: Add support for Leaky ReLU and TanH activations in contrib/ltr NeuralNetworkModel class.
(Kamuela Lau, Christine Poerschke)
Other Changes Other Changes
---------------------- ----------------------

View File

@ -31,7 +31,7 @@ import org.apache.solr.util.SolrPluginUtils;
* A scoring model that computes document scores using a neural network. * A scoring model that computes document scores using a neural network.
* <p> * <p>
* Supported <a href="https://en.wikipedia.org/wiki/Activation_function">activation functions</a> are: * Supported <a href="https://en.wikipedia.org/wiki/Activation_function">activation functions</a> are:
* <code>identity</code>, <code>relu</code>, <code>sigmoid</code> and * <code>identity</code>, <code>relu</code>, <code>sigmoid</code>, <code>tanh</code>, <code>leakyrelu</code> and
* contributions to support additional activation functions are welcome. * contributions to support additional activation functions are welcome.
* <p> * <p>
* Example configuration: * Example configuration:
@ -60,8 +60,20 @@ import org.apache.solr.util.SolrPluginUtils;
"activation" : "relu" "activation" : "relu"
}, },
{ {
"matrix" : [ [ 27.0, 28.0 ] ], "matrix" : [ [ 27.0, 28.0 ],
"bias" : [ 29.0 ], [ 29.0, 30.0 ] ],
"bias" : [ 31.0, 32.0 ],
"activation" : "leakyrelu"
},
{
"matrix" : [ [ 33.0, 34.0 ],
[ 35.0, 36.0 ] ],
"bias" : [ 37.0, 38.0 ],
"activation" : "tanh"
},
{
"matrix" : [ [ 39.0, 40.0 ] ],
"bias" : [ 41.0 ],
"activation" : "identity" "activation" : "identity"
} }
] ]
@ -144,6 +156,22 @@ public class NeuralNetworkModel extends LTRScoringModel {
} }
}; };
break; break;
case "leakyrelu":
this.activation = new Activation() {
@Override
public float apply(float in) {
return in < 0 ? 0.01f * in : in;
}
};
break;
case "tanh":
this.activation = new Activation() {
@Override
public float apply(float in) {
return (float)Math.tanh(in);
}
};
break;
case "sigmoid": case "sigmoid":
this.activation = new Activation() { this.activation = new Activation() {
@Override @Override