N-gram metrics overhaul & UI improvements
This commit is contained in:
@@ -15,6 +15,12 @@ pub struct KeyStat {
|
||||
pub error_count: usize,
|
||||
#[serde(default)]
|
||||
pub total_count: usize,
|
||||
#[serde(default = "default_error_rate_ema")]
|
||||
pub error_rate_ema: f64,
|
||||
}
|
||||
|
||||
fn default_error_rate_ema() -> f64 {
|
||||
0.5
|
||||
}
|
||||
|
||||
impl Default for KeyStat {
|
||||
@@ -27,6 +33,7 @@ impl Default for KeyStat {
|
||||
recent_times: Vec::new(),
|
||||
error_count: 0,
|
||||
total_count: 0,
|
||||
error_rate_ema: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -67,6 +74,13 @@ impl KeyStatsStore {
|
||||
if stat.recent_times.len() > 30 {
|
||||
stat.recent_times.remove(0);
|
||||
}
|
||||
|
||||
// Update error rate EMA (correct stroke = 0.0 signal)
|
||||
if stat.total_count == 1 {
|
||||
stat.error_rate_ema = 0.0;
|
||||
} else {
|
||||
stat.error_rate_ema = EMA_ALPHA * 0.0 + (1.0 - EMA_ALPHA) * stat.error_rate_ema;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_confidence(&self, key: char) -> f64 {
|
||||
@@ -84,13 +98,20 @@ impl KeyStatsStore {
|
||||
let stat = self.stats.entry(key).or_default();
|
||||
stat.error_count += 1;
|
||||
stat.total_count += 1;
|
||||
|
||||
// Update error rate EMA (error stroke = 1.0 signal)
|
||||
if stat.total_count == 1 {
|
||||
stat.error_rate_ema = 1.0;
|
||||
} else {
|
||||
stat.error_rate_ema = EMA_ALPHA * 1.0 + (1.0 - EMA_ALPHA) * stat.error_rate_ema;
|
||||
}
|
||||
}
|
||||
|
||||
/// Laplace-smoothed error rate: (errors + 1) / (total + 2).
|
||||
/// EMA-based error rate for a key.
|
||||
pub fn smoothed_error_rate(&self, key: char) -> f64 {
|
||||
match self.stats.get(&key) {
|
||||
Some(s) => (s.error_count as f64 + 1.0) / (s.total_count as f64 + 2.0),
|
||||
None => 0.5, // (0 + 1) / (0 + 2) = 0.5
|
||||
Some(s) => s.error_rate_ema,
|
||||
None => 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,4 +163,50 @@ mod tests {
|
||||
"confidence should be < 1.0 for slow typing, got {conf}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_error_rate_correct_strokes() {
|
||||
let mut store = KeyStatsStore::default();
|
||||
// All correct strokes → EMA should be 0.0 for first, stay near 0
|
||||
store.update_key('a', 200.0);
|
||||
assert!((store.smoothed_error_rate('a') - 0.0).abs() < f64::EPSILON);
|
||||
for _ in 0..10 {
|
||||
store.update_key('a', 200.0);
|
||||
}
|
||||
assert!(
|
||||
store.smoothed_error_rate('a') < 0.01,
|
||||
"All correct → EMA near 0"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_error_rate_error_strokes() {
|
||||
let mut store = KeyStatsStore::default();
|
||||
// First stroke is error
|
||||
store.update_key_error('b');
|
||||
assert!((store.smoothed_error_rate('b') - 1.0).abs() < f64::EPSILON);
|
||||
// Follow with correct strokes → EMA decays
|
||||
for _ in 0..20 {
|
||||
store.update_key('b', 200.0);
|
||||
}
|
||||
let rate = store.smoothed_error_rate('b');
|
||||
assert!(
|
||||
rate < 0.15,
|
||||
"After 20 correct, EMA should be < 0.15, got {rate}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_error_rate_default_for_missing_key() {
|
||||
let store = KeyStatsStore::default();
|
||||
assert!((store.smoothed_error_rate('z') - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ema_error_rate_serde_default() {
|
||||
// Verify backward compat: deserializing old data without error_rate_ema gets 0.5
|
||||
let json = r#"{"filtered_time_ms":200.0,"best_time_ms":200.0,"confidence":1.0,"sample_count":10,"recent_times":[],"error_count":2,"total_count":10}"#;
|
||||
let stat: KeyStat = serde_json::from_str(json).unwrap();
|
||||
assert!((stat.error_rate_ema - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,4 +5,4 @@ pub mod ngram_stats;
|
||||
pub mod scoring;
|
||||
pub mod skill_tree;
|
||||
|
||||
pub use ngram_stats::FocusTarget;
|
||||
pub use ngram_stats::FocusSelection;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -567,9 +567,7 @@ impl SkillTree {
|
||||
let newly_mastered: Vec<char> = if let Some(before) = before_stats {
|
||||
before_unlocked
|
||||
.iter()
|
||||
.filter(|&&ch| {
|
||||
before.get_confidence(ch) < 1.0 && stats.get_confidence(ch) >= 1.0
|
||||
})
|
||||
.filter(|&&ch| before.get_confidence(ch) < 1.0 && stats.get_confidence(ch) >= 1.0)
|
||||
.copied()
|
||||
.collect()
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user