N-gram metrics overhaul & UI improvements

This commit is contained in:
2026-02-26 01:26:25 -05:00
parent e7f57dd497
commit 54ddebf054
23 changed files with 3812 additions and 1008 deletions

View File

@@ -2,19 +2,18 @@ use std::collections::{HashSet, VecDeque};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::thread;
use std::time::Instant;
use std::time::{Duration, Instant};
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use crate::config::Config;
use crate::engine::FocusSelection;
use crate::engine::filter::CharFilter;
use crate::engine::key_stats::KeyStatsStore;
use crate::engine::FocusTarget;
use crate::engine::ngram_stats::{
self, BigramKey, BigramStatsStore, TrigramStatsStore, extract_ngram_events,
select_focus_target,
self, BigramStatsStore, TrigramStatsStore, extract_ngram_events, select_focus,
};
use crate::engine::scoring;
use crate::engine::skill_tree::{BranchId, BranchStatus, DrillScope, SkillTree};
@@ -35,14 +34,16 @@ use crate::generator::passage::{
use crate::generator::phonetic::PhoneticGenerator;
use crate::generator::punctuate;
use crate::generator::transition_table::TransitionTable;
use crate::keyboard::model::KeyboardModel;
use crate::keyboard::display::BACKSPACE;
use crate::keyboard::model::KeyboardModel;
use crate::session::drill::DrillState;
use crate::session::input::{self, KeystrokeEvent};
use crate::session::result::{DrillResult, KeyTime};
use crate::store::json_store::JsonStore;
use crate::store::schema::{DrillHistoryData, ExportData, KeyStatsData, ProfileData, EXPORT_VERSION};
use crate::store::schema::{
DrillHistoryData, EXPORT_VERSION, ExportData, KeyStatsData, ProfileData,
};
use crate::ui::components::menu::Menu;
use crate::ui::theme::Theme;
@@ -108,6 +109,8 @@ const MASTERY_MESSAGES: &[&str] = &[
"One more key conquered!",
];
const POST_DRILL_INPUT_LOCK_MS: u64 = 800;
struct DownloadJob {
downloaded_bytes: Arc<AtomicU64>,
total_bytes: Arc<AtomicU64>,
@@ -135,7 +138,10 @@ pub fn next_available_path(path_str: &str) -> String {
let path = std::path::Path::new(path_str).to_path_buf();
let parent = path.parent().unwrap_or(std::path::Path::new("."));
let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("json");
let full_stem = path.file_stem().and_then(|s| s.to_str()).unwrap_or("export");
let full_stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("export");
// Strip existing trailing -N suffix to get base stem
let base_stem = if let Some(pos) = full_stem.rfind('-') {
@@ -272,6 +278,8 @@ pub struct App {
pub user_median_transition_ms: f64,
pub transition_buffer: Vec<f64>,
pub trigram_gain_history: Vec<f64>,
pub current_focus: Option<FocusSelection>,
pub post_drill_input_lock_until: Option<Instant>,
rng: SmallRng,
transition_table: TransitionTable,
#[allow(dead_code)]
@@ -293,38 +301,39 @@ impl App {
let store = JsonStore::new().ok();
let (key_stats, ranked_key_stats, skill_tree, profile, drill_history) = if let Some(ref s) = store {
// load_profile returns None if file exists but can't parse (schema mismatch)
let pd = s.load_profile();
let (key_stats, ranked_key_stats, skill_tree, profile, drill_history) =
if let Some(ref s) = store {
// load_profile returns None if file exists but can't parse (schema mismatch)
let pd = s.load_profile();
match pd {
Some(pd) if !pd.needs_reset() => {
let ksd = s.load_key_stats();
let rksd = s.load_ranked_key_stats();
let lhd = s.load_drill_history();
let st = SkillTree::new(pd.skill_tree.clone());
(ksd.stats, rksd.stats, st, pd, lhd.drills)
match pd {
Some(pd) if !pd.needs_reset() => {
let ksd = s.load_key_stats();
let rksd = s.load_ranked_key_stats();
let lhd = s.load_drill_history();
let st = SkillTree::new(pd.skill_tree.clone());
(ksd.stats, rksd.stats, st, pd, lhd.drills)
}
_ => {
// Schema mismatch or parse failure: full reset of all stores
(
KeyStatsStore::default(),
KeyStatsStore::default(),
SkillTree::default(),
ProfileData::default(),
Vec::new(),
)
}
}
_ => {
// Schema mismatch or parse failure: full reset of all stores
(
KeyStatsStore::default(),
KeyStatsStore::default(),
SkillTree::default(),
ProfileData::default(),
Vec::new(),
)
}
}
} else {
(
KeyStatsStore::default(),
KeyStatsStore::default(),
SkillTree::default(),
ProfileData::default(),
Vec::new(),
)
};
} else {
(
KeyStatsStore::default(),
KeyStatsStore::default(),
SkillTree::default(),
ProfileData::default(),
Vec::new(),
)
};
let mut key_stats_with_target = key_stats;
key_stats_with_target.target_cpm = config.target_cpm();
@@ -421,6 +430,8 @@ impl App {
user_median_transition_ms: 0.0,
transition_buffer: Vec::new(),
trigram_gain_history: Vec::new(),
current_focus: None,
post_drill_input_lock_until: None,
rng: SmallRng::from_entropy(),
transition_table,
dictionary,
@@ -454,6 +465,23 @@ impl App {
self.settings_editing_download_dir = false;
}
pub fn arm_post_drill_input_lock(&mut self) {
self.post_drill_input_lock_until =
Some(Instant::now() + Duration::from_millis(POST_DRILL_INPUT_LOCK_MS));
}
pub fn clear_post_drill_input_lock(&mut self) {
self.post_drill_input_lock_until = None;
}
pub fn post_drill_input_lock_remaining_ms(&self) -> Option<u64> {
self.post_drill_input_lock_until.and_then(|until| {
until
.checked_duration_since(Instant::now())
.map(|remaining| remaining.as_millis().max(1) as u64)
})
}
pub fn export_data(&mut self) {
let path = std::path::Path::new(&self.settings_export_path);
@@ -643,6 +671,7 @@ impl App {
}
pub fn start_drill(&mut self) {
self.clear_post_drill_input_lock();
let (text, source_info) = self.generate_text();
self.drill = Some(DrillState::new(&text));
self.drill_source_info = source_info;
@@ -659,17 +688,16 @@ impl App {
let scope = self.drill_scope;
let all_keys = self.skill_tree.unlocked_keys(scope);
// Select focus target: single char or bigram
let focus_target = select_focus_target(
// Select focus targets: char and bigram independently
let selection = select_focus(
&self.skill_tree,
scope,
&self.ranked_key_stats,
&self.ranked_bigram_stats,
);
let (focused_char, focused_bigram) = match &focus_target {
FocusTarget::Char(ch) => (Some(*ch), None),
FocusTarget::Bigram(key) => (Some(key.0[0]), Some(key.clone())),
};
self.current_focus = Some(selection.clone());
let focused_char = selection.char_focus;
let focused_bigram = selection.bigram_focus.map(|(k, _, _)| k.0);
// Generate base lowercase text using only lowercase keys from scope
let lowercase_keys: Vec<char> = all_keys
@@ -684,7 +712,8 @@ impl App {
let dict = Dictionary::load();
let rng = SmallRng::from_rng(&mut self.rng).unwrap();
let mut generator = PhoneticGenerator::new(table, dict, rng);
let mut text = generator.generate(&filter, lowercase_focused, word_count);
let mut text =
generator.generate(&filter, lowercase_focused, focused_bigram, word_count);
// Apply capitalization if uppercase keys are in scope
let cap_keys: Vec<char> = all_keys
@@ -694,7 +723,8 @@ impl App {
.collect();
if !cap_keys.is_empty() {
let mut rng = SmallRng::from_rng(&mut self.rng).unwrap();
text = capitalize::apply_capitalization(&text, &cap_keys, focused_char, &mut rng);
text =
capitalize::apply_capitalization(&text, &cap_keys, focused_char, &mut rng);
}
// Apply punctuation if punctuation keys are in scope
@@ -722,7 +752,8 @@ impl App {
if !digit_keys.is_empty() {
let has_dot = all_keys.contains(&'.');
let mut rng = SmallRng::from_rng(&mut self.rng).unwrap();
text = numbers::apply_numbers(&text, &digit_keys, has_dot, focused_char, &mut rng);
text =
numbers::apply_numbers(&text, &digit_keys, has_dot, focused_char, &mut rng);
}
// Apply code symbols only if this drill is for the CodeSymbols branch,
@@ -781,11 +812,6 @@ impl App {
text = insert_line_breaks(&text);
}
// After all generation: if bigram focus, swap some words for bigram-containing words
if let Some(ref bigram) = focused_bigram {
text = self.apply_bigram_focus(&text, &filter, bigram);
}
(text, None)
}
DrillMode::Code => {
@@ -796,13 +822,10 @@ impl App {
.unwrap_or_else(|| self.config.code_language.clone());
self.last_code_drill_language = Some(lang.clone());
let rng = SmallRng::from_rng(&mut self.rng).unwrap();
let mut generator = CodeSyntaxGenerator::new(
rng,
&lang,
&self.config.code_download_dir,
);
let mut generator =
CodeSyntaxGenerator::new(rng, &lang, &self.config.code_download_dir);
self.code_drill_language_override = None;
let text = generator.generate(&filter, None, word_count);
let text = generator.generate(&filter, None, None, word_count);
(text, Some(generator.last_source().to_string()))
}
DrillMode::Passage => {
@@ -821,7 +844,7 @@ impl App {
self.config.passage_downloads_enabled,
);
self.passage_drill_selection_override = None;
let text = generator.generate(&filter, None, word_count);
let text = generator.generate(&filter, None, None, word_count);
(text, Some(generator.last_source().to_string()))
}
}
@@ -891,18 +914,43 @@ impl App {
// Extract and update n-gram stats for all drill modes
let drill_index = self.drill_history.len() as u32;
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let hesitation_thresh =
ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
// Collect unique bigram keys for per-drill streak updates
let mut seen_bigrams: std::collections::HashSet<ngram_stats::BigramKey> =
std::collections::HashSet::new();
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
seen_bigrams.insert(ev.key.clone());
self.bigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
// Update streaks once per drill per unique bigram (not per event)
for key in &seen_bigrams {
self.bigram_stats
.update_error_anomaly_streak(key, &self.key_stats);
self.bigram_stats
.update_speed_anomaly_streak(key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.trigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
if ranked {
let mut seen_ranked_bigrams: std::collections::HashSet<ngram_stats::BigramKey> =
std::collections::HashSet::new();
for kt in &result.per_key_times {
if kt.correct {
self.ranked_key_stats.update_key(kt.key, kt.time_ms);
@@ -911,11 +959,29 @@ impl App {
}
}
for ev in &bigram_events {
self.ranked_bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.ranked_bigram_stats.update_redundancy_streak(&ev.key, &self.ranked_key_stats);
seen_ranked_bigrams.insert(ev.key.clone());
self.ranked_bigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
for key in &seen_ranked_bigrams {
self.ranked_bigram_stats
.update_error_anomaly_streak(key, &self.ranked_key_stats);
self.ranked_bigram_stats
.update_speed_anomaly_streak(key, &self.ranked_key_stats);
}
for ev in &trigram_events {
self.ranked_trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.ranked_trigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
let update = self
.skill_tree
@@ -1003,6 +1069,9 @@ impl App {
}
self.last_result = Some(result);
if !self.milestone_queue.is_empty() || self.drill_mode != DrillMode::Adaptive {
self.arm_post_drill_input_lock();
}
// Adaptive mode auto-continues unless milestone popups must be shown first.
if self.drill_mode == DrillMode::Adaptive && self.milestone_queue.is_empty() {
@@ -1036,15 +1105,36 @@ impl App {
// Extract and update n-gram stats
let drill_index = self.drill_history.len() as u32;
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let hesitation_thresh =
ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
let mut seen_bigrams: std::collections::HashSet<ngram_stats::BigramKey> =
std::collections::HashSet::new();
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
seen_bigrams.insert(ev.key.clone());
self.bigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
for key in &seen_bigrams {
self.bigram_stats
.update_error_anomaly_streak(key, &self.key_stats);
self.bigram_stats
.update_speed_anomaly_streak(key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index);
self.trigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index,
);
}
// Update transition buffer for hesitation baseline
@@ -1056,6 +1146,7 @@ impl App {
}
self.last_result = Some(result);
self.arm_post_drill_input_lock();
self.screen = AppScreen::DrillResult;
self.save_data();
}
@@ -1081,52 +1172,6 @@ impl App {
/// Replace up to 40% of words with dictionary words containing the target bigram.
/// No more than 3 consecutive bigram-focused words to prevent repetitive feel.
fn apply_bigram_focus(&mut self, text: &str, filter: &CharFilter, bigram: &BigramKey) -> String {
let bigram_str: String = bigram.0.iter().collect();
let words: Vec<&str> = text.split(' ').collect();
if words.is_empty() {
return text.to_string();
}
// Find dictionary words that contain the bigram and pass the filter
let dict = Dictionary::load();
let candidates: Vec<&str> = dict
.find_matching(filter, None)
.into_iter()
.filter(|w| w.contains(&bigram_str))
.collect();
if candidates.is_empty() {
return text.to_string();
}
let max_replacements = (words.len() * 2 + 4) / 5; // ~40%
let mut replaced = 0;
let mut consecutive = 0;
let mut result_words: Vec<String> = Vec::with_capacity(words.len());
for word in &words {
let already_has = word.contains(&bigram_str);
if already_has {
consecutive += 1;
result_words.push(word.to_string());
continue;
}
if replaced < max_replacements && consecutive < 3 {
let candidate = candidates[self.rng.gen_range(0..candidates.len())];
result_words.push(candidate.to_string());
replaced += 1;
consecutive += 1;
} else {
consecutive = 0;
result_words.push(word.to_string());
}
}
result_words.join(" ")
}
/// Update the rolling transition buffer with new inter-keystroke intervals.
fn update_transition_buffer(&mut self, per_key_times: &[KeyTime]) {
for kt in per_key_times {
@@ -1152,67 +1197,121 @@ impl App {
fn rebuild_ngram_stats(&mut self) {
// Reset n-gram stores
self.bigram_stats = BigramStatsStore::default();
self.bigram_stats.target_cpm = self.config.target_cpm();
self.ranked_bigram_stats = BigramStatsStore::default();
self.ranked_bigram_stats.target_cpm = self.config.target_cpm();
self.trigram_stats = TrigramStatsStore::default();
self.trigram_stats.target_cpm = self.config.target_cpm();
self.ranked_trigram_stats = TrigramStatsStore::default();
self.ranked_trigram_stats.target_cpm = self.config.target_cpm();
self.transition_buffer.clear();
self.user_median_transition_ms = 0.0;
// Reset char-level error/total counts (timing fields are untouched)
// Reset char-level error/total counts and EMA (timing fields are untouched)
for stat in self.key_stats.stats.values_mut() {
stat.error_count = 0;
stat.total_count = 0;
stat.error_rate_ema = 0.5;
}
for stat in self.ranked_key_stats.stats.values_mut() {
stat.error_count = 0;
stat.total_count = 0;
stat.error_rate_ema = 0.5;
}
// Take drill_history out temporarily to avoid borrow conflict
let history = std::mem::take(&mut self.drill_history);
for (drill_index, result) in history.iter().enumerate() {
let hesitation_thresh = ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let hesitation_thresh =
ngram_stats::hesitation_threshold(self.user_median_transition_ms);
let (bigram_events, trigram_events) =
extract_ngram_events(&result.per_key_times, hesitation_thresh);
// Rebuild char-level error/total counts from history
// Rebuild char-level error/total counts and EMA from history
for kt in &result.per_key_times {
if kt.correct {
let stat = self.key_stats.stats.entry(kt.key).or_default();
stat.total_count += 1;
// Update error rate EMA for correct stroke
if stat.total_count == 1 {
stat.error_rate_ema = 0.0;
} else {
stat.error_rate_ema = 0.1 * 0.0 + 0.9 * stat.error_rate_ema;
}
} else {
self.key_stats.update_key_error(kt.key);
}
}
// Collect unique bigram keys seen this drill for per-drill streak updates
let mut seen_bigrams: std::collections::HashSet<ngram_stats::BigramKey> =
std::collections::HashSet::new();
for ev in &bigram_events {
self.bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.bigram_stats.update_redundancy_streak(&ev.key, &self.key_stats);
seen_bigrams.insert(ev.key.clone());
self.bigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index as u32,
);
}
// Update streaks once per drill per unique bigram (not per event)
for key in &seen_bigrams {
self.bigram_stats
.update_error_anomaly_streak(key, &self.key_stats);
self.bigram_stats
.update_speed_anomaly_streak(key, &self.key_stats);
}
for ev in &trigram_events {
self.trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.trigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index as u32,
);
}
if result.ranked {
let mut seen_ranked_bigrams: std::collections::HashSet<ngram_stats::BigramKey> =
std::collections::HashSet::new();
for kt in &result.per_key_times {
if kt.correct {
let stat = self.ranked_key_stats.stats.entry(kt.key).or_default();
stat.total_count += 1;
if stat.total_count == 1 {
stat.error_rate_ema = 0.0;
} else {
stat.error_rate_ema = 0.1 * 0.0 + 0.9 * stat.error_rate_ema;
}
} else {
self.ranked_key_stats.update_key_error(kt.key);
}
}
for ev in &bigram_events {
self.ranked_bigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.ranked_bigram_stats.update_redundancy_streak(&ev.key, &self.ranked_key_stats);
seen_ranked_bigrams.insert(ev.key.clone());
self.ranked_bigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index as u32,
);
}
for key in &seen_ranked_bigrams {
self.ranked_bigram_stats
.update_error_anomaly_streak(key, &self.ranked_key_stats);
self.ranked_bigram_stats
.update_speed_anomaly_streak(key, &self.ranked_key_stats);
}
for ev in &trigram_events {
self.ranked_trigram_stats.update(ev.key.clone(), ev.total_time_ms, ev.correct, ev.has_hesitation, drill_index as u32);
self.ranked_trigram_stats.update(
ev.key.clone(),
ev.total_time_ms,
ev.correct,
ev.has_hesitation,
drill_index as u32,
);
}
}
@@ -1282,6 +1381,7 @@ impl App {
}
pub fn go_to_menu(&mut self) {
self.clear_post_drill_input_lock();
self.screen = AppScreen::Menu;
self.drill = None;
self.drill_source_info = None;
@@ -1289,6 +1389,7 @@ impl App {
}
pub fn go_to_stats(&mut self) {
self.clear_post_drill_input_lock();
self.stats_tab = 0;
self.history_selected = 0;
self.history_confirm_delete = false;
@@ -1562,10 +1663,8 @@ impl App {
}
pub fn start_code_downloads(&mut self) {
let queue = build_code_download_queue(
&self.config.code_language,
&self.code_intro_download_dir,
);
let queue =
build_code_download_queue(&self.config.code_language, &self.code_intro_download_dir);
self.code_intro_download_total = queue.len();
self.code_download_queue = queue;
@@ -1662,10 +1761,8 @@ impl App {
let snippets_limit = self.code_intro_snippets_per_repo;
// Get static references for thread
let repo_ref: &'static crate::generator::code_syntax::CodeRepo =
&lang.repos[repo_idx];
let block_style_ref: &'static crate::generator::code_syntax::BlockStyle =
&lang.block_style;
let repo_ref: &'static crate::generator::code_syntax::CodeRepo = &lang.repos[repo_idx];
let block_style_ref: &'static crate::generator::code_syntax::BlockStyle = &lang.block_style;
let handle = thread::spawn(move || {
let ok = download_code_repo_to_cache_with_progress(
@@ -1931,12 +2028,11 @@ impl App {
// Editable text field handled directly in key handler.
}
6 => {
self.config.code_snippets_per_repo =
match self.config.code_snippets_per_repo {
0 => 1,
n if n >= 200 => 0,
n => n + 10,
};
self.config.code_snippets_per_repo = match self.config.code_snippets_per_repo {
0 => 1,
n if n >= 200 => 0,
n => n + 10,
};
}
// 7 = Download Code Now (action button)
8 => {
@@ -1998,12 +2094,11 @@ impl App {
// Editable text field handled directly in key handler.
}
6 => {
self.config.code_snippets_per_repo =
match self.config.code_snippets_per_repo {
0 => 200,
1 => 0,
n => n.saturating_sub(10).max(1),
};
self.config.code_snippets_per_repo = match self.config.code_snippets_per_repo {
0 => 200,
1 => 0,
n => n.saturating_sub(10).max(1),
};
}
// 7 = Download Code Now (action button)
8 => {

View File

@@ -202,10 +202,19 @@ code_language = "go"
let config = Config::default();
let serialized = toml::to_string_pretty(&config).unwrap();
let deserialized: Config = toml::from_str(&serialized).unwrap();
assert_eq!(config.code_downloads_enabled, deserialized.code_downloads_enabled);
assert_eq!(
config.code_downloads_enabled,
deserialized.code_downloads_enabled
);
assert_eq!(config.code_download_dir, deserialized.code_download_dir);
assert_eq!(config.code_snippets_per_repo, deserialized.code_snippets_per_repo);
assert_eq!(config.code_onboarding_done, deserialized.code_onboarding_done);
assert_eq!(
config.code_snippets_per_repo,
deserialized.code_snippets_per_repo
);
assert_eq!(
config.code_onboarding_done,
deserialized.code_onboarding_done
);
}
#[test]

View File

@@ -15,6 +15,12 @@ pub struct KeyStat {
pub error_count: usize,
#[serde(default)]
pub total_count: usize,
#[serde(default = "default_error_rate_ema")]
pub error_rate_ema: f64,
}
fn default_error_rate_ema() -> f64 {
0.5
}
impl Default for KeyStat {
@@ -27,6 +33,7 @@ impl Default for KeyStat {
recent_times: Vec::new(),
error_count: 0,
total_count: 0,
error_rate_ema: 0.5,
}
}
}
@@ -67,6 +74,13 @@ impl KeyStatsStore {
if stat.recent_times.len() > 30 {
stat.recent_times.remove(0);
}
// Update error rate EMA (correct stroke = 0.0 signal)
if stat.total_count == 1 {
stat.error_rate_ema = 0.0;
} else {
stat.error_rate_ema = EMA_ALPHA * 0.0 + (1.0 - EMA_ALPHA) * stat.error_rate_ema;
}
}
pub fn get_confidence(&self, key: char) -> f64 {
@@ -84,13 +98,20 @@ impl KeyStatsStore {
let stat = self.stats.entry(key).or_default();
stat.error_count += 1;
stat.total_count += 1;
// Update error rate EMA (error stroke = 1.0 signal)
if stat.total_count == 1 {
stat.error_rate_ema = 1.0;
} else {
stat.error_rate_ema = EMA_ALPHA * 1.0 + (1.0 - EMA_ALPHA) * stat.error_rate_ema;
}
}
/// Laplace-smoothed error rate: (errors + 1) / (total + 2).
/// EMA-based error rate for a key.
pub fn smoothed_error_rate(&self, key: char) -> f64 {
match self.stats.get(&key) {
Some(s) => (s.error_count as f64 + 1.0) / (s.total_count as f64 + 2.0),
None => 0.5, // (0 + 1) / (0 + 2) = 0.5
Some(s) => s.error_rate_ema,
None => 0.5,
}
}
}
@@ -142,4 +163,50 @@ mod tests {
"confidence should be < 1.0 for slow typing, got {conf}"
);
}
#[test]
fn test_ema_error_rate_correct_strokes() {
let mut store = KeyStatsStore::default();
// All correct strokes → EMA should be 0.0 for first, stay near 0
store.update_key('a', 200.0);
assert!((store.smoothed_error_rate('a') - 0.0).abs() < f64::EPSILON);
for _ in 0..10 {
store.update_key('a', 200.0);
}
assert!(
store.smoothed_error_rate('a') < 0.01,
"All correct → EMA near 0"
);
}
#[test]
fn test_ema_error_rate_error_strokes() {
let mut store = KeyStatsStore::default();
// First stroke is error
store.update_key_error('b');
assert!((store.smoothed_error_rate('b') - 1.0).abs() < f64::EPSILON);
// Follow with correct strokes → EMA decays
for _ in 0..20 {
store.update_key('b', 200.0);
}
let rate = store.smoothed_error_rate('b');
assert!(
rate < 0.15,
"After 20 correct, EMA should be < 0.15, got {rate}"
);
}
#[test]
fn test_ema_error_rate_default_for_missing_key() {
let store = KeyStatsStore::default();
assert!((store.smoothed_error_rate('z') - 0.5).abs() < f64::EPSILON);
}
#[test]
fn test_ema_error_rate_serde_default() {
// Verify backward compat: deserializing old data without error_rate_ema gets 0.5
let json = r#"{"filtered_time_ms":200.0,"best_time_ms":200.0,"confidence":1.0,"sample_count":10,"recent_times":[],"error_count":2,"total_count":10}"#;
let stat: KeyStat = serde_json::from_str(json).unwrap();
assert!((stat.error_rate_ema - 0.5).abs() < f64::EPSILON);
}
}

View File

@@ -5,4 +5,4 @@ pub mod ngram_stats;
pub mod scoring;
pub mod skill_tree;
pub use ngram_stats::FocusTarget;
pub use ngram_stats::FocusSelection;

File diff suppressed because it is too large Load Diff

View File

@@ -567,9 +567,7 @@ impl SkillTree {
let newly_mastered: Vec<char> = if let Some(before) = before_stats {
before_unlocked
.iter()
.filter(|&&ch| {
before.get_confidence(ch) < 1.0 && stats.get_confidence(ch) >= 1.0
})
.filter(|&&ch| before.get_confidence(ch) < 1.0 && stats.get_confidence(ch) >= 1.0)
.copied()
.collect()
} else {

View File

@@ -51,24 +51,39 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
],
has_builtin: true,
block_style: BlockStyle::Braces(&[
"fn ", "pub fn ", "async fn ", "pub async fn ", "impl ", "trait ", "struct ", "enum ",
"macro_rules! ", "mod ", "const ", "static ", "type ", "pub struct ", "pub enum ",
"pub trait ", "pub mod ", "pub const ", "pub static ", "pub type ",
"fn ",
"pub fn ",
"async fn ",
"pub async fn ",
"impl ",
"trait ",
"struct ",
"enum ",
"macro_rules! ",
"mod ",
"const ",
"static ",
"type ",
"pub struct ",
"pub enum ",
"pub trait ",
"pub mod ",
"pub const ",
"pub static ",
"pub type ",
]),
},
CodeLanguage {
key: "python",
display_name: "Python",
extensions: &[".py", ".pyi"],
repos: &[
CodeRepo {
key: "cpython",
urls: &[
"https://raw.githubusercontent.com/python/cpython/main/Lib/json/encoder.py",
"https://raw.githubusercontent.com/python/cpython/main/Lib/pathlib/__init__.py",
],
},
],
repos: &[CodeRepo {
key: "cpython",
urls: &[
"https://raw.githubusercontent.com/python/cpython/main/Lib/json/encoder.py",
"https://raw.githubusercontent.com/python/cpython/main/Lib/pathlib/__init__.py",
],
}],
has_builtin: true,
block_style: BlockStyle::Indentation(&["def ", "class ", "async def ", "@"]),
},
@@ -76,15 +91,13 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "javascript",
display_name: "JavaScript",
extensions: &[".js", ".mjs"],
repos: &[
CodeRepo {
key: "node-stdlib",
urls: &[
"https://raw.githubusercontent.com/nodejs/node/main/lib/path.js",
"https://raw.githubusercontent.com/nodejs/node/main/lib/url.js",
],
},
],
repos: &[CodeRepo {
key: "node-stdlib",
urls: &[
"https://raw.githubusercontent.com/nodejs/node/main/lib/path.js",
"https://raw.githubusercontent.com/nodejs/node/main/lib/url.js",
],
}],
has_builtin: true,
block_style: BlockStyle::Braces(&[
"function ",
@@ -101,14 +114,10 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "go",
display_name: "Go",
extensions: &[".go"],
repos: &[
CodeRepo {
key: "go-stdlib",
urls: &[
"https://raw.githubusercontent.com/golang/go/master/src/fmt/print.go",
],
},
],
repos: &[CodeRepo {
key: "go-stdlib",
urls: &["https://raw.githubusercontent.com/golang/go/master/src/fmt/print.go"],
}],
has_builtin: true,
block_style: BlockStyle::Braces(&["func ", "type "]),
},
@@ -119,9 +128,7 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
repos: &[
CodeRepo {
key: "ts-node",
urls: &[
"https://raw.githubusercontent.com/TypeStrong/ts-node/main/src/index.ts",
],
urls: &["https://raw.githubusercontent.com/TypeStrong/ts-node/main/src/index.ts"],
},
CodeRepo {
key: "deno-std",
@@ -195,9 +202,7 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
},
CodeRepo {
key: "jq",
urls: &[
"https://raw.githubusercontent.com/jqlang/jq/master/src/builtin.c",
],
urls: &["https://raw.githubusercontent.com/jqlang/jq/master/src/builtin.c"],
},
],
has_builtin: true,
@@ -229,9 +234,7 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
},
CodeRepo {
key: "fmt",
urls: &[
"https://raw.githubusercontent.com/fmtlib/fmt/master/include/fmt/format.h",
],
urls: &["https://raw.githubusercontent.com/fmtlib/fmt/master/include/fmt/format.h"],
},
],
has_builtin: true,
@@ -274,7 +277,13 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
],
has_builtin: true,
block_style: BlockStyle::EndDelimited(&[
"def ", "class ", "module ", "attr_", "scope ", "describe ", "it ",
"def ",
"class ",
"module ",
"attr_",
"scope ",
"describe ",
"it ",
]),
},
CodeLanguage {
@@ -319,9 +328,7 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
repos: &[
CodeRepo {
key: "nvm",
urls: &[
"https://raw.githubusercontent.com/nvm-sh/nvm/master/nvm.sh",
],
urls: &["https://raw.githubusercontent.com/nvm-sh/nvm/master/nvm.sh"],
},
CodeRepo {
key: "oh-my-zsh",
@@ -340,9 +347,7 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
repos: &[
CodeRepo {
key: "kong",
urls: &[
"https://raw.githubusercontent.com/Kong/kong/master/kong/init.lua",
],
urls: &["https://raw.githubusercontent.com/Kong/kong/master/kong/init.lua"],
},
CodeRepo {
key: "luarocks",
@@ -359,41 +364,60 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "kotlin",
display_name: "Kotlin",
extensions: &[".kt", ".kts"],
repos: &[
CodeRepo {
key: "kotlinx-coroutines",
urls: &[
"https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/flow/Builders.kt",
"https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/channels/Channel.kt",
],
},
],
repos: &[CodeRepo {
key: "kotlinx-coroutines",
urls: &[
"https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/flow/Builders.kt",
"https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/channels/Channel.kt",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&[
"fun ", "class ", "object ", "interface ", "suspend fun ",
"public ", "private ", "internal ", "override fun ", "open ",
"data class ", "sealed ", "abstract ",
"val ", "var ", "enum ", "annotation ", "typealias ",
"fun ",
"class ",
"object ",
"interface ",
"suspend fun ",
"public ",
"private ",
"internal ",
"override fun ",
"open ",
"data class ",
"sealed ",
"abstract ",
"val ",
"var ",
"enum ",
"annotation ",
"typealias ",
]),
},
CodeLanguage {
key: "scala",
display_name: "Scala",
extensions: &[".scala"],
repos: &[
CodeRepo {
key: "scala-stdlib",
urls: &[
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/immutable/List.scala",
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/mutable/HashMap.scala",
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/Option.scala",
],
},
],
repos: &[CodeRepo {
key: "scala-stdlib",
urls: &[
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/immutable/List.scala",
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/mutable/HashMap.scala",
"https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/Option.scala",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&[
"def ", "class ", "object ", "trait ", "case class ",
"val ", "var ", "type ", "implicit ", "given ", "extension ",
"def ",
"class ",
"object ",
"trait ",
"case class ",
"val ",
"var ",
"type ",
"implicit ",
"given ",
"extension ",
]),
},
CodeLanguage {
@@ -461,18 +485,29 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "dart",
display_name: "Dart",
extensions: &[".dart"],
repos: &[
CodeRepo {
key: "flutter",
urls: &[
"https://raw.githubusercontent.com/flutter/flutter/master/packages/flutter/lib/src/widgets/framework.dart",
],
},
],
repos: &[CodeRepo {
key: "flutter",
urls: &[
"https://raw.githubusercontent.com/flutter/flutter/master/packages/flutter/lib/src/widgets/framework.dart",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&[
"void ", "Future ", "Future<", "class ", "int ", "String ", "bool ", "static ", "factory ",
"Widget ", "get ", "set ", "enum ", "typedef ", "extension ",
"void ",
"Future ",
"Future<",
"class ",
"int ",
"String ",
"bool ",
"static ",
"factory ",
"Widget ",
"get ",
"set ",
"enum ",
"typedef ",
"extension ",
]),
},
CodeLanguage {
@@ -495,22 +530,23 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
],
has_builtin: false,
block_style: BlockStyle::EndDelimited(&[
"def ", "defp ", "defmodule ",
"defmacro ", "defstruct", "defprotocol ", "defimpl ",
"def ",
"defp ",
"defmodule ",
"defmacro ",
"defstruct",
"defprotocol ",
"defimpl ",
]),
},
CodeLanguage {
key: "perl",
display_name: "Perl",
extensions: &[".pl", ".pm"],
repos: &[
CodeRepo {
key: "mojolicious",
urls: &[
"https://raw.githubusercontent.com/mojolicious/mojo/main/lib/Mojolicious.pm",
],
},
],
repos: &[CodeRepo {
key: "mojolicious",
urls: &["https://raw.githubusercontent.com/mojolicious/mojo/main/lib/Mojolicious.pm"],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&["sub "]),
},
@@ -518,30 +554,31 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "zig",
display_name: "Zig",
extensions: &[".zig"],
repos: &[
CodeRepo {
key: "zig-stdlib",
urls: &[
"https://raw.githubusercontent.com/ziglang/zig/master/lib/std/mem.zig",
"https://raw.githubusercontent.com/ziglang/zig/master/lib/std/fmt.zig",
],
},
],
repos: &[CodeRepo {
key: "zig-stdlib",
urls: &[
"https://raw.githubusercontent.com/ziglang/zig/master/lib/std/mem.zig",
"https://raw.githubusercontent.com/ziglang/zig/master/lib/std/fmt.zig",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&["pub fn ", "fn ", "const ", "pub const ", "test ", "var "]),
block_style: BlockStyle::Braces(&[
"pub fn ",
"fn ",
"const ",
"pub const ",
"test ",
"var ",
]),
},
CodeLanguage {
key: "julia",
display_name: "Julia",
extensions: &[".jl"],
repos: &[
CodeRepo {
key: "julia-stdlib",
urls: &[
"https://raw.githubusercontent.com/JuliaLang/julia/master/base/array.jl",
],
},
],
repos: &[CodeRepo {
key: "julia-stdlib",
urls: &["https://raw.githubusercontent.com/JuliaLang/julia/master/base/array.jl"],
}],
has_builtin: false,
block_style: BlockStyle::EndDelimited(&["function ", "macro "]),
},
@@ -549,14 +586,10 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "nim",
display_name: "Nim",
extensions: &[".nim"],
repos: &[
CodeRepo {
key: "nim-stdlib",
urls: &[
"https://raw.githubusercontent.com/nim-lang/Nim/devel/lib/pure/strutils.nim",
],
},
],
repos: &[CodeRepo {
key: "nim-stdlib",
urls: &["https://raw.githubusercontent.com/nim-lang/Nim/devel/lib/pure/strutils.nim"],
}],
has_builtin: false,
block_style: BlockStyle::Indentation(&["proc ", "func ", "method ", "type "]),
},
@@ -564,14 +597,10 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "ocaml",
display_name: "OCaml",
extensions: &[".ml", ".mli"],
repos: &[
CodeRepo {
key: "ocaml-stdlib",
urls: &[
"https://raw.githubusercontent.com/ocaml/ocaml/trunk/stdlib/list.ml",
],
},
],
repos: &[CodeRepo {
key: "ocaml-stdlib",
urls: &["https://raw.githubusercontent.com/ocaml/ocaml/trunk/stdlib/list.ml"],
}],
has_builtin: false,
block_style: BlockStyle::Indentation(&["let ", "type ", "module "]),
},
@@ -596,21 +625,24 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
has_builtin: false,
// Haskell: top-level declarations are indented blocks
block_style: BlockStyle::Indentation(&[
"data ", "type ", "class ", "instance ", "newtype ", "module ",
"data ",
"type ",
"class ",
"instance ",
"newtype ",
"module ",
]),
},
CodeLanguage {
key: "clojure",
display_name: "Clojure",
extensions: &[".clj", ".cljs"],
repos: &[
CodeRepo {
key: "clojure-core",
urls: &[
"https://raw.githubusercontent.com/clojure/clojure/master/src/clj/clojure/core.clj",
],
},
],
repos: &[CodeRepo {
key: "clojure-core",
urls: &[
"https://raw.githubusercontent.com/clojure/clojure/master/src/clj/clojure/core.clj",
],
}],
has_builtin: false,
block_style: BlockStyle::Indentation(&["(defn ", "(defn- ", "(defmacro "]),
},
@@ -618,15 +650,13 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "r",
display_name: "R",
extensions: &[".r", ".R"],
repos: &[
CodeRepo {
key: "shiny",
urls: &[
"https://raw.githubusercontent.com/rstudio/shiny/main/R/bootstrap.R",
"https://raw.githubusercontent.com/rstudio/shiny/main/R/input-text.R",
],
},
],
repos: &[CodeRepo {
key: "shiny",
urls: &[
"https://raw.githubusercontent.com/rstudio/shiny/main/R/bootstrap.R",
"https://raw.githubusercontent.com/rstudio/shiny/main/R/input-text.R",
],
}],
has_builtin: false,
// R functions are defined as `name <- function(...)`. Since our extractor only
// supports `starts_with`, we match roxygen doc blocks that precede functions.
@@ -636,36 +666,30 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "erlang",
display_name: "Erlang",
extensions: &[".erl"],
repos: &[
CodeRepo {
key: "cowboy",
urls: &[
"https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_req.erl",
"https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_http.erl",
],
},
],
repos: &[CodeRepo {
key: "cowboy",
urls: &[
"https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_req.erl",
"https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_http.erl",
],
}],
has_builtin: false,
// Erlang: -spec and -record use braces for types/fields.
// Erlang functions themselves don't use braces (they end with `.`),
// so extraction is limited to type specs and records.
block_style: BlockStyle::Braces(&[
"-spec ", "-record(", "-type ", "-callback ",
]),
block_style: BlockStyle::Braces(&["-spec ", "-record(", "-type ", "-callback "]),
},
CodeLanguage {
key: "groovy",
display_name: "Groovy",
extensions: &[".groovy"],
repos: &[
CodeRepo {
key: "nextflow",
urls: &[
"https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy",
"https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/Session.groovy",
],
},
],
repos: &[CodeRepo {
key: "nextflow",
urls: &[
"https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy",
"https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/Session.groovy",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&["def ", "void ", "static ", "public ", "private "]),
},
@@ -673,14 +697,12 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "fsharp",
display_name: "F#",
extensions: &[".fs", ".fsx"],
repos: &[
CodeRepo {
key: "fsharp-compiler",
urls: &[
"https://raw.githubusercontent.com/dotnet/fsharp/main/src/Compiler/Utilities/lib.fs",
],
},
],
repos: &[CodeRepo {
key: "fsharp-compiler",
urls: &[
"https://raw.githubusercontent.com/dotnet/fsharp/main/src/Compiler/Utilities/lib.fs",
],
}],
has_builtin: false,
block_style: BlockStyle::Indentation(&["let ", "member ", "type ", "module "]),
},
@@ -688,18 +710,23 @@ pub const CODE_LANGUAGES: &[CodeLanguage] = &[
key: "objective-c",
display_name: "Objective-C",
extensions: &[".m", ".h"],
repos: &[
CodeRepo {
key: "afnetworking",
urls: &[
"https://raw.githubusercontent.com/AFNetworking/AFNetworking/master/AFNetworking/AFURLSessionManager.m",
],
},
],
repos: &[CodeRepo {
key: "afnetworking",
urls: &[
"https://raw.githubusercontent.com/AFNetworking/AFNetworking/master/AFNetworking/AFURLSessionManager.m",
],
}],
has_builtin: false,
block_style: BlockStyle::Braces(&[
"- (", "+ (", "- (void)", "- (id)", "- (BOOL)",
"@interface ", "@implementation ", "@protocol ", "typedef ",
"- (",
"+ (",
"- (void)",
"- (id)",
"- (BOOL)",
"@interface ",
"@implementation ",
"@protocol ",
"typedef ",
]),
},
];
@@ -767,8 +794,8 @@ pub fn build_code_download_queue(lang_key: &str, cache_dir: &str) -> Vec<(String
for lk in &languages_to_download {
if let Some(lang) = language_by_key(lk) {
for (repo_idx, repo) in lang.repos.iter().enumerate() {
let cache_path = std::path::Path::new(cache_dir)
.join(format!("{}_{}.txt", lang.key, repo.key));
let cache_path =
std::path::Path::new(cache_dir).join(format!("{}_{}.txt", lang.key, repo.key));
if !cache_path.exists()
|| std::fs::metadata(&cache_path)
.map(|m| m.len() == 0)
@@ -1653,7 +1680,8 @@ impl TextGenerator for CodeSyntaxGenerator {
fn generate(
&mut self,
_filter: &CharFilter,
_focused: Option<char>,
_focused_char: Option<char>,
_focused_bigram: Option<[char; 2]>,
word_count: usize,
) -> String {
let embedded = self.get_snippets();
@@ -1721,7 +1749,10 @@ fn approx_token_count(text: &str) -> usize {
}
fn fit_snippet_to_target(snippet: &str, target_units: usize) -> String {
let max_units = target_units.saturating_mul(3).saturating_div(2).max(target_units);
let max_units = target_units
.saturating_mul(3)
.saturating_div(2)
.max(target_units);
if approx_token_count(snippet) <= max_units {
return snippet.to_string();
}
@@ -1777,8 +1808,8 @@ where
all_snippets.truncate(snippets_limit);
let cache_path = std::path::Path::new(cache_dir)
.join(format!("{}_{}.txt", language_key, repo.key));
let cache_path =
std::path::Path::new(cache_dir).join(format!("{}_{}.txt", language_key, repo.key));
let combined = all_snippets.join("\n---SNIPPET---\n");
fs::write(cache_path, combined).is_ok()
}
@@ -1811,8 +1842,12 @@ fn is_noise_snippet(snippet: &str) -> bool {
.lines()
.filter(|l| {
let t = l.trim();
!t.is_empty() && !t.starts_with("//") && !t.starts_with('#') && !t.starts_with("/*")
&& !t.starts_with('*') && !t.starts_with("*/")
!t.is_empty()
&& !t.starts_with("//")
&& !t.starts_with('#')
&& !t.starts_with("/*")
&& !t.starts_with('*')
&& !t.starts_with("*/")
})
.collect();
@@ -1828,8 +1863,15 @@ fn is_noise_snippet(snippet: &str) -> bool {
// Reject if body consists entirely of import/use/require/include statements
let import_prefixes = [
"import ", "from ", "use ", "require", "#include", "using ",
"package ", "module ", "extern crate ",
"import ",
"from ",
"use ",
"require",
"#include",
"using ",
"package ",
"module ",
"extern crate ",
];
let body_lines: Vec<&str> = meaningful_lines.iter().skip(1).copied().collect();
if !body_lines.is_empty()
@@ -2087,7 +2129,10 @@ fn structural_extract_indent(lines: &[&str]) -> Vec<String> {
}
}
while snippet_lines.last().map_or(false, |sl| sl.trim().is_empty()) {
while snippet_lines
.last()
.map_or(false, |sl| sl.trim().is_empty())
{
snippet_lines.pop();
}
@@ -2483,18 +2528,14 @@ z = 99
println!(" ({lines} lines, {bytes} bytes)");
total_ok += 1;
let snippets =
extract_code_snippets(&content, &lang.block_style);
let snippets = extract_code_snippets(&content, &lang.block_style);
println!(" Extracted {} snippets", snippets.len());
lang_total_snippets += snippets.len();
// Show first 2 snippets (truncated)
for (si, snippet) in snippets.iter().take(2).enumerate() {
let preview: String = snippet
.lines()
.take(5)
.collect::<Vec<_>>()
.join("\n");
let preview: String =
snippet.lines().take(5).collect::<Vec<_>>().join("\n");
let suffix = if snippet.lines().count() > 5 {
"\n ..."
} else {
@@ -2507,7 +2548,9 @@ z = 99
.join("\n");
println!(
" --- snippet {} ---\n{}{}",
si + 1, indented, suffix,
si + 1,
indented,
suffix,
);
}
}

View File

@@ -39,3 +39,26 @@ impl Dictionary {
matching
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn find_matching_focused_is_sort_only() {
let dictionary = Dictionary::load();
let filter = CharFilter::new(('a'..='z').collect());
let without_focus = dictionary.find_matching(&filter, None);
let with_focus = dictionary.find_matching(&filter, Some('k'));
// Same membership — focused param only reorders, never filters
let mut sorted_without: Vec<&str> = without_focus.clone();
let mut sorted_with: Vec<&str> = with_focus.clone();
sorted_without.sort();
sorted_with.sort();
assert_eq!(sorted_without, sorted_with);
assert_eq!(without_focus.len(), with_focus.len());
}
}

View File

@@ -12,6 +12,11 @@ pub mod transition_table;
use crate::engine::filter::CharFilter;
pub trait TextGenerator {
fn generate(&mut self, filter: &CharFilter, focused: Option<char>, word_count: usize)
-> String;
fn generate(
&mut self,
filter: &CharFilter,
focused_char: Option<char>,
focused_bigram: Option<[char; 2]>,
word_count: usize,
) -> String;
}

View File

@@ -176,7 +176,8 @@ impl TextGenerator for PassageGenerator {
fn generate(
&mut self,
_filter: &CharFilter,
_focused: Option<char>,
_focused_char: Option<char>,
_focused_bigram: Option<[char; 2]>,
word_count: usize,
) -> String {
let use_builtin = self.selection == "all" || self.selection == "builtin";

View File

@@ -56,9 +56,14 @@ impl PhoneticGenerator {
Some(filtered.last().unwrap().0)
}
fn generate_phonetic_word(&mut self, filter: &CharFilter, focused: Option<char>) -> String {
fn generate_phonetic_word(
&mut self,
filter: &CharFilter,
focused_char: Option<char>,
focused_bigram: Option<[char; 2]>,
) -> String {
for _attempt in 0..5 {
let word = self.try_generate_word(filter, focused);
let word = self.try_generate_word(filter, focused_char, focused_bigram);
if word.len() >= MIN_WORD_LEN {
return word;
}
@@ -67,14 +72,46 @@ impl PhoneticGenerator {
"the".to_string()
}
fn try_generate_word(&mut self, filter: &CharFilter, focused: Option<char>) -> String {
fn try_generate_word(
&mut self,
filter: &CharFilter,
focused: Option<char>,
focused_bigram: Option<[char; 2]>,
) -> String {
let mut word = Vec::new();
// Start with space prefix
let start_char = if let Some(focus) = focused {
// Try bigram-start: 30% chance to start word with bigram[0],bigram[1]
let bigram_eligible =
focused_bigram.filter(|b| filter.is_allowed(b[0]) && filter.is_allowed(b[1]));
let start_char = if let Some(bg) = bigram_eligible {
if self.rng.gen_bool(0.3) {
word.push(bg[0]);
word.push(bg[1]);
// Continue Markov chain from the bigram
let prefix = vec![' ', bg[0], bg[1]];
if let Some(probs) = self.table.segment(&prefix) {
Self::pick_weighted_from(&mut self.rng, probs, filter)
} else {
None
}
} else if let Some(focus) = focused {
if self.rng.gen_bool(0.4) && filter.is_allowed(focus) {
word.push(focus);
let prefix = vec![' ', ' ', focus];
if let Some(probs) = self.table.segment(&prefix) {
Self::pick_weighted_from(&mut self.rng, probs, filter)
} else {
None
}
} else {
None
}
} else {
None
}
} else if let Some(focus) = focused {
if self.rng.gen_bool(0.4) && filter.is_allowed(focus) {
word.push(focus);
// Get next char from transition table
let prefix = vec![' ', ' ', focus];
if let Some(probs) = self.table.segment(&prefix) {
Self::pick_weighted_from(&mut self.rng, probs, filter)
@@ -189,65 +226,151 @@ impl PhoneticGenerator {
word.iter().collect()
}
fn pick_tiered_word(
&mut self,
all_words: &[String],
bigram_indices: &[usize],
char_indices: &[usize],
other_indices: &[usize],
recent: &[String],
) -> String {
for _ in 0..6 {
let tier = self.select_tier(bigram_indices, char_indices, other_indices);
let idx = tier[self.rng.gen_range(0..tier.len())];
let word = &all_words[idx];
if !recent.contains(word) {
return word.clone();
}
}
// Fallback: accept any word from full pool
let idx = self.rng.gen_range(0..all_words.len());
all_words[idx].clone()
}
fn select_tier<'a>(
&mut self,
bigram_indices: &'a [usize],
char_indices: &'a [usize],
other_indices: &'a [usize],
) -> &'a [usize] {
let has_bigram = bigram_indices.len() >= 2;
let has_char = char_indices.len() >= 2;
// Tier selection probabilities:
// Both available: 40% bigram, 30% char, 30% other
// Only bigram: 50% bigram, 50% other
// Only char: 70% char, 30% other
// Neither: 100% other
let roll: f64 = self.rng.gen_range(0.0..1.0);
match (has_bigram, has_char) {
(true, true) => {
if roll < 0.4 {
bigram_indices
} else if roll < 0.7 {
char_indices
} else {
if other_indices.len() >= 2 {
other_indices
} else if has_char {
char_indices
} else {
bigram_indices
}
}
}
(true, false) => {
if roll < 0.5 {
bigram_indices
} else {
if other_indices.len() >= 2 {
other_indices
} else {
bigram_indices
}
}
}
(false, true) => {
if roll < 0.7 {
char_indices
} else {
if other_indices.len() >= 2 {
other_indices
} else {
char_indices
}
}
}
(false, false) => {
// Use other_indices if available, otherwise all words
if other_indices.len() >= 2 {
other_indices
} else {
char_indices
}
}
}
}
}
impl TextGenerator for PhoneticGenerator {
fn generate(
&mut self,
filter: &CharFilter,
focused: Option<char>,
focused_char: Option<char>,
focused_bigram: Option<[char; 2]>,
word_count: usize,
) -> String {
// keybr's approach: prefer real words when enough match the filter
// Collect matching words into owned Vec to avoid borrow conflict
let matching_words: Vec<String> = self
.dictionary
.find_matching(filter, focused)
.find_matching(filter, None)
.iter()
.map(|s| s.to_string())
.collect();
let use_real_words = matching_words.len() >= MIN_REAL_WORDS;
// Pre-categorize words into tiers for real-word mode
let bigram_str = focused_bigram.map(|b| format!("{}{}", b[0], b[1]));
let focus_char_lower = focused_char.filter(|ch| ch.is_ascii_lowercase());
let (bigram_indices, char_indices, other_indices) = if use_real_words {
let mut bi = Vec::new();
let mut ci = Vec::new();
let mut oi = Vec::new();
for (i, w) in matching_words.iter().enumerate() {
if bigram_str.as_ref().is_some_and(|b| w.contains(b.as_str())) {
bi.push(i);
} else if focus_char_lower.is_some_and(|ch| w.contains(ch)) {
ci.push(i);
} else {
oi.push(i);
}
}
(bi, ci, oi)
} else {
(vec![], vec![], vec![])
};
let mut words: Vec<String> = Vec::new();
let mut last_word = String::new();
let mut recent: Vec<String> = Vec::new();
for _ in 0..word_count {
if use_real_words {
// Pick a real word (avoid consecutive duplicates).
// If focused is set, bias sampling toward words containing that key.
let focus = focused.filter(|ch| ch.is_ascii_lowercase());
let focused_indices: Vec<usize> = if let Some(ch) = focus {
matching_words
.iter()
.enumerate()
.filter_map(|(i, w)| w.contains(ch).then_some(i))
.collect()
} else {
Vec::new()
};
let mut picked = None;
for _ in 0..6 {
let idx = if !focused_indices.is_empty() && self.rng.gen_bool(0.70) {
let j = self.rng.gen_range(0..focused_indices.len());
focused_indices[j]
} else {
self.rng.gen_range(0..matching_words.len())
};
let word = matching_words[idx].clone();
if word != last_word {
picked = Some(word);
break;
}
let word = self.pick_tiered_word(
&matching_words,
&bigram_indices,
&char_indices,
&other_indices,
&recent,
);
recent.push(word.clone());
if recent.len() > 4 {
recent.remove(0);
}
let word = match picked {
Some(w) => w,
None => self.generate_phonetic_word(filter, focused),
};
last_word.clone_from(&word);
words.push(word);
} else {
// Fall back to phonetic pseudo-words
let word = self.generate_phonetic_word(filter, focused);
let word = self.generate_phonetic_word(filter, focused_char, focused_bigram);
words.push(word);
}
}
@@ -272,7 +395,7 @@ mod tests {
Dictionary::load(),
SmallRng::seed_from_u64(42),
);
let focused_text = focused_gen.generate(&filter, Some('k'), 1200);
let focused_text = focused_gen.generate(&filter, Some('k'), None, 1200);
let focused_count = focused_text
.split_whitespace()
.filter(|w| w.contains('k'))
@@ -280,7 +403,7 @@ mod tests {
let mut baseline_gen =
PhoneticGenerator::new(table, Dictionary::load(), SmallRng::seed_from_u64(42));
let baseline_text = baseline_gen.generate(&filter, None, 1200);
let baseline_text = baseline_gen.generate(&filter, None, None, 1200);
let baseline_count = baseline_text
.split_whitespace()
.filter(|w| w.contains('k'))
@@ -291,4 +414,64 @@ mod tests {
"focused_count={focused_count}, baseline_count={baseline_count}"
);
}
#[test]
fn test_phonetic_bigram_focus_increases_bigram_words() {
let dictionary = Dictionary::load();
let table = TransitionTable::build_from_words(&dictionary.words_list());
let filter = CharFilter::new(('a'..='z').collect());
let mut bigram_gen = PhoneticGenerator::new(
table.clone(),
Dictionary::load(),
SmallRng::seed_from_u64(42),
);
let bigram_text = bigram_gen.generate(&filter, None, Some(['t', 'h']), 1200);
let bigram_count = bigram_text
.split_whitespace()
.filter(|w| w.contains("th"))
.count();
let mut baseline_gen =
PhoneticGenerator::new(table, Dictionary::load(), SmallRng::seed_from_u64(42));
let baseline_text = baseline_gen.generate(&filter, None, None, 1200);
let baseline_count = baseline_text
.split_whitespace()
.filter(|w| w.contains("th"))
.count();
assert!(
bigram_count > baseline_count,
"bigram_count={bigram_count}, baseline_count={baseline_count}"
);
}
#[test]
fn test_phonetic_dual_focus_no_excessive_repeats() {
let dictionary = Dictionary::load();
let table = TransitionTable::build_from_words(&dictionary.words_list());
let filter = CharFilter::new(('a'..='z').collect());
let mut generator =
PhoneticGenerator::new(table, Dictionary::load(), SmallRng::seed_from_u64(42));
let text = generator.generate(&filter, Some('k'), Some(['t', 'h']), 200);
let words: Vec<&str> = text.split_whitespace().collect();
// Check no word appears > 3 times consecutively
let mut max_consecutive = 1;
let mut current_run = 1;
for i in 1..words.len() {
if words[i] == words[i - 1] {
current_run += 1;
max_consecutive = max_consecutive.max(current_run);
} else {
current_run = 1;
}
}
assert!(
max_consecutive <= 3,
"Max consecutive repeats = {max_consecutive}, expected <= 3"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ use serde::{Serialize, de::DeserializeOwned};
use crate::config::Config;
use crate::store::schema::{
DrillHistoryData, ExportData, KeyStatsData, ProfileData, EXPORT_VERSION,
DrillHistoryData, EXPORT_VERSION, ExportData, KeyStatsData, ProfileData,
};
pub struct JsonStore {
@@ -136,9 +136,18 @@ impl JsonStore {
let files: Vec<(&str, String)> = vec![
("profile.json", serde_json::to_string_pretty(&data.profile)?),
("key_stats.json", serde_json::to_string_pretty(&data.key_stats)?),
("key_stats_ranked.json", serde_json::to_string_pretty(&data.ranked_key_stats)?),
("lesson_history.json", serde_json::to_string_pretty(&data.drill_history)?),
(
"key_stats.json",
serde_json::to_string_pretty(&data.key_stats)?,
),
(
"key_stats_ranked.json",
serde_json::to_string_pretty(&data.ranked_key_stats)?,
),
(
"lesson_history.json",
serde_json::to_string_pretty(&data.drill_history)?,
),
];
// Stage phase: write .tmp files
@@ -172,9 +181,7 @@ impl JsonStore {
let had_original = final_path.exists();
// Back up existing file if it exists
if had_original
&& let Err(e) = fs::rename(&final_path, &bak_path)
{
if had_original && let Err(e) = fs::rename(&final_path, &bak_path) {
// Rollback: restore already committed files
for (committed_final, committed_bak, committed_had) in &committed {
if *committed_had {
@@ -335,12 +342,19 @@ mod tests {
// Now create a store that points to a nonexistent subdir of the same tmpdir
// so that staging .tmp writes will fail
let bad_dir = _dir.path().join("nonexistent_subdir");
let bad_store = JsonStore { base_dir: bad_dir.clone() };
let bad_store = JsonStore {
base_dir: bad_dir.clone(),
};
let config = Config::default();
let export = make_test_export(&config);
let result = bad_store.import_all(&export);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Import failed during staging"));
assert!(
result
.unwrap_err()
.to_string()
.contains("Import failed during staging")
);
// Original file in the real store is unchanged
let after_content = fs::read_to_string(store.file_path("profile.json")).unwrap();
@@ -390,5 +404,4 @@ mod tests {
// Should have been cleaned up
assert!(!store.file_path("profile.json.bak").exists());
}
}

View File

@@ -10,11 +10,20 @@ use crate::ui::theme::Theme;
pub struct Dashboard<'a> {
pub result: &'a DrillResult,
pub theme: &'a Theme,
pub input_lock_remaining_ms: Option<u64>,
}
impl<'a> Dashboard<'a> {
pub fn new(result: &'a DrillResult, theme: &'a Theme) -> Self {
Self { result, theme }
pub fn new(
result: &'a DrillResult,
theme: &'a Theme,
input_lock_remaining_ms: Option<u64>,
) -> Self {
Self {
result,
theme,
input_lock_remaining_ms,
}
}
}
@@ -114,16 +123,31 @@ impl Widget for Dashboard<'_> {
]);
Paragraph::new(chars_line).render(layout[4], buf);
let help = Paragraph::new(Line::from(vec![
Span::styled(
" [c/Enter/Space] Continue ",
Style::default().fg(colors.accent()),
),
Span::styled("[r] Retry ", Style::default().fg(colors.accent())),
Span::styled("[q] Menu ", Style::default().fg(colors.accent())),
Span::styled("[s] Stats ", Style::default().fg(colors.accent())),
Span::styled("[x] Delete", Style::default().fg(colors.accent())),
]));
let help = if let Some(ms) = self.input_lock_remaining_ms {
Paragraph::new(Line::from(vec![
Span::styled(
" Input temporarily blocked ",
Style::default().fg(colors.warning()),
),
Span::styled(
format!("({ms}ms remaining)"),
Style::default()
.fg(colors.warning())
.add_modifier(Modifier::BOLD),
),
]))
} else {
Paragraph::new(Line::from(vec![
Span::styled(
" [c/Enter/Space] Continue ",
Style::default().fg(colors.accent()),
),
Span::styled("[r] Retry ", Style::default().fg(colors.accent())),
Span::styled("[q] Menu ", Style::default().fg(colors.accent())),
Span::styled("[s] Stats ", Style::default().fg(colors.accent())),
Span::styled("[x] Delete", Style::default().fg(colors.accent())),
]))
};
help.render(layout[6], buf);
}
}

View File

@@ -267,6 +267,22 @@ impl KeyboardDiagram<'_> {
}
let offsets: &[u16] = &[3, 4, 6];
let keyboard_width = letter_rows
.iter()
.enumerate()
.map(|(row_idx, row)| {
let offset = offsets.get(row_idx).copied().unwrap_or(0);
let row_end = offset + row.len() as u16 * key_width;
match row_idx {
0 => row_end + 3, // [B]
1 => row_end + 3, // [E]
2 => row_end + 3, // [S]
_ => row_end,
}
})
.max()
.unwrap_or(0);
let start_x = inner.x + inner.width.saturating_sub(keyboard_width) / 2;
for (row_idx, row) in letter_rows.iter().enumerate() {
let y = inner.y + row_idx as u16;
@@ -283,18 +299,18 @@ impl KeyboardDiagram<'_> {
let is_next = self.next_key == Some(TAB);
let is_sel = self.is_sentinel_selected(TAB);
let style = modifier_key_style(is_dep, is_next, is_sel, colors);
buf.set_string(inner.x, y, "[T]", style);
buf.set_string(start_x, y, "[T]", style);
}
2 => {
let is_dep = self.shift_held;
let style = modifier_key_style(is_dep, false, false, colors);
buf.set_string(inner.x, y, "[S]", style);
buf.set_string(start_x, y, "[S]", style);
}
_ => {}
}
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_width;
let x = start_x + offset + col_idx as u16 * key_width;
if x + key_width > inner.x + inner.width {
break;
}
@@ -326,7 +342,7 @@ impl KeyboardDiagram<'_> {
}
// Render trailing modifier key
let row_end_x = inner.x + offset + row.len() as u16 * key_width;
let row_end_x = start_x + offset + row.len() as u16 * key_width;
match row_idx {
1 => {
if row_end_x + 3 <= inner.x + inner.width {
@@ -351,7 +367,7 @@ impl KeyboardDiagram<'_> {
// Backspace at end of first row
if inner.height >= 3 {
let y = inner.y;
let row_end_x = inner.x + offsets[0] + letter_rows[0].len() as u16 * key_width;
let row_end_x = start_x + offsets[0] + letter_rows[0].len() as u16 * key_width;
if row_end_x + 3 <= inner.x + inner.width {
let is_dep = self.depressed_keys.contains(&BACKSPACE);
let is_next = self.next_key == Some(BACKSPACE);
@@ -373,6 +389,24 @@ impl KeyboardDiagram<'_> {
}
let offsets: &[u16] = &[0, 5, 5, 6];
let keyboard_width = self
.model
.rows
.iter()
.enumerate()
.map(|(row_idx, row)| {
let offset = offsets.get(row_idx).copied().unwrap_or(0);
let row_end = offset + row.len() as u16 * key_width;
match row_idx {
0 => row_end + 6, // [Bksp]
2 => row_end + 7, // [Enter]
3 => row_end + 6, // [Shft]
_ => row_end,
}
})
.max()
.unwrap_or(0);
let start_x = inner.x + inner.width.saturating_sub(keyboard_width) / 2;
for (row_idx, row) in self.model.rows.iter().enumerate() {
let y = inner.y + row_idx as u16;
@@ -391,7 +425,7 @@ impl KeyboardDiagram<'_> {
let is_sel = self.is_sentinel_selected(TAB);
let style = modifier_key_style(is_dep, is_next, is_sel, colors);
let label = format!("[{}]", display::key_short_label(TAB));
buf.set_string(inner.x, y, &label, style);
buf.set_string(start_x, y, &label, style);
}
}
2 => {
@@ -401,10 +435,10 @@ impl KeyboardDiagram<'_> {
let style = Style::default()
.fg(readable_fg(bg, colors.warning()))
.bg(bg);
buf.set_string(inner.x, y, "[Cap]", style);
buf.set_string(start_x, y, "[Cap]", style);
} else {
let style = Style::default().fg(colors.text_pending()).bg(colors.bg());
buf.set_string(inner.x, y, "[ ]", style);
buf.set_string(start_x, y, "[ ]", style);
}
}
}
@@ -412,14 +446,14 @@ impl KeyboardDiagram<'_> {
if offset >= 6 {
let is_dep = self.shift_held;
let style = modifier_key_style(is_dep, false, false, colors);
buf.set_string(inner.x, y, "[Shft]", style);
buf.set_string(start_x, y, "[Shft]", style);
}
}
_ => {}
}
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_width;
let x = start_x + offset + col_idx as u16 * key_width;
if x + key_width > inner.x + inner.width {
break;
}
@@ -451,7 +485,7 @@ impl KeyboardDiagram<'_> {
}
// Render trailing modifier keys
let after_x = inner.x + offset + row.len() as u16 * key_width;
let after_x = start_x + offset + row.len() as u16 * key_width;
match row_idx {
0 => {
if after_x + 6 <= inner.x + inner.width {
@@ -484,34 +518,13 @@ impl KeyboardDiagram<'_> {
}
}
// Compute full keyboard width from rendered rows (including trailing modifier keys),
// so the space bar centers relative to the keyboard, not the container.
let keyboard_width = self
.model
.rows
.iter()
.enumerate()
.map(|(row_idx, row)| {
let offset = offsets.get(row_idx).copied().unwrap_or(0);
let row_end = offset + row.len() as u16 * key_width;
match row_idx {
0 => row_end + 6, // [Bksp]
2 => row_end + 7, // [Enter]
3 => row_end + 6, // [Shft]
_ => row_end,
}
})
.max()
.unwrap_or(0)
.min(inner.width);
// Space bar row (row 4)
let space_y = inner.y + 4;
if space_y < inner.y + inner.height {
let space_name = display::key_display_name(SPACE);
let space_label = format!("[ {space_name} ]");
let space_width = space_label.len() as u16;
let space_x = inner.x + (keyboard_width.saturating_sub(space_width)) / 2;
let space_x = start_x + (keyboard_width.saturating_sub(space_width)) / 2;
if space_x + space_width <= inner.x + inner.width {
let is_dep = self.depressed_keys.contains(&SPACE);
let is_next = self.next_key == Some(SPACE);
@@ -527,6 +540,16 @@ impl KeyboardDiagram<'_> {
let letter_rows = self.model.letter_rows();
let key_width: u16 = 5;
let offsets: &[u16] = &[1, 3, 5];
let keyboard_width = letter_rows
.iter()
.enumerate()
.map(|(row_idx, row)| {
let offset = offsets.get(row_idx).copied().unwrap_or(0);
offset + row.len() as u16 * key_width
})
.max()
.unwrap_or(0);
let start_x = inner.x + inner.width.saturating_sub(keyboard_width) / 2;
if inner.height < 3 || inner.width < 30 {
return;
@@ -541,7 +564,7 @@ impl KeyboardDiagram<'_> {
let offset = offsets.get(row_idx).copied().unwrap_or(0);
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_width;
let x = start_x + offset + col_idx as u16 * key_width;
if x + key_width > inner.x + inner.width {
break;
}

View File

@@ -118,8 +118,8 @@ impl Widget for SkillTreeWidget<'_> {
let notice_lines = footer_notice
.map(|text| wrapped_line_count(text, inner.width as usize))
.unwrap_or(0);
let show_notice =
footer_notice.is_some() && (inner.height as usize >= hint_lines.len() + notice_lines + 8);
let show_notice = footer_notice.is_some()
&& (inner.height as usize >= hint_lines.len() + notice_lines + 8);
let footer_needed = hint_lines.len() + if show_notice { notice_lines } else { 0 } + 1;
let footer_height = footer_needed
.min(inner.height.saturating_sub(5) as usize)
@@ -161,7 +161,10 @@ impl Widget for SkillTreeWidget<'_> {
}
}
footer_lines.extend(hint_lines.into_iter().map(|line| {
Line::from(Span::styled(line, Style::default().fg(colors.text_pending())))
Line::from(Span::styled(
line,
Style::default().fg(colors.text_pending()),
))
}));
let footer = Paragraph::new(footer_lines).wrap(Wrap { trim: false });
footer.render(layout[3], buf);

View File

@@ -6,12 +6,39 @@ use ratatui::widgets::{Block, Clear, Paragraph, Widget};
use std::collections::{BTreeSet, HashMap};
use crate::engine::key_stats::KeyStatsStore;
use crate::engine::ngram_stats::{AnomalyType, FocusSelection};
use crate::keyboard::display::{self, BACKSPACE, ENTER, MODIFIER_SENTINELS, SPACE, TAB};
use crate::keyboard::model::KeyboardModel;
use crate::session::result::DrillResult;
use crate::ui::components::activity_heatmap::ActivityHeatmap;
use crate::ui::theme::Theme;
// ---------------------------------------------------------------------------
// N-grams tab view models
// ---------------------------------------------------------------------------
pub struct AnomalyBigramRow {
pub bigram: String,
pub anomaly_pct: f64,
pub sample_count: usize,
pub error_count: usize,
pub error_rate_ema: f64,
pub speed_ms: f64,
pub expected_baseline: f64,
pub confirmed: bool,
}
pub struct NgramTabData {
pub focus: FocusSelection,
pub error_anomalies: Vec<AnomalyBigramRow>,
pub speed_anomalies: Vec<AnomalyBigramRow>,
pub total_bigrams: usize,
pub total_trigrams: usize,
pub hesitation_threshold_ms: f64,
pub latest_trigram_gain: Option<f64>,
pub scope_label: String,
}
pub struct StatsDashboard<'a> {
pub history: &'a [DrillResult],
pub key_stats: &'a KeyStatsStore,
@@ -24,6 +51,7 @@ pub struct StatsDashboard<'a> {
pub history_selected: usize,
pub history_confirm_delete: bool,
pub keyboard_model: &'a KeyboardModel,
pub ngram_data: Option<&'a NgramTabData>,
}
impl<'a> StatsDashboard<'a> {
@@ -39,6 +67,7 @@ impl<'a> StatsDashboard<'a> {
history_selected: usize,
history_confirm_delete: bool,
keyboard_model: &'a KeyboardModel,
ngram_data: Option<&'a NgramTabData>,
) -> Self {
Self {
history,
@@ -52,6 +81,7 @@ impl<'a> StatsDashboard<'a> {
history_selected,
history_confirm_delete,
keyboard_model,
ngram_data,
}
}
}
@@ -92,6 +122,7 @@ impl Widget for StatsDashboard<'_> {
"[3] Activity",
"[4] Accuracy",
"[5] Timing",
"[6] N-grams",
];
let tab_spans: Vec<Span> = tabs
.iter()
@@ -114,9 +145,9 @@ impl Widget for StatsDashboard<'_> {
// Footer
let footer_text = if self.active_tab == 1 {
" [ESC] Back [Tab] Next tab [1-5] Switch tab [j/k] Navigate [x] Delete"
" [ESC] Back [Tab] Next tab [1-6] Switch tab [j/k] Navigate [x] Delete"
} else {
" [ESC] Back [Tab] Next tab [1-5] Switch tab"
" [ESC] Back [Tab] Next tab [1-6] Switch tab"
};
let footer = Paragraph::new(Line::from(Span::styled(
footer_text,
@@ -163,6 +194,7 @@ impl StatsDashboard<'_> {
2 => self.render_activity_tab(area, buf),
3 => self.render_accuracy_tab(area, buf),
4 => self.render_timing_tab(area, buf),
5 => self.render_ngram_tab(area, buf),
_ => {}
}
}
@@ -692,6 +724,17 @@ impl StatsDashboard<'_> {
let show_shifted = inner.height >= 10; // 4 base + 4 shifted + 1 mod row + 1 spare
let all_rows = &self.keyboard_model.rows;
let offsets: &[u16] = &[0, 2, 3, 4];
let kbd_width = all_rows
.iter()
.enumerate()
.map(|(i, row)| {
let off = offsets.get(i).copied().unwrap_or(0);
off + row.len() as u16 * key_step
})
.max()
.unwrap_or(inner.width)
.min(inner.width);
let keyboard_x = inner.x + inner.width.saturating_sub(kbd_width) / 2;
for (row_idx, row) in all_rows.iter().enumerate() {
let base_y = if show_shifted {
@@ -711,7 +754,7 @@ impl StatsDashboard<'_> {
let shifted_y = base_y - 1;
if shifted_y >= inner.y {
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_step;
let x = keyboard_x + offset + col_idx as u16 * key_step;
if x + key_width > inner.x + inner.width {
break;
}
@@ -733,7 +776,7 @@ impl StatsDashboard<'_> {
// Base row
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_step;
let x = keyboard_x + offset + col_idx as u16 * key_step;
if x + key_width > inner.x + inner.width {
break;
}
@@ -745,20 +788,8 @@ impl StatsDashboard<'_> {
let display = format_accuracy_cell(key, accuracy, key_width);
buf.set_string(x, base_y, &display, Style::default().fg(fg_color));
}
}
// Modifier key stats row below the keyboard, spread across keyboard width
let kbd_width = all_rows
.iter()
.enumerate()
.map(|(i, row)| {
let off = offsets.get(i).copied().unwrap_or(0);
off + row.len() as u16 * key_step
})
.max()
.unwrap_or(inner.width)
.min(inner.width);
let mod_y = if show_shifted {
inner.y + all_rows.len() as u16 * 2 + 1
} else {
@@ -783,7 +814,7 @@ impl StatsDashboard<'_> {
let accuracy = self.get_key_accuracy(key);
let fg_color = accuracy_color(accuracy, colors);
buf.set_string(
inner.x + positions[i],
keyboard_x + positions[i],
mod_y,
&labels[i],
Style::default().fg(fg_color),
@@ -848,6 +879,17 @@ impl StatsDashboard<'_> {
let show_shifted = inner.height >= 10; // 4 base + 4 shifted + 1 mod row + 1 spare
let all_rows = &self.keyboard_model.rows;
let offsets: &[u16] = &[0, 2, 3, 4];
let kbd_width = all_rows
.iter()
.enumerate()
.map(|(i, row)| {
let off = offsets.get(i).copied().unwrap_or(0);
off + row.len() as u16 * key_step
})
.max()
.unwrap_or(inner.width)
.min(inner.width);
let keyboard_x = inner.x + inner.width.saturating_sub(kbd_width) / 2;
for (row_idx, row) in all_rows.iter().enumerate() {
let base_y = if show_shifted {
@@ -866,7 +908,7 @@ impl StatsDashboard<'_> {
let shifted_y = base_y - 1;
if shifted_y >= inner.y {
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_step;
let x = keyboard_x + offset + col_idx as u16 * key_step;
if x + key_width > inner.x + inner.width {
break;
}
@@ -886,7 +928,7 @@ impl StatsDashboard<'_> {
}
for (col_idx, physical_key) in row.iter().enumerate() {
let x = inner.x + offset + col_idx as u16 * key_step;
let x = keyboard_x + offset + col_idx as u16 * key_step;
if x + key_width > inner.x + inner.width {
break;
}
@@ -897,20 +939,8 @@ impl StatsDashboard<'_> {
let display = format_timing_cell(key, time_ms, key_width);
buf.set_string(x, base_y, &display, Style::default().fg(fg_color));
}
}
// Modifier key stats row below the keyboard, spread across keyboard width
let kbd_width = all_rows
.iter()
.enumerate()
.map(|(i, row)| {
let off = offsets.get(i).copied().unwrap_or(0);
off + row.len() as u16 * key_step
})
.max()
.unwrap_or(inner.width)
.min(inner.width);
let mod_y = if show_shifted {
inner.y + all_rows.len() as u16 * 2 + 1
} else {
@@ -935,7 +965,7 @@ impl StatsDashboard<'_> {
let time_ms = self.get_key_time_ms(key);
let fg_color = timing_color(time_ms, colors);
buf.set_string(
inner.x + positions[i],
keyboard_x + positions[i],
mod_y,
&labels[i],
Style::default().fg(fg_color),
@@ -1261,6 +1291,334 @@ impl StatsDashboard<'_> {
Paragraph::new(lines).render(inner, buf);
}
// --- N-grams tab ---
fn render_ngram_tab(&self, area: Rect, buf: &mut Buffer) {
let colors = &self.theme.colors;
let data = match self.ngram_data {
Some(d) => d,
None => {
let msg = Paragraph::new(Line::from(Span::styled(
"Complete some adaptive drills to see n-gram data",
Style::default().fg(colors.text_pending()),
)));
msg.render(area, buf);
return;
}
};
let layout = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(4), // focus box
Constraint::Min(5), // lists
Constraint::Length(2), // summary
])
.split(area);
self.render_ngram_focus(data, layout[0], buf);
let wide = layout[1].width >= 60;
if wide {
let lists = Layout::default()
.direction(Direction::Horizontal)
.constraints([Constraint::Percentage(50), Constraint::Percentage(50)])
.split(layout[1]);
self.render_error_anomalies(data, lists[0], buf);
self.render_speed_anomalies(data, lists[1], buf);
} else {
// Stacked vertically for narrow terminals
let available = layout[1].height;
if available < 10 {
// Only show error anomalies if very little space
self.render_error_anomalies(data, layout[1], buf);
} else {
let half = available / 2;
let lists = Layout::default()
.direction(Direction::Vertical)
.constraints([Constraint::Length(half), Constraint::Min(0)])
.split(layout[1]);
self.render_error_anomalies(data, lists[0], buf);
self.render_speed_anomalies(data, lists[1], buf);
}
}
self.render_ngram_summary(data, layout[2], buf);
}
fn render_ngram_focus(&self, data: &NgramTabData, area: Rect, buf: &mut Buffer) {
let colors = &self.theme.colors;
let block = Block::bordered()
.title(Line::from(Span::styled(
" Active Focus ",
Style::default()
.fg(colors.accent())
.add_modifier(Modifier::BOLD),
)))
.border_style(Style::default().fg(colors.accent()));
let inner = block.inner(area);
block.render(area, buf);
if inner.height < 1 {
return;
}
let mut lines = Vec::new();
match (&data.focus.char_focus, &data.focus.bigram_focus) {
(Some(ch), Some((key, anomaly_pct, anomaly_type))) => {
let bigram_label = format!("\"{}{}\"", key.0[0], key.0[1]);
// Line 1: both focuses
lines.push(Line::from(vec![
Span::styled(" Focus: ", Style::default().fg(colors.fg())),
Span::styled(
format!("Char '{ch}'"),
Style::default()
.fg(colors.focused_key())
.add_modifier(Modifier::BOLD),
),
Span::styled(" + ", Style::default().fg(colors.fg())),
Span::styled(
format!("Bigram {bigram_label}"),
Style::default()
.fg(colors.focused_key())
.add_modifier(Modifier::BOLD),
),
]));
// Line 2: details
if inner.height >= 2 {
let type_label = match anomaly_type {
AnomalyType::Error => "error",
AnomalyType::Speed => "speed",
};
let detail = format!(
" Char '{ch}': weakest key | Bigram {bigram_label}: {type_label} anomaly {anomaly_pct:.0}%"
);
lines.push(Line::from(Span::styled(
detail,
Style::default().fg(colors.text_pending()),
)));
}
}
(Some(ch), None) => {
lines.push(Line::from(vec![
Span::styled(" Focus: ", Style::default().fg(colors.fg())),
Span::styled(
format!("Char '{ch}'"),
Style::default()
.fg(colors.focused_key())
.add_modifier(Modifier::BOLD),
),
]));
if inner.height >= 2 {
lines.push(Line::from(Span::styled(
format!(" Char '{ch}': weakest key, no confirmed bigram anomalies"),
Style::default().fg(colors.text_pending()),
)));
}
}
(None, Some((key, anomaly_pct, anomaly_type))) => {
let bigram_label = format!("\"{}{}\"", key.0[0], key.0[1]);
let type_label = match anomaly_type {
AnomalyType::Error => "error",
AnomalyType::Speed => "speed",
};
lines.push(Line::from(vec![
Span::styled(" Focus: ", Style::default().fg(colors.fg())),
Span::styled(
format!("Bigram {bigram_label}"),
Style::default()
.fg(colors.focused_key())
.add_modifier(Modifier::BOLD),
),
Span::styled(
format!(" ({type_label} anomaly: {anomaly_pct:.0}%)"),
Style::default().fg(colors.text_pending()),
),
]));
}
(None, None) => {
lines.push(Line::from(Span::styled(
" Complete some adaptive drills to see focus data",
Style::default().fg(colors.text_pending()),
)));
}
}
Paragraph::new(lines).render(inner, buf);
}
fn render_anomaly_panel(
&self,
title: &str,
empty_msg: &str,
rows: &[AnomalyBigramRow],
is_speed: bool,
area: Rect,
buf: &mut Buffer,
) {
let colors = &self.theme.colors;
let block = Block::bordered()
.title(Line::from(Span::styled(
title.to_string(),
Style::default()
.fg(colors.accent())
.add_modifier(Modifier::BOLD),
)))
.border_style(Style::default().fg(colors.accent()));
let inner = block.inner(area);
block.render(area, buf);
if inner.height < 1 {
return;
}
if rows.is_empty() {
buf.set_string(
inner.x,
inner.y,
empty_msg,
Style::default().fg(colors.text_pending()),
);
return;
}
let narrow = inner.width < 30;
// Error table: Bigram Anom% Rate Errors Smp Strk
// Speed table: Bigram Anom% Speed Smp Strk
let header = if narrow {
if is_speed {
" Bgrm Speed Expct Anom%"
} else {
" Bgrm Err Smp Rate Exp Anom%"
}
} else if is_speed {
" Bigram Speed Expect Samples Anom%"
} else {
" Bigram Errors Samples Rate Expect Anom%"
};
buf.set_string(
inner.x,
inner.y,
header,
Style::default()
.fg(colors.accent())
.add_modifier(Modifier::BOLD),
);
let max_rows = (inner.height as usize).saturating_sub(1);
for (i, row) in rows.iter().take(max_rows).enumerate() {
let y = inner.y + 1 + i as u16;
if y >= inner.y + inner.height {
break;
}
let line = if narrow {
if is_speed {
format!(
" {:>4} {:>3.0}ms {:>3.0}ms {:>4.0}%",
row.bigram, row.speed_ms, row.expected_baseline, row.anomaly_pct,
)
} else {
format!(
" {:>4} {:>3} {:>3} {:>3.0}% {:>2.0}% {:>4.0}%",
row.bigram,
row.error_count,
row.sample_count,
row.error_rate_ema * 100.0,
row.expected_baseline * 100.0,
row.anomaly_pct,
)
}
} else if is_speed {
format!(
" {:>6} {:>4.0}ms {:>4.0}ms {:>5} {:>4.0}%",
row.bigram,
row.speed_ms,
row.expected_baseline,
row.sample_count,
row.anomaly_pct,
)
} else {
format!(
" {:>6} {:>5} {:>5} {:>4.0}% {:>4.0}% {:>5.0}%",
row.bigram,
row.error_count,
row.sample_count,
row.error_rate_ema * 100.0,
row.expected_baseline * 100.0,
row.anomaly_pct,
)
};
let color = if row.confirmed {
colors.error()
} else {
colors.warning()
};
buf.set_string(inner.x, y, &line, Style::default().fg(color));
}
}
fn render_error_anomalies(&self, data: &NgramTabData, area: Rect, buf: &mut Buffer) {
let title = format!(" Error Anomalies ({}) ", data.error_anomalies.len());
self.render_anomaly_panel(
&title,
" No error anomalies detected",
&data.error_anomalies,
false,
area,
buf,
);
}
fn render_speed_anomalies(&self, data: &NgramTabData, area: Rect, buf: &mut Buffer) {
let title = format!(" Speed Anomalies ({}) ", data.speed_anomalies.len());
self.render_anomaly_panel(
&title,
" No speed anomalies detected",
&data.speed_anomalies,
true,
area,
buf,
);
}
fn render_ngram_summary(&self, data: &NgramTabData, area: Rect, buf: &mut Buffer) {
let colors = &self.theme.colors;
let gain_str = match data.latest_trigram_gain {
Some(g) => format!("{:.1}%", g * 100.0),
None => "--".to_string(),
};
let gain_note = if data.latest_trigram_gain.is_none() {
" (computed every 50 drills)"
} else {
""
};
let line = format!(
" Scope: {} | Bigrams: {} | Trigrams: {} | Hesitation: >{:.0}ms | Tri-gain: {}{}",
data.scope_label,
data.total_bigrams,
data.total_trigrams,
data.hesitation_threshold_ms,
gain_str,
gain_note,
);
buf.set_string(
area.x,
area.y,
&line,
Style::default().fg(colors.text_pending()),
);
}
}
fn accuracy_color(accuracy: f64, colors: &crate::ui::theme::ThemeColors) -> ratatui::style::Color {
@@ -1501,3 +1859,79 @@ fn format_duration(secs: f64) -> String {
format!("{s}s")
}
}
/// Compute the ngram tab panel layout for the given terminal area.
/// Returns `(wide, lists_area_height)` where:
/// - `wide` = true means side-by-side anomaly panels (width >= 60)
/// - `lists_area_height` = height available for the anomaly panels region
///
/// When `!wide && lists_area_height < 10`, only error anomalies should render.
#[cfg(test)]
fn ngram_panel_layout(area: Rect) -> (bool, u16) {
let layout = Layout::default()
.direction(Direction::Vertical)
.constraints([
Constraint::Length(4), // focus box
Constraint::Min(5), // lists
Constraint::Length(2), // summary
])
.split(area);
let wide = layout[1].width >= 60;
(wide, layout[1].height)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn narrow_short_terminal_shows_only_error_panel() {
// 50 cols × 15 rows: narrow (<60) so panels stack vertically.
// lists area = 15 - 4 (focus) - 2 (summary) = 9 rows → < 10 → error only.
let area = Rect::new(0, 0, 50, 15);
let (wide, lists_height) = ngram_panel_layout(area);
assert!(!wide, "50 cols should be narrow layout");
assert!(
lists_height < 10,
"lists_height={lists_height}, expected < 10 so only error panel renders"
);
}
#[test]
fn narrow_tall_terminal_stacks_both_panels() {
// 50 cols × 30 rows: narrow (<60) so panels stack vertically.
// lists area = 30 - 4 - 2 = 24 rows → >= 10 → both panels stacked.
let area = Rect::new(0, 0, 50, 30);
let (wide, lists_height) = ngram_panel_layout(area);
assert!(!wide, "50 cols should be narrow layout");
assert!(
lists_height >= 10,
"lists_height={lists_height}, expected >= 10 so both panels stack vertically"
);
}
#[test]
fn wide_terminal_shows_side_by_side_panels() {
// 80 cols × 24 rows: wide (>= 60) so panels render side by side.
let area = Rect::new(0, 0, 80, 24);
let (wide, _) = ngram_panel_layout(area);
assert!(
wide,
"80 cols should be wide layout with side-by-side panels"
);
}
#[test]
fn boundary_width_59_is_narrow() {
let area = Rect::new(0, 0, 59, 24);
let (wide, _) = ngram_panel_layout(area);
assert!(!wide, "59 cols should be narrow");
}
#[test]
fn boundary_width_60_is_wide() {
let area = Rect::new(0, 0, 60, 24);
let (wide, _) = ngram_panel_layout(area);
assert!(wide, "60 cols should be wide");
}
}

View File

@@ -103,7 +103,9 @@ fn contrast_ratio(a: ratatui::style::Color, b: ratatui::style::Color) -> f64 {
(hi + 0.05) / (lo + 0.05)
}
fn choose_cursor_colors(colors: &crate::ui::theme::ThemeColors) -> (ratatui::style::Color, ratatui::style::Color) {
fn choose_cursor_colors(
colors: &crate::ui::theme::ThemeColors,
) -> (ratatui::style::Color, ratatui::style::Color) {
use ratatui::style::Color;
let base_bg = colors.bg();
@@ -113,7 +115,13 @@ fn choose_cursor_colors(colors: &crate::ui::theme::ThemeColors) -> (ratatui::sty
if contrast_ratio(cursor_bg, base_bg) < 1.8 {
let mut best_bg = cursor_bg;
let mut best_ratio = contrast_ratio(cursor_bg, base_bg);
for candidate in [colors.accent(), colors.focused_key(), colors.warning(), Color::Black, Color::White] {
for candidate in [
colors.accent(),
colors.focused_key(),
colors.warning(),
Color::Black,
Color::White,
] {
let ratio = contrast_ratio(candidate, base_bg);
if ratio > best_ratio {
best_bg = candidate;

View File

@@ -91,8 +91,12 @@ pub fn centered_rect(percent_x: u16, percent_y: u16, area: Rect) -> Rect {
let target_w = requested_w.max(MIN_POPUP_WIDTH).min(area.width);
let target_h = requested_h.max(MIN_POPUP_HEIGHT).min(area.height);
let left = area.x.saturating_add((area.width.saturating_sub(target_w)) / 2);
let top = area.y.saturating_add((area.height.saturating_sub(target_h)) / 2);
let left = area
.x
.saturating_add((area.width.saturating_sub(target_w)) / 2);
let top = area
.y
.saturating_add((area.height.saturating_sub(target_h)) / 2);
Rect::new(left, top, target_w, target_h)
}