Slope One refactoring (#926)

* @Async and Spring Security

* @Async with SecurityContext propagated

* Spring and @Async

* Simulated Annealing algorithm

* Simulated Annealing algorithm

* Rebase

* Rebase

* SA further fixes

* Slope One plus package refactoring

* SlopeOne refactoring
This commit is contained in:
maibin 2016-12-26 21:21:17 +01:00 committed by GitHub
parent bc518a5bc1
commit ed89b85f18
1 changed files with 46 additions and 42 deletions

View File

@ -11,8 +11,8 @@ import java.util.Map.Entry;
*/
public class SlopeOne {
private static Map<Item, Map<Item, Double>> differencesMatrix = new HashMap<>();
private static Map<Item, Map<Item, Integer>> frequenciesMatrix = new HashMap<>();
private static Map<Item, Map<Item, Double>> diff = new HashMap<>();
private static Map<Item, Map<Item, Integer>> freq = new HashMap<>();
private static Map<User, HashMap<Item, Double>> inputData;
private static Map<User, HashMap<Item, Double>> outputData = new HashMap<>();
@ -28,33 +28,36 @@ public class SlopeOne {
* Based on the available data, calculate the relationships between the
* items and number of occurences
*
* @param data existing user data and their items' ratings
* @param data
* existing user data and their items' ratings
*/
private static void buildDifferencesMatrix(Map<User, HashMap<Item, Double>> data) {
for (HashMap<Item, Double> user : data.values()) {
for (Entry<Item, Double> entry : user.entrySet()) {
if (!differencesMatrix.containsKey(entry.getKey())) {
differencesMatrix.put(entry.getKey(), new HashMap<Item, Double>());
frequenciesMatrix.put(entry.getKey(), new HashMap<Item, Integer>());
for (Entry<Item, Double> e : user.entrySet()) {
if (!diff.containsKey(e.getKey())) {
diff.put(e.getKey(), new HashMap<Item, Double>());
freq.put(e.getKey(), new HashMap<Item, Integer>());
}
for (Entry<Item, Double> entry2 : user.entrySet()) {
for (Entry<Item, Double> e2 : user.entrySet()) {
int oldCount = 0;
if (frequenciesMatrix.get(entry.getKey()).containsKey(entry2.getKey()))
oldCount = frequenciesMatrix.get(entry.getKey()).get(entry2.getKey()).intValue();
if (freq.get(e.getKey()).containsKey(e2.getKey())) {
oldCount = freq.get(e.getKey()).get(e2.getKey()).intValue();
}
double oldDiff = 0.0;
if (differencesMatrix.get(entry.getKey()).containsKey(entry2.getKey()))
oldDiff = differencesMatrix.get(entry.getKey()).get(entry2.getKey()).doubleValue();
double observedDiff = entry.getValue() - entry2.getValue();
frequenciesMatrix.get(entry.getKey()).put(entry2.getKey(), oldCount + 1);
differencesMatrix.get(entry.getKey()).put(entry2.getKey(), oldDiff + observedDiff);
if (diff.get(e.getKey()).containsKey(e2.getKey())) {
oldDiff = diff.get(e.getKey()).get(e2.getKey()).doubleValue();
}
double observedDiff = e.getValue() - e2.getValue();
freq.get(e.getKey()).put(e2.getKey(), oldCount + 1);
diff.get(e.getKey()).put(e2.getKey(), oldDiff + observedDiff);
}
}
}
for (Item j : differencesMatrix.keySet()) {
for (Item i : differencesMatrix.get(j).keySet()) {
double oldvalue = differencesMatrix.get(j).get(i).doubleValue();
int count = frequenciesMatrix.get(j).get(i).intValue();
differencesMatrix.get(j).put(i, oldvalue / count);
for (Item j : diff.keySet()) {
for (Item i : diff.get(j).keySet()) {
double oldValue = diff.get(j).get(i).doubleValue();
int count = freq.get(j).get(i).intValue();
diff.get(j).put(i, oldValue / count);
}
}
printData(data);
@ -64,41 +67,42 @@ public class SlopeOne {
* Based on existing data predict all missing ratings. If prediction is not
* possible, the value will be equal to -1
*
* @param data existing user data and their items' ratings
* @param data
* existing user data and their items' ratings
*/
private static void predict(Map<User, HashMap<Item, Double>> data) {
HashMap<Item, Double> predictions = new HashMap<Item, Double>();
HashMap<Item, Integer> frequencies = new HashMap<Item, Integer>();
for (Item j : differencesMatrix.keySet()) {
frequencies.put(j, 0);
predictions.put(j, 0.0);
HashMap<Item, Double> uPred = new HashMap<Item, Double>();
HashMap<Item, Integer> uFreq = new HashMap<Item, Integer>();
for (Item j : diff.keySet()) {
uFreq.put(j, 0);
uPred.put(j, 0.0);
}
for (Entry<User, HashMap<Item, Double>> entry : data.entrySet()) {
for (Item j : entry.getValue().keySet()) {
for (Item k : differencesMatrix.keySet()) {
for (Entry<User, HashMap<Item, Double>> e : data.entrySet()) {
for (Item j : e.getValue().keySet()) {
for (Item k : diff.keySet()) {
try {
double newValue = (differencesMatrix.get(k).get(j).doubleValue()
+ entry.getValue().get(j).doubleValue()) * frequenciesMatrix.get(k).get(j).intValue();
predictions.put(k, predictions.get(k) + newValue);
frequencies.put(k, frequencies.get(k) + frequenciesMatrix.get(k).get(j).intValue());
} catch (NullPointerException e) {
double predictedValue = diff.get(k).get(j).doubleValue() + e.getValue().get(j).doubleValue();
double finalValue = predictedValue * freq.get(k).get(j).intValue();
uPred.put(k, uPred.get(k) + finalValue);
uFreq.put(k, uFreq.get(k) + freq.get(k).get(j).intValue());
} catch (NullPointerException e1) {
}
}
}
HashMap<Item, Double> cleanPredictions = new HashMap<Item, Double>();
for (Item j : predictions.keySet()) {
if (frequencies.get(j) > 0) {
cleanPredictions.put(j, predictions.get(j).doubleValue() / frequencies.get(j).intValue());
HashMap<Item, Double> clean = new HashMap<Item, Double>();
for (Item j : uPred.keySet()) {
if (uFreq.get(j) > 0) {
clean.put(j, uPred.get(j).doubleValue() / uFreq.get(j).intValue());
}
}
for (Item j : InputData.items) {
if (entry.getValue().containsKey(j)) {
cleanPredictions.put(j, entry.getValue().get(j));
if (e.getValue().containsKey(j)) {
clean.put(j, e.getValue().get(j));
} else {
cleanPredictions.put(j, -1.0);
clean.put(j, -1.0);
}
}
outputData.put(entry.getKey(), cleanPredictions);
outputData.put(e.getKey(), clean);
}
printData(outputData);
}