N-gram error tracking for adaptive drill selection

This commit is contained in:
2026-02-24 14:55:51 -05:00
parent 0c5a70d5c4
commit e7f57dd497
11 changed files with 2244 additions and 10 deletions

View File

@@ -11,6 +11,11 @@ use rand::rngs::SmallRng;
use crate::config::Config;
use crate::engine::filter::CharFilter;
use crate::engine::key_stats::KeyStatsStore;
use crate::engine::FocusTarget;
use crate::engine::ngram_stats::{
self, BigramKey, BigramStatsStore, TrigramStatsStore, extract_ngram_events,
select_focus_target,
};
use crate::engine::scoring;
use crate::engine::skill_tree::{BranchId, BranchStatus, DrillScope, SkillTree};
use crate::generator::TextGenerator;
@@ -35,7 +40,7 @@ use crate::keyboard::display::BACKSPACE;
use crate::session::drill::DrillState;
use crate::session::input::{self, KeystrokeEvent};
use crate::session::result::DrillResult;
use crate::session::result::{DrillResult, KeyTime};
use crate::store::json_store::JsonStore;
use crate::store::schema::{DrillHistoryData, ExportData, KeyStatsData, ProfileData, EXPORT_VERSION};
use crate::ui::components::menu::Menu;
@@ -260,6 +265,13 @@ pub struct App {
pub keyboard_explorer_selected: Option<char>,
pub explorer_accuracy_cache_overall: Option<(char, usize, usize)>,
pub explorer_accuracy_cache_ranked: Option<(char, usize, usize)>,
pub bigram_stats: BigramStatsStore,
pub ranked_bigram_stats: BigramStatsStore,
pub trigram_stats: TrigramStatsStore,
pub ranked_trigram_stats: TrigramStatsStore,
pub user_median_transition_ms: f64,
pub transition_buffer: Vec<f64>,
pub trigram_gain_history: Vec<f64>,
rng: SmallRng,
transition_table: TransitionTable,
#[allow(dead_code)]
@@ -402,6 +414,13 @@ impl App {
keyboard_explorer_selected: None,
explorer_accuracy_cache_overall: None,
explorer_accuracy_cache_ranked: None,
bigram_stats: BigramStatsStore::default(),
ranked_bigram_stats: BigramStatsStore::default(),
trigram_stats: TrigramStatsStore::default(),
ranked_trigram_stats: TrigramStatsStore::default(),
user_median_transition_ms: 0.0,
transition_buffer: Vec::new(),
trigram_gain_history: Vec::new(),
rng: SmallRng::from_entropy(),
transition_table,
dictionary,
@@ -419,6 +438,9 @@ impl App {
});
}
// Rebuild n-gram stats from drill history
app.rebuild_ngram_stats();
app.start_drill();
app
}
@@ -591,6 +613,9 @@ impl App {
self.skill_tree = SkillTree::new(self.profile.skill_tree.clone());
self.keyboard_model = KeyboardModel::from_name(&self.config.keyboard_layout);
// Rebuild n-gram stats from imported drill history
self.rebuild_ngram_stats();
// Check theme availability
let theme_name = self.config.theme.clone();
let loaded_theme = Theme::load(&theme_name).unwrap_or_default();
@@ -633,7 +658,18 @@ impl App {
DrillMode::Adaptive => {
let scope = self.drill_scope;
let all_keys = self.skill_tree.unlocked_keys(scope);
let focused = self.skill_tree.focused_key(scope, &self.ranked_key_stats);
// Select focus target: single char or bigram
let focus_target = select_focus_target(
&self.skill_tree,
scope,
&self.ranked_key_stats,
&self.ranked_bigram_stats,
);
let (focused_char, focused_bigram) = match &focus_target {
FocusTarget::Char(ch) => (Some(*ch), None),
FocusTarget::Bigram(key) => (Some(key.0[0]), Some(key.clone())),
};
// Generate base lowercase text using only lowercase keys from scope
let lowercase_keys: Vec<char> = all_keys
@@ -643,7 +679,7 @@ impl App {
.collect();
let filter = CharFilter::new(lowercase_keys);
// Only pass focused to phonetic generator if it's a lowercase letter
let lowercase_focused = focused.filter(|ch| ch.is_ascii_lowercase());
let lowercase_focused = focused_char.filter(|ch| ch.is_ascii_lowercase());
let table = self.transition_table.clone();
let dict = Dictionary::load();
let rng = SmallRng::from_rng(&mut self.rng).unwrap();
@@ -658,7 +694,7 @@ impl App {
.collect();
if !cap_keys.is_empty() {
let mut rng = SmallRng::from_rng(&mut self.rng).unwrap();
text = capitalize::apply_capitalization(&text, &cap_keys, focused, &mut rng);
text = capitalize::apply_capitalization(&text, &cap_keys, focused_char, &mut rng);
}
// Apply punctuation if punctuation keys are in scope
@@ -674,7 +710,7 @@ impl App {
.collect();
if !punct_keys.is_empty() {
let mut rng = SmallRng::from_rng(&mut self.rng).unwrap();
text = punctuate::apply_punctuation(&text, &punct_keys, focused, &mut rng);
text = punctuate::apply_punctuation(&text, &punct_keys, focused_char, &mut rng);
}
// Apply numbers if digit keys are in scope
@@ -686,7 +722,7 @@ impl App {
if !digit_keys.is_empty() {
let has_dot = all_keys.contains(&'.');
let mut rng = SmallRng::from_rng(&mut self.rng).unwrap();
text = numbers::apply_numbers(&text, &digit_keys, has_dot, focused, &mut rng);
text = numbers::apply_numbers(&text, &digit_keys, has_dot, focused_char, &mut rng);
}
// Apply code symbols only if this drill is for the CodeSymbols branch,
@@ -734,7 +770,7 @@ impl App {
text = code_patterns::apply_code_symbols(
&text,
&symbol_keys,
focused,
focused_char,
&mut rng,
);
}
@@ -745,6 +781,11 @@ impl App {
text = insert_line_breaks(&text);
}
// After all generation: if bigram focus, swap some words for bigram-containing words
if let Some(ref bigram) = focused_bigram {
text = self.apply_bigram_focus(&text, &filter, bigram);
}
(text, None)
}
DrillMode::Code => {
@@ -843,15 +884,39 @@ impl App {
for kt in &result.per_key_times {
if kt.correct {
self.key_stats.update_key(kt.key, kt.time_ms);
} else {
self.key_stats.update_key_error(kt.key);
}
}
// Extract and update n-gram stats for all drill modes
let drill_index = self.drill_history.len() as u32;
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
}
if ranked {
for kt in &result.per_key_times {
if kt.correct {
self.ranked_key_stats.update_key(kt.key, kt.time_ms);
} else {
self.ranked_key_stats.update_key_error(kt.key);
}
}
for ev in &bigram_events {
self.ranked_bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.ranked_bigram_stats.update_redundancy_streak(&ev.key, &self.ranked_key_stats);
}
for ev in &trigram_events {
self.ranked_trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
}
let update = self
.skill_tree
.update(&self.ranked_key_stats, before_stats.as_ref());
@@ -919,6 +984,19 @@ impl App {
self.profile.last_practice_date = Some(today);
}
// Update transition buffer for hesitation baseline
self.update_transition_buffer(&result.per_key_times);
// Periodic trigram marginal gain analysis (every 50 drills)
if self.profile.total_drills % 50 == 0 && self.profile.total_drills > 0 {
let gain = ngram_stats::trigram_marginal_gain(
&self.ranked_trigram_stats,
&self.ranked_bigram_stats,
&self.ranked_key_stats,
);
self.trigram_gain_history.push(gain);
}
self.drill_history.push(result.clone());
if self.drill_history.len() > 500 {
self.drill_history.remove(0);
@@ -951,9 +1029,27 @@ impl App {
for kt in &result.per_key_times {
if kt.correct {
self.key_stats.update_key(kt.key, kt.time_ms);
} else {
self.key_stats.update_key_error(kt.key);
}
}
// Extract and update n-gram stats
let drill_index = self.drill_history.len() as u32;
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
}
// Update transition buffer for hesitation baseline
self.update_transition_buffer(&result.per_key_times);
self.drill_history.push(result.clone());
if self.drill_history.len() > 500 {
self.drill_history.remove(0);
@@ -983,6 +1079,177 @@ impl App {
}
}
/// Replace up to 40% of words with dictionary words containing the target bigram.
/// No more than 3 consecutive bigram-focused words to prevent repetitive feel.
fn apply_bigram_focus(&mut self, text: &str, filter: &CharFilter, bigram: &BigramKey) -> String {
let bigram_str: String = bigram.0.iter().collect();
let words: Vec<&str> = text.split(' ').collect();
if words.is_empty() {
return text.to_string();
}
// Find dictionary words that contain the bigram and pass the filter
let dict = Dictionary::load();
let candidates: Vec<&str> = dict
.find_matching(filter, None)
.into_iter()
.filter(|w| w.contains(&bigram_str))
.collect();
if candidates.is_empty() {
return text.to_string();
}
let max_replacements = (words.len() * 2 + 4) / 5; // ~40%
let mut replaced = 0;
let mut consecutive = 0;
let mut result_words: Vec<String> = Vec::with_capacity(words.len());
for word in &words {
let already_has = word.contains(&bigram_str);
if already_has {
consecutive += 1;
result_words.push(word.to_string());
continue;
}
if replaced < max_replacements && consecutive < 3 {
let candidate = candidates[self.rng.gen_range(0..candidates.len())];
result_words.push(candidate.to_string());
replaced += 1;
consecutive += 1;
} else {
consecutive = 0;
result_words.push(word.to_string());
}
}
result_words.join(" ")
}
/// Update the rolling transition buffer with new inter-keystroke intervals.
fn update_transition_buffer(&mut self, per_key_times: &[KeyTime]) {
for kt in per_key_times {
if kt.key == BACKSPACE {
continue;
}
self.transition_buffer.push(kt.time_ms);
}
// Keep only last 200 entries
if self.transition_buffer.len() > 200 {
let excess = self.transition_buffer.len() - 200;
self.transition_buffer.drain(..excess);
}
// Recompute median
let mut buf = self.transition_buffer.clone();
self.user_median_transition_ms = ngram_stats::compute_median(&mut buf);
}
/// Rebuild all n-gram stats and char-level error/total counts from drill history.
/// This is the sole source of truth for error_count/total_count on KeyStat
/// and all n-gram stores. Timing EMA on KeyStat is NOT touched here
/// (it is either loaded from disk or rebuilt by `rebuild_from_history`).
fn rebuild_ngram_stats(&mut self) {
// Reset n-gram stores
self.bigram_stats = BigramStatsStore::default();
self.bigram_stats.target_cpm = self.config.target_cpm();
self.ranked_bigram_stats = BigramStatsStore::default();
self.ranked_bigram_stats.target_cpm = self.config.target_cpm();
self.trigram_stats = TrigramStatsStore::default();
self.trigram_stats.target_cpm = self.config.target_cpm();
self.ranked_trigram_stats = TrigramStatsStore::default();
self.ranked_trigram_stats.target_cpm = self.config.target_cpm();
self.transition_buffer.clear();
self.user_median_transition_ms = 0.0;
// Reset char-level error/total counts (timing fields are untouched)
for stat in self.key_stats.stats.values_mut() {
stat.error_count = 0;
stat.total_count = 0;
}
for stat in self.ranked_key_stats.stats.values_mut() {
stat.error_count = 0;
stat.total_count = 0;
}
// Take drill_history out temporarily to avoid borrow conflict
let history = std::mem::take(&mut self.drill_history);
for (drill_index, result) in history.iter().enumerate() {
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
// Rebuild char-level error/total counts from history
for kt in &result.per_key_times {
if kt.correct {
let stat = self.key_stats.stats.entry(kt.key).or_default();
stat.total_count += 1;
} else {
self.key_stats.update_key_error(kt.key);
}
}
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
}
if result.ranked {
for kt in &result.per_key_times {
if kt.correct {
let stat = self.ranked_key_stats.stats.entry(kt.key).or_default();
stat.total_count += 1;
} else {
self.ranked_key_stats.update_key_error(kt.key);
}
}
for ev in &bigram_events {
self.ranked_bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.ranked_bigram_stats.update_redundancy_streak(&ev.key, &self.ranked_key_stats);
}
for ev in &trigram_events {
self.ranked_trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
}
}
// Update transition buffer
for kt in &result.per_key_times {
if kt.key != BACKSPACE {
self.transition_buffer.push(kt.time_ms);
}
}
if self.transition_buffer.len() > 200 {
let excess = self.transition_buffer.len() - 200;
self.transition_buffer.drain(..excess);
}
let mut buf = self.transition_buffer.clone();
self.user_median_transition_ms = ngram_stats::compute_median(&mut buf);
}
// Put drill_history back
self.drill_history = history;
// Prune trigrams — use drill_history.len() as total, matching the drill_index
// space used in last_seen_drill_index above (history position, includes partials)
let total_history_entries = self.drill_history.len() as u32;
self.trigram_stats.prune(
ngram_stats::MAX_TRIGRAMS,
total_history_entries,
&self.bigram_stats,
&self.key_stats,
);
self.ranked_trigram_stats.prune(
ngram_stats::MAX_TRIGRAMS,
total_history_entries,
&self.ranked_bigram_stats,
&self.ranked_key_stats,
);
}
pub fn retry_drill(&mut self) {
if let Some(ref drill) = self.drill {
let text: String = drill.target.iter().collect();
@@ -1111,6 +1378,9 @@ impl App {
}
self.profile.skill_tree = self.skill_tree.progress.clone();
// Rebuild n-gram stats from the replayed history
self.rebuild_ngram_stats();
}
pub fn go_to_skill_tree(&mut self) {

View File

@@ -11,6 +11,10 @@ pub struct KeyStat {
pub confidence: f64,
pub sample_count: usize,
pub recent_times: Vec<f64>,
#[serde(default)]
pub error_count: usize,
#[serde(default)]
pub total_count: usize,
}
impl Default for KeyStat {
@@ -21,6 +25,8 @@ impl Default for KeyStat {
confidence: 0.0,
sample_count: 0,
recent_times: Vec::new(),
error_count: 0,
total_count: 0,
}
}
}
@@ -44,6 +50,7 @@ impl KeyStatsStore {
pub fn update_key(&mut self, key: char, time_ms: f64) {
let stat = self.stats.entry(key).or_default();
stat.sample_count += 1;
stat.total_count += 1;
if stat.sample_count == 1 {
stat.filtered_time_ms = time_ms;
@@ -70,6 +77,22 @@ impl KeyStatsStore {
pub fn get_stat(&self, key: char) -> Option<&KeyStat> {
self.stats.get(&key)
}
/// Record an error for a key (increments error_count and total_count).
/// Does NOT update timing/confidence (those are only updated for correct strokes).
pub fn update_key_error(&mut self, key: char) {
let stat = self.stats.entry(key).or_default();
stat.error_count += 1;
stat.total_count += 1;
}
/// Laplace-smoothed error rate: (errors + 1) / (total + 2).
pub fn smoothed_error_rate(&self, key: char) -> f64 {
match self.stats.get(&key) {
Some(s) => (s.error_count as f64 + 1.0) / (s.total_count as f64 + 2.0),
None => 0.5, // (0 + 1) / (0 + 2) = 0.5
}
}
}
#[cfg(test)]

View File

@@ -1,5 +1,8 @@
pub mod filter;
pub mod key_stats;
pub mod learning_rate;
pub mod ngram_stats;
pub mod scoring;
pub mod skill_tree;
pub use ngram_stats::FocusTarget;

1221
src/engine/ngram_stats.rs Normal file

File diff suppressed because it is too large Load Diff

18
src/lib.rs Normal file
View File

@@ -0,0 +1,18 @@
// Library target exists solely for criterion benchmarks.
// The binary entry point is main.rs; this file re-declares the module tree so
// that bench harnesses can import types via `keydr::engine::*` / `keydr::session::*`.
// Most code is only exercised through the binary, so suppress dead_code warnings.
#![allow(dead_code)]
// Public: used directly by benchmarks
pub mod engine;
pub mod session;
// Private: required transitively by engine/session (won't compile without them)
mod app;
mod config;
mod event;
mod generator;
mod keyboard;
mod store;
mod ui;

View File

@@ -101,6 +101,7 @@ impl JsonStore {
}
/// Bundle all persisted data + config into an ExportData struct.
/// N-gram stats are not included — they are always rebuilt from drill history.
pub fn export_all(&self, config: &Config) -> ExportData {
let profile = self.load_profile().unwrap_or_default();
let key_stats = self.load_key_stats();

View File

@@ -74,6 +74,9 @@ impl Default for DrillHistoryData {
pub const EXPORT_VERSION: u32 = 1;
/// Export contract: drill_history is the sole source of truth for n-gram stats.
/// N-gram data is always rebuilt from history on import/startup, so it is not
/// included in the export payload.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExportData {
pub keydr_export_version: u32,