[ML] Restore categoriser state after the anomaly detector (elastic/x-pack-elasticsearch#993)
Original commit: elastic/x-pack-elasticsearch@fc4205f1d6
This commit is contained in:
parent
e339cf82df
commit
9a9ae5edc7
|
@ -882,7 +882,27 @@ public class JobProvider {
|
|||
public void restoreStateToStream(String jobId, ModelSnapshot modelSnapshot, OutputStream restoreStream) throws IOException {
|
||||
String indexName = AnomalyDetectorsIndex.jobStateIndexName();
|
||||
|
||||
// First try to restore categorizer state. There are no snapshots for this, so the IDs simply
|
||||
|
||||
// First try to restore model state.
|
||||
int numDocs = modelSnapshot.getSnapshotDocCount();
|
||||
for (int docNum = 1; docNum <= numDocs; ++docNum) {
|
||||
String docId = String.format(Locale.ROOT, "%s#%d", ModelSnapshot.documentId(modelSnapshot), docNum);
|
||||
|
||||
LOGGER.trace("ES API CALL: get ID {} type {} from index {}", docId, ModelState.TYPE, indexName);
|
||||
|
||||
GetResponse stateResponse = client.prepareGet(indexName, ModelState.TYPE.getPreferredName(), docId).get();
|
||||
if (!stateResponse.isExists()) {
|
||||
LOGGER.error("Expected {} documents for model state for {} snapshot {} but failed to find {}",
|
||||
numDocs, jobId, modelSnapshot.getSnapshotId(), docId);
|
||||
break;
|
||||
}
|
||||
writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream);
|
||||
}
|
||||
|
||||
|
||||
// Secondly try to restore categorizer state. This must come after model state because that's
|
||||
// the order the C++ process expects.
|
||||
// There are no snapshots for this, so the IDs simply
|
||||
// count up until a document is not found. It's NOT an error to have no categorizer state.
|
||||
int docNum = 0;
|
||||
while (true) {
|
||||
|
@ -897,22 +917,6 @@ public class JobProvider {
|
|||
writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream);
|
||||
}
|
||||
|
||||
// Finally try to restore model state. This must come after categorizer state because that's
|
||||
// the order the C++ process expects.
|
||||
int numDocs = modelSnapshot.getSnapshotDocCount();
|
||||
for (docNum = 1; docNum <= numDocs; ++docNum) {
|
||||
String docId = String.format(Locale.ROOT, "%s#%d", ModelSnapshot.documentId(modelSnapshot), docNum);
|
||||
|
||||
LOGGER.trace("ES API CALL: get ID {} type {} from index {}", docId, ModelState.TYPE, indexName);
|
||||
|
||||
GetResponse stateResponse = client.prepareGet(indexName, ModelState.TYPE.getPreferredName(), docId).get();
|
||||
if (!stateResponse.isExists()) {
|
||||
LOGGER.error("Expected {} documents for model state for {} snapshot {} but failed to find {}",
|
||||
numDocs, jobId, modelSnapshot.getSnapshotId(), docId);
|
||||
break;
|
||||
}
|
||||
writeStateToStream(stateResponse.getSourceAsBytesRef(), restoreStream);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeStateToStream(BytesReference source, OutputStream stream) throws IOException {
|
||||
|
|
|
@ -865,13 +865,6 @@ public class JobProviderTests extends ESTestCase {
|
|||
assertTrue(queryString.matches("(?s).*snapshot_id.*value. : .snappyId.*description.*value. : .description1.*"));
|
||||
}
|
||||
|
||||
private AnomalyRecord createAnomalyRecord(String partitionFieldValue, Date timestamp, double recordScore) {
|
||||
AnomalyRecord record = new AnomalyRecord("foo", timestamp, 600, 42);
|
||||
record.setPartitionFieldValue(partitionFieldValue);
|
||||
record.setRecordScore(recordScore);
|
||||
return record;
|
||||
}
|
||||
|
||||
public void testRestoreStateToStream() throws Exception {
|
||||
Map<String, Object> categorizerState = new HashMap<>();
|
||||
categorizerState.put("catName", "catVal");
|
||||
|
@ -901,9 +894,9 @@ public class JobProviderTests extends ESTestCase {
|
|||
|
||||
String[] restoreData = stream.toString(StandardCharsets.UTF_8.name()).split("\0");
|
||||
assertEquals(3, restoreData.length);
|
||||
assertEquals("{\"catName\":\"catVal\"}", restoreData[0]);
|
||||
assertEquals("{\"modName\":\"modVal1\"}", restoreData[1]);
|
||||
assertEquals("{\"modName\":\"modVal2\"}", restoreData[2]);
|
||||
assertEquals("{\"modName\":\"modVal1\"}", restoreData[0]);
|
||||
assertEquals("{\"modName\":\"modVal2\"}", restoreData[1]);
|
||||
assertEquals("{\"catName\":\"catVal\"}", restoreData[2]);
|
||||
}
|
||||
|
||||
public void testViolatedFieldCountLimit() throws Exception {
|
||||
|
|
Loading…
Reference in New Issue