diff --git a/app/src/main/java/eu/faircode/email/MessageClassifier.java b/app/src/main/java/eu/faircode/email/MessageClassifier.java index aaedf12d69..7f66276f9b 100644 --- a/app/src/main/java/eu/faircode/email/MessageClassifier.java +++ b/app/src/main/java/eu/faircode/email/MessageClassifier.java @@ -39,7 +39,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.Date; import java.util.HashMap; -import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; @@ -282,47 +282,37 @@ public class MessageClassifier { for (String clazz : classFrequency.keySet()) { Frequency frequency = classFrequency.get(clazz); + if (frequency.count > 0) { + Stat stat = state.classStats.get(clazz); + if (stat == null) { + stat = new Stat(); + state.classStats.put(clazz, stat); + } - Stat stat = state.classStats.get(clazz); - if (stat == null) { - stat = new Stat(); - state.classStats.put(clazz, stat); + int c = frequency.count; + Integer b = (before == null ? null : frequency.before.get(before)); + Integer a = (after == null ? null : frequency.after.get(after)); + stat.totalFrequency += + ((b == null ? 0.0 : (double) b / c) + c + (a == null ? 0.0 : (double) a / c)) / 3; + + stat.matchedWords++; + if (stat.matchedWords > state.maxMatchedWords) + state.maxMatchedWords = stat.matchedWords; + + if (BuildConfig.DEBUG) + stat.words.add(current); } - - stat.matchedWords++; - - boolean b = (before != null && frequency.before.contains(before)); - boolean a = (after != null && frequency.after.contains(after)); - if (b && a) - stat.totalFrequency += frequency.count; - else if (b || a) - stat.totalFrequency += frequency.count * 0.5; - else - stat.totalFrequency += frequency.count * 0.25; - - if (BuildConfig.DEBUG) - stat.words.add(current); - - if (stat.matchedWords > state.maxMatchedWords) - state.maxMatchedWords = stat.matchedWords; } Frequency c = classFrequency.get(currentClass); if (c == null) c = new Frequency(); - c.count++; - if (before != null && !c.before.contains(before)) - c.before.add(before); - if (after != null && !c.after.contains(after)) - c.after.add(after); + c.add(before, after, 1); classFrequency.put(currentClass, c); } else { Frequency c = (classFrequency == null ? null : classFrequency.get(currentClass)); if (c != null) - if (c.count > 0) - c.count--; - else - classFrequency.remove(currentClass); + c.add(before, after, -1); } } @@ -411,11 +401,11 @@ public class MessageClassifier { return jroot; } - private static JSONArray from(HashSet list) { - JSONArray jarray = new JSONArray(); - for (String item : list) - jarray.put(item); - return jarray; + private static JSONObject from(Map map) throws JSONException { + JSONObject jmap = new JSONObject(); + for (String key : map.keySet()) + jmap.put(key, map.get(key)); + return jmap; } static void fromJson(JSONObject jroot) throws JSONException { @@ -443,30 +433,50 @@ public class MessageClassifier { Frequency f = new Frequency(); f.count = jword.getInt("frequency"); if (jword.has("before")) - f.before = from(jword.getJSONArray("before")); + f.before = from(jword.getJSONObject("before")); if (jword.has("after")) - f.after = from(jword.getJSONArray("after")); + f.after = from(jword.getJSONObject("after")); classFrequency.put(jword.getString("class"), f); } } - private static HashSet from(JSONArray jarray) throws JSONException { - HashSet result = new HashSet<>(jarray.length()); - for (int i = 0; i < jarray.length(); i++) - result.add((String) jarray.get(i)); + private static Map from(JSONObject jmap) throws JSONException { + Map result = new HashMap<>(jmap.length()); + Iterator iterator = jmap.keys(); + while (iterator.hasNext()) { + String key = iterator.next(); + result.put(key, jmap.getInt(key)); + } return result; } private static class State { - int maxMatchedWords = 0; - List words = new ArrayList<>(); - Map classStats = new HashMap<>(); + private int maxMatchedWords = 0; + private List words = new ArrayList<>(); + private Map classStats = new HashMap<>(); } private static class Frequency { - int count; - HashSet before = new HashSet<>(); - HashSet after = new HashSet<>(); + private int count = 0; + private Map before = new HashMap<>(); + private Map after = new HashMap<>(); + + private void add(String b, String a, int c) { + if (count + c < 0) + return; + + count += c; + + if (b != null) { + Integer x = before.get(b); + before.put(b, (x == null ? 0 : x) + c); + } + + if (a != null) { + Integer x = after.get(a); + after.put(a, (x == null ? 0 : x) + c); + } + } } private static class Stat { @@ -476,10 +486,10 @@ public class MessageClassifier { } private static class Chance { - String clazz; - Double chance; + private String clazz; + private Double chance; - Chance(String clazz, Double chance) { + private Chance(String clazz, Double chance) { this.clazz = clazz; this.chance = chance; }