diff --git a/docs/plans/2026-02-17-code-drill-feature-parity-plan.md b/docs/plans/2026-02-17-code-drill-feature-parity-plan.md new file mode 100644 index 0000000..3cb1f1f --- /dev/null +++ b/docs/plans/2026-02-17-code-drill-feature-parity-plan.md @@ -0,0 +1,757 @@ +# Code Drill Feature Parity Plan + +## Context + +The code drill feature is significantly less developed than the passage drill. The passage drill has a full onboarding flow, lazy downloads with progress bars, configurable network/cache settings, and rich content from Project Gutenberg. The code drill only has 4 hardcoded languages with ~20-30 built-in snippets each, a basic language selection screen, and a partially-implemented synchronous GitHub fetch that blocks the UI thread. There's also a completely dead `github_code.rs` file that's never used. + +This plan is split into three delivery phases: +1. **Phase 1**: Feature parity with passage drill (onboarding, downloads, progress bar, config) +2. **Phase 2**: Language expansion and extraction improvements +3. **Phase 3**: Custom repo support + +## Current Code Drill Analysis + +### What exists: +- **`generator/code_syntax.rs`**: `CodeSyntaxGenerator` with built-in snippets for 4 languages (rust, python, javascript, go), a `try_fetch_code()` that synchronously fetches from hardcoded GitHub URLs (blocking UI), `extract_code_snippets()` for parsing functions from source +- **`generator/code_patterns.rs`**: Post-processor that inserts code-like expressions into adaptive drill text (unrelated to code drill mode) +- **`generator/github_code.rs`**: **Dead code** - `GitHubCodeGenerator` struct with `#[allow(dead_code)]`, never referenced outside its own file +- **Config**: Only `code_language: String` - no download/network/onboarding settings +- **Screens**: `CodeLanguageSelect` only - no intro, no download progress +- **Languages**: rust, python, javascript, go, "all" + +### What passage drill has that code drill doesn't: +- Onboarding intro screen (`PassageIntro`) with config for downloads/dir/limits +- `passage_onboarding_done` flag (shows intro only on first use) +- `passage_downloads_enabled` toggle +- `passage_download_dir` configurable path +- `passage_paragraphs_per_book` content limit +- Lazy download: on drill start, downloads one book if not cached +- Background download thread with atomic progress reporting +- Download progress screen (`PassageDownloadProgress`) with byte-level progress bar +- Fallback to built-in content when downloads off + +### Built-in snippet whitespace review: +- **Rust**: 4-space indent - idiomatic +- **Python**: 4-space indent - idiomatic +- **JavaScript**: 4-space indent - idiomatic +- **Go**: `\t` tab indent - idiomatic + +All whitespace is correct. The escaped string format (`\n`, `\t`, `\"`) is hard to read. Converting to raw strings (`r#"..."#`) improves maintainability. + +--- + +## Phase 1: Feature Parity with Passage Drill + +Goal: Give code drill the same onboarding, download, caching, and config infrastructure as passage drill. Keep the existing 4 languages. No language expansion yet. + +### Step 1.1: Delete dead code + +- Delete `src/generator/github_code.rs` entirely +- Remove `pub mod github_code;` from `src/generator/mod.rs` + +### Step 1.2: Convert built-in snippets to raw strings + +**File**: `src/generator/code_syntax.rs` + +Convert all 4 language snippet arrays from escaped strings to `r#"..."#` raw strings. Example: + +Before: `"fn main() {\n println!(\"hello\");\n}"` +After: +```rust +r#"fn main() { + println!("hello"); +}"# +``` + +Go snippets: `\t` becomes actual tab characters inside raw strings (correct for Go). + +Keep all existing snippets at their current count (~20-30 per language). Do NOT reduce them -- since downloads default to off, these are the primary content source for new users. + +Validation: run `cargo test` after conversion. Add a focused test that asserts a sample snippet's char content matches expectations (catches any accidental whitespace changes). + +### Step 1.3: Add config fields for code drill + +**File**: `src/config.rs` + +Add fields mirroring passage drill config: + +```rust +#[serde(default = "default_code_downloads_enabled")] +pub code_downloads_enabled: bool, // default: false +#[serde(default = "default_code_download_dir")] +pub code_download_dir: String, // default: dirs::data_dir()/keydr/code/ +#[serde(default = "default_code_snippets_per_repo")] +pub code_snippets_per_repo: usize, // default: 50 +#[serde(default = "default_code_onboarding_done")] +pub code_onboarding_done: bool, // default: false +``` + +`code_download_dir` default uses `dirs::data_dir()` (same pattern as `default_passage_download_dir`) for cross-platform portability. + +`code_snippets_per_repo` is a **download-time extraction cap**: when fetching from a repo, extract at most this many snippets and write them to cache. The generator reads whatever is in the cache without re-filtering. + +Update `Default` impl. Add `default_*` functions. + +**Config normalization**: After deserialization in `App::new()` (not `Config::load()`, to avoid coupling config to generator internals), validate `code_language` against `code_language_options()`. If invalid (e.g., old/renamed key), reset to `"rust"`. + +**Old cache migration**: The old `DiskCache("code_cache")` entries (in `~/.local/share/keydr/code_cache/`) are simply ignored. They used a different key format (`{lang}_snippets`) and location. No migration or cleanup needed -- they'll be naturally superseded by the new cache in `code_download_dir`. + +### Step 1.4: Define language data structures + +**File**: `src/generator/code_syntax.rs` + +Add structures for the language registry. Phase 1 only populates the 4 existing languages + "all": + +```rust +pub struct CodeLanguage { + pub key: &'static str, // filesystem-safe identifier (e.g. "rust", "bash") + pub display_name: &'static str, // UI label (e.g. "Rust", "Shell/Bash") + pub extensions: &'static [&'static str], // e.g. &[".rs"], &[".py", ".pyi"] + pub repos: &'static [CodeRepo], + pub has_builtin: bool, +} + +pub struct CodeRepo { + pub key: &'static str, // filesystem-safe identifier for cache naming + pub urls: &'static [&'static str], // raw.githubusercontent.com file URLs to fetch +} + +pub const CODE_LANGUAGES: &[CodeLanguage] = &[ + CodeLanguage { + key: "rust", + display_name: "Rust", + extensions: &[".rs"], + repos: &[ + CodeRepo { + key: "tokio", + urls: &[ + "https://raw.githubusercontent.com/tokio-rs/tokio/master/tokio/src/sync/mutex.rs", + "https://raw.githubusercontent.com/tokio-rs/tokio/master/tokio/src/net/tcp/stream.rs", + ], + }, + CodeRepo { + key: "serde", + urls: &[ + "https://raw.githubusercontent.com/serde-rs/serde/master/serde/src/ser/mod.rs", + ], + }, + ], + has_builtin: true, + }, + // ... python, javascript, go with similar structure + // Move existing hardcoded URLs from try_fetch_code() into these repo definitions +]; +``` + +Helper functions: +```rust +pub fn code_language_options() -> Vec<(&'static str, String)> +// Returns [("rust", "Rust"), ("python", "Python"), ..., ("all", "All (random)")] + +pub fn language_by_key(key: &str) -> Option<&'static CodeLanguage> + +pub fn is_language_cached(cache_dir: &str, key: &str) -> bool +// Checks if any {key}_*.txt files exist in cache_dir AND have non-empty content (>0 bytes) +// Uses direct filesystem scanning (NOT DiskCache -- DiskCache has no list/glob API) +``` + +### Step 1.5: Generalize download job struct + +**File**: `src/app.rs` + +Rename `PassageDownloadJob` to `DownloadJob`. It's already generic (just `Arc`, `Arc`, and a thread handle). Update all passage references to use the renamed type. No behavior change. + +### Step 1.6: Add code drill app state + +**File**: `src/app.rs` + +Add `CodeDownloadCompleteAction` enum (parallels `PassageDownloadCompleteAction`): +```rust +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CodeDownloadCompleteAction { + StartCodeDrill, + ReturnToSettings, +} +``` + +Add screen variants: +```rust +CodeIntro, // Onboarding screen for code drill +CodeDownloadProgress, // Download progress for code files +``` + +Add app fields: +```rust +pub code_intro_selected: usize, +pub code_intro_downloads_enabled: bool, +pub code_intro_download_dir: String, +pub code_intro_snippets_per_repo: usize, +pub code_intro_downloading: bool, +pub code_intro_download_total: usize, +pub code_intro_downloaded: usize, +pub code_intro_current_repo: String, +pub code_intro_download_bytes: u64, +pub code_intro_download_bytes_total: u64, +pub code_download_queue: Vec, // repo indices within current language's repos array +pub code_drill_language_override: Option, +pub code_download_action: CodeDownloadCompleteAction, +code_download_job: Option, +``` + +### Step 1.7: Remove blocking fetch from generator + +**File**: `src/generator/code_syntax.rs` + +Remove `try_fetch_code()` from `CodeSyntaxGenerator`. All network I/O moves to the app layer with background threads. + +Update constructor: +```rust +pub fn new(rng: SmallRng, language: &str, cache_dir: &str) -> Self +``` + +Update `load_cached_snippets()`: scan `cache_dir` for files matching `{language}_*.txt`, read each, split on `---SNIPPET---` delimiter. This replaces the `DiskCache("code_cache")` approach with direct filesystem reads (since `DiskCache` has no listing/glob API and the cache dir is now user-configurable). + +### Step 1.8: Add download function + +**File**: `src/generator/code_syntax.rs` + +```rust +pub fn download_code_repo_to_cache_with_progress( + cache_dir: &str, + language_key: &str, + repo: &CodeRepo, + snippets_limit: usize, + on_progress: F, +) -> bool +where + F: FnMut(u64, Option), +``` + +This function: +1. Creates `cache_dir` if needed (`fs::create_dir_all`) +2. Fetches each URL in `repo.urls` using `fetch_url_bytes_with_progress` (already exists in `cache.rs`) +3. Runs `extract_code_snippets()` on each fetched file +4. Combines all snippets, truncates to `snippets_limit` +5. Writes to `{cache_dir}/{language_key}_{repo.key}.txt` with `---SNIPPET---` delimiter +6. Returns `true` on success + +**Error handling**: If any individual URL fails (404, timeout, network error), skip it and continue with others. If zero snippets extracted from all URLs, return `false`. The app layer treats `false` as "skip this repo, continue queue" (same as passage drill's failure behavior). + +### Step 1.9: Implement code drill flow methods + +**File**: `src/app.rs` + +**`go_to_code_intro()`**: Initialize intro screen state (downloads toggle, dir, snippets limit from config). Set `code_download_action = CodeDownloadCompleteAction::StartCodeDrill`. Set screen to `CodeIntro`. + +**`start_code_drill()`**: Lazy download logic with explicit language resolution: + +```rust +pub fn start_code_drill(&mut self) { + // Step 1: Resolve concrete language (never download with "all" selected) + if self.code_drill_language_override.is_none() { + let chosen = if self.config.code_language == "all" { + // Pick from languages with built-in OR cached content only + // Never pick a network-only language that isn't cached + let available = languages_with_content(&self.config.code_download_dir); + if available.is_empty() { + "rust".to_string() // ultimate fallback + } else { + let idx = self.rng.gen_range(0..available.len()); + available[idx].to_string() + } + } else { + self.config.code_language.clone() + }; + self.code_drill_language_override = Some(chosen); + } + + let chosen = self.code_drill_language_override.clone().unwrap(); + + // Step 2: Check if we need to download + if self.config.code_downloads_enabled + && !is_language_cached(&self.config.code_download_dir, &chosen) + { + if let Some(lang) = language_by_key(&chosen) { + if !lang.repos.is_empty() { + // Pick one random repo to download + let repo_idx = self.rng.gen_range(0..lang.repos.len()); + self.code_download_queue = vec![repo_idx]; + self.code_intro_download_total = 1; + self.code_intro_downloaded = 0; + self.code_intro_downloading = true; + self.code_intro_current_repo = format!("{}", lang.repos[repo_idx].key); + self.code_download_action = CodeDownloadCompleteAction::StartCodeDrill; + self.code_download_job = None; + self.screen = AppScreen::CodeDownloadProgress; + return; + } + } + // Language has no repos or unknown: fall through to built-in + } + + // Step 3: If language has no built-in AND no cache AND downloads off → fallback + if !is_language_cached(&self.config.code_download_dir, &chosen) { + if let Some(lang) = language_by_key(&chosen) { + if !lang.has_builtin { + // Network-only language with no cache: fall back to "rust" + self.code_drill_language_override = Some("rust".to_string()); + } + } + } + + // Step 4: Start the drill + self.drill_mode = DrillMode::Code; + self.drill_scope = DrillScope::Global; + self.start_drill(); +} +``` + +Key behavior: `"all"` only selects from `languages_with_content()` (built-in OR cached). This prevents the dead-end loop of repeatedly picking uncached network-only languages and forcing download screens. In Phase 2, once network-only languages get cached via manual download, they are automatically included in `"all"` selection. + +**`languages_with_content(cache_dir: &str) -> Vec<&'static str>`**: Returns language keys that have either `has_builtin: true` or non-empty cache files in `cache_dir`. + +**`process_code_download_tick()`**, **`spawn_code_download_job()`**: Same pattern as passage equivalents, using `download_code_repo_to_cache_with_progress` and `DownloadJob`. + +**`start_code_downloads_from_settings()`**: Mirror `start_passage_downloads_from_settings()` with `CodeDownloadCompleteAction::ReturnToSettings`. + +### Step 1.10: Update code language select flow + +**File**: `src/main.rs` + +Update `handle_code_language_key()` and `render_code_language_select()`: +- Still shows the same 4+1 languages for now (Phase 2 expands this) +- Wire Enter to `confirm_code_language_and_continue()`: + +```rust +fn confirm_code_language_and_continue(app: &mut App, langs: &[&str]) { + if app.code_language_selected >= langs.len() { return; } + app.config.code_language = langs[app.code_language_selected].to_string(); + let _ = app.config.save(); + if app.config.code_onboarding_done { + app.start_code_drill(); + } else { + app.go_to_code_intro(); + } +} +``` + +### Step 1.11: Add event handlers and renderers + +**File**: `src/main.rs` + +Add to screen dispatch in `handle_key()` and `render()`: + +**`handle_code_intro_key()`**: Same field navigation as `handle_passage_intro_key()` but operates on `code_intro_*` fields. 4 fields: +1. Enable network downloads (toggle) +2. Download directory (editable text) +3. Snippets per repo (numeric, adjustable) +4. Start code drill (confirm button) + +On confirm: save config fields, set `code_onboarding_done = true`, call `start_code_drill()`. + +**`handle_code_download_progress_key()`**: Esc/q to cancel. On cancel: +1. Clear `code_download_queue` +2. Set `code_intro_downloading = false` +3. If a `code_download_job` is in-flight, detach it (set to `None` without joining -- the thread will finish and write to cache, which is harmless; the `Arc` atomics keep the thread safe) +4. Reset `code_drill_language_override` to `None` +5. Go to menu + +This matches the existing passage download cancel behavior (passage also does not join/abort in-flight threads on Esc). + +**`render_code_intro()`**: Mirror `render_passage_intro()` layout. Title: "Code Downloads Setup". Explanatory text: "Configure code source settings before your first code drill." / "Downloads are lazy: code is fetched only when first needed." + +**`render_code_download_progress()`**: Mirror `render_passage_download_progress()`. Title: "Downloading Code Source". Show repo name, byte progress bar. + +Update tick handler: +```rust +if (app.screen == AppScreen::CodeIntro + || app.screen == AppScreen::CodeDownloadProgress) + && app.code_intro_downloading +{ + app.process_code_download_tick(); +} +``` + +### Step 1.12: Update generate_text for Code mode + +**File**: `src/app.rs` + +Update `DrillMode::Code` in `generate_text()`: + +```rust +DrillMode::Code => { + let filter = CharFilter::new(('a'..='z').collect()); + let lang = self.code_drill_language_override + .clone() + .unwrap_or_else(|| self.config.code_language.clone()); + let rng = SmallRng::from_rng(&mut self.rng).unwrap(); + let mut generator = CodeSyntaxGenerator::new( + rng, &lang, &self.config.code_download_dir, + ); + self.code_drill_language_override = None; + let text = generator.generate(&filter, None, word_count); + (text, Some(generator.last_source().to_string())) +} +``` + +### Step 1.13: Settings integration + +**Files**: `src/main.rs`, `src/app.rs` + +Add settings rows after existing code language field (index 3): +- Index 4: Code Downloads: On/Off +- Index 5: Code Download Dir: editable path +- Index 6: Code Snippets per Repo: numeric +- Index 7: Download Code Now: action button + +Shift existing passage settings indices up by 4. Update `settings_cycle_forward`/`settings_cycle_backward` and max `settings_selected` bound. + +**"Download Code Now" behavior**: Downloads all uncached curated repos for the currently selected `code_language` only. If `code_language == "all"`, downloads all uncached repos for all curated languages. Does NOT include custom repos. Mirrors passage behavior where "Download Passages Now" downloads all uncached books. + +**`start_code_downloads()`**: Queues all uncached repos for the currently selected language. Used by intro screen "confirm" flow when downloads are enabled. + +### Phase 1 Verification + +1. `cargo build` -- compiles +2. `cargo test` -- all existing tests pass, plus new tests: + - `test_languages_with_content_includes_builtin` -- verifies built-in languages appear in `languages_with_content()` even with empty cache dir + - `test_languages_with_content_excludes_uncached_network_only` -- verifies network-only languages without cache are not returned + - `test_config_serde_defaults` -- verifies new config fields deserialize with correct defaults from empty/old configs + - `test_raw_string_snippets_preserved` -- spot-check that raw string conversion didn't alter snippet content +3. `cargo build --no-default-features` -- compiles, network features gated +4. Manual tests: + - Menu → Code Drill → language select → first time shows CodeIntro + - CodeIntro with downloads off → confirms → starts drill with built-in snippets + - CodeIntro with downloads on → confirms → shows CodeDownloadProgress → downloads repo → starts drill with downloaded content + - Subsequent code drills skip onboarding + - "all" language mode only picks from languages with content (never triggers download) + - Settings shows code drill fields, values persist on restart + - Passage drill flow completely unchanged + - Esc during download progress → returns to menu, no crash + +--- + +## Phase 2: Language Expansion and Extraction Improvements + +Goal: Add 8 more built-in languages and ~18 network-only languages, improve snippet extraction. + +### Step 2.1: Add 8 built-in language snippet sets + +**File**: `src/generator/code_syntax.rs` + +Add ~10-15 raw-string snippets each for: **typescript, java, c, cpp, ruby, swift, bash, lua** + +Language keys: `typescript`/`ts`, `java`, `c`, `cpp`, `ruby`, `swift`, `bash` (display: "Shell/Bash"), `lua` + +All with idiomatic whitespace: +- TypeScript: 4-space indent +- Java: 4-space indent +- C: 4-space indent +- C++: 4-space indent +- Ruby: 2-space indent +- Swift: 4-space indent +- Bash: 2-space indent (common convention) +- Lua: 2-space indent + +Update `get_snippets()` match to include all 12 languages. + +### Step 2.2: Expand language registry to ~30 languages + +**File**: `src/generator/code_syntax.rs` + +Add ~18 network-only entries to `CODE_LANGUAGES` with curated repos: + +kotlin, scala, haskell, elixir, clojure, perl, php, r, dart, zig, nim, ocaml, erlang, julia, objective-c, groovy, csharp, fsharp + +Each gets 2-3 repos with specific raw.githubusercontent.com file URLs. **Exclude SQL and CSS** -- their syntax is too different from procedural code for function-level extraction to work well. + +This is a significant data curation subtask: for each language, identify 2-3 well-known repos with permissive licenses (MIT/Apache/BSD), select 2-5 representative source files per repo with functions/methods to extract. + +**Acceptance threshold**: Each language must yield at least 10 extractable snippets from its curated repos (verified by running `extract_code_snippets` against fetched files). Languages that fall below this threshold should be dropped from the registry rather than shipped with poor content. + +### Step 2.3: Improve snippet extraction + +**File**: `src/generator/code_syntax.rs` + +Add a `func_start_patterns` field to `CodeLanguage`: + +```rust +pub struct CodeLanguage { + // ... existing fields ... + pub block_style: BlockStyle, +} + +pub enum BlockStyle { + Braces(&'static [&'static str]), // fn/def/func patterns, brace-delimited (C, Java, Go, etc.) + Indentation(&'static [&'static str]), // def/class patterns, indentation-delimited (Python) + EndDelimited(&'static [&'static str]), // def/class patterns, closed by `end` keyword (Ruby, Lua, Elixir) +} +``` + +Update `extract_code_snippets()` to accept `BlockStyle`: +- `Braces`: current behavior with configurable start patterns (C, Java, Go, JS, etc.) +- `Indentation`: track indent level changes to find block boundaries (Python only) +- `EndDelimited`: scan for matching `end` keyword at same indent level to close blocks (Ruby, Lua, Elixir) + +Language-specific patterns: +- Java: `["public ", "private ", "protected ", "static ", "class ", "interface "]` +- Ruby: `["def ", "class ", "module "]` (EndDelimited style -- uses `end` keyword to close blocks) +- C/C++: `["int ", "void ", "char ", "float ", "double ", "struct ", "class ", "template"]` +- Swift: `["func ", "class ", "struct ", "enum ", "protocol "]` +- Bash: `["function ", "() {"]` (Braces style, simple) +- etc. + +### Step 2.4: Make language select scrollable + +**File**: `src/main.rs` + +With 30+ languages, the selection screen needs scrolling. Add `code_language_scroll: usize` to `App`. Show a viewport of ~15 items. Add keybindings: +- Up/Down: navigate +- PageUp/PageDown: jump 10 items +- Home/End or `g`/`G`: jump to top/bottom +- `/`: type-to-filter (optional, nice-to-have) + +Mark each language as "(built-in)" or "(download required)" in the list. + +### Phase 2 Verification + +1. `cargo build && cargo test` +2. Manual: verify all 12 built-in languages produce readable snippets with correct indentation +3. Manual: select a network-only language → triggers download → produces good snippets +4. Manual: scrollable language list works, indicators are accurate +5. Verify each built-in language's snippet whitespace is idiomatic + +--- + +## Phase 3: Custom Repo Support + +Goal: Let users specify their own GitHub repos to train on. + +### Step 3.1: Design custom repo fetch strategy + +Custom repos require solving problems that curated repos don't have: +- **Branch discovery**: Use GitHub API `GET /repos/{owner}/{repo}` to find `default_branch`. Requires `User-Agent` header (GitHub rejects requests without it; use `"keydr/{version}"`). Optionally support a `GITHUB_TOKEN` env var for authenticated requests (raises rate limit from 60 to 5000 req/hour). +- **File discovery**: Use GitHub API `GET /repos/{owner}/{repo}/git/trees/{branch}?recursive=1` to list all files, filter by language extensions. Same `User-Agent` and optional auth headers. If the response has `"truncated": true` (repos with >100k files), reject with a user-facing error: "Repository is too large for automatic file discovery. Please use a smaller repo or fork with fewer files." +- **Rate limiting**: Cache the tree response to disk. On 403/429 responses, show error: "GitHub API rate limit reached. Try again later or set GITHUB_TOKEN env var for higher limits." +- **File selection**: From matching files, randomly select 3-5 files to download via raw.githubusercontent.com (no API needed for file content) +- **Language detection**: Match file extensions against `CodeLanguage.extensions` field. If ambiguous or no match, prompt user. +- **All API requests**: Set `Accept: application/vnd.github.v3+json` header, timeout 10s. + +### Step 3.2: Add config field and validation + +**File**: `src/config.rs` + +```rust +#[serde(default)] +pub code_custom_repos: Vec, // Format: "owner/repo" or "owner/repo@language" +``` + +Parse function: +```rust +pub fn parse_custom_repo(input: &str) -> Option { + // Accepts: "owner/repo", "owner/repo@language", "https://github.com/owner/repo" + // Validates: owner and repo contain only valid GitHub chars + // Returns None on invalid input +} +``` + +### Step 3.3: Settings UI for custom repos + +Add a settings section showing current custom repos as a scrollable list. Keybindings: +- `a`: add new repo (enters text input mode) +- `d`/`x`: delete selected repo +- Up/Down: navigate list + +### Step 3.4: Code language select "Add custom repo" option + +At the bottom of the language select list, add an "[ + Add custom repo ]" option. Selecting it enters a text input mode for `owner/repo`. On confirm: +1. Validate format +2. Add to `code_custom_repos` config +3. Auto-detect language from repo (via API tree listing file extensions) +4. If language ambiguous, show a small picker +5. Queue download of that repo + +### Step 3.5: Integrate custom repos into download flow + +When `start_code_drill()` runs for a language, include matching custom repos in the download candidates alongside curated repos. + +### Phase 3 Verification + +1. Add a custom repo → appears in settings list +2. Start drill → custom repo snippets appear +3. Invalid repo format → shows error, doesn't save +4. GitHub rate limit → shows informative error +5. Remove custom repo → removed from config and future drills + +--- + +## Critical Files Summary + +| File | Phase | Changes | +|------|-------|---------| +| `src/generator/github_code.rs` | 1 | Delete | +| `src/generator/mod.rs` | 1 | Remove github_code module | +| `src/generator/code_syntax.rs` | 1, 2 | Raw strings, new constructor, remove blocking fetch, language registry, download fn, new snippet sets, improved extraction | +| `src/config.rs` | 1, 3 | New code drill config fields, validation | +| `src/app.rs` | 1 | DownloadJob rename, new screens/state/flow methods, CodeDownloadCompleteAction | +| `src/main.rs` | 1, 2 | New handlers/renderers, updated settings, scrollable language list | +| `src/generator/cache.rs` | 1 | No changes (reuse existing `fetch_url_bytes_with_progress`) | + +## Existing Code to Reuse + +- `generator::cache::fetch_url_bytes_with_progress` -- already handles progress callbacks, used for passage downloads +- `generator::cache::DiskCache` -- NOT reused for code cache (no listing API); use direct `fs::read_dir` + `fs::read_to_string` instead +- `PassageDownloadJob` pattern (atomics + thread) -- generalized into `DownloadJob` +- `passage::extract_paragraphs` pattern -- referenced for extraction design but not directly reused +- `passage::download_book_to_cache_with_progress` -- structural template for `download_code_repo_to_cache_with_progress` + +--- + +## Phase 2.5: Improve Snippet Extraction Quality + +### Context + +After Phase 2, the verification test (`test_verify_repo_urls`) shows many languages producing far fewer than 100 snippets. Root causes: +1. **Per-file cap of 50** in `extract_code_snippets()` (line 1869) limits output even from large source files +2. **Keyword-only matching** — extraction only starts when a line begins with a recognized keyword (e.g. `fn `, `def `, `class `). Many valid code blocks (anonymous functions, method chains, match arms, closures, etc.) are missed. +3. **Narrow keyword lists** — some languages are missing patterns for common constructs (e.g. `macro_rules!` in Rust, `@interface` in Objective-C) +4. **`code_snippets_per_repo` default of 50** caps total output per download + +### Goal + +Get every language to produce 100+ snippets from its curated repos, without sacrificing snippet quality. Do this by: +1. Widening keyword patterns to capture more language constructs +2. Adding a structural fallback that extracts well-formed code blocks by structure when keywords alone don't find enough +3. Raising the per-file and per-repo snippet caps + +### Step 2.5.1: Raise snippet caps + +**File**: `src/generator/code_syntax.rs` + +Change `snippets.truncate(50)` → `snippets.truncate(200)` in `extract_code_snippets()`. + +**File**: `src/config.rs` + +Change `default_code_snippets_per_repo()` → `200`. + +### Step 2.5.2: Widen keyword patterns + +**File**: `src/generator/code_syntax.rs` + +Add missing start patterns to existing languages. These are patterns that should have been there from the start — they represent common, well-defined constructs that produce good typing drill snippets: + +| Language | Add patterns | +|----------|-------------| +| Rust | `"macro_rules! "`, `"mod "`, `"const "`, `"static "`, `"type "` | +| Python | `"async def "` is already there. Add `"@"` (decorators start blocks) | +| JavaScript | `"class "`, `"const "`, `"let "`, `"export "` | +| Go | No changes needed (already has `"func "`, `"type "`) | +| TypeScript | `"class "`, `"const "`, `"let "`, `"export "`, `"interface "` | +| Java | `"abstract "`, `"final "`, `"@"` (annotations start blocks) | +| C | `"typedef "`, `"#define "`, `"enum "` | +| C++ | `"namespace "`, `"typedef "`, `"#define "`, `"enum "`, `"constexpr "`, `"auto "` | +| Ruby | Add `"attr_"`, `"scope "`, `"describe "`, `"it "` | +| Swift | `"var "`, `"let "`, `"init("`, `"deinit "`, `"extension "`, `"typealias "` | +| Bash | `"if "`, `"for "`, `"while "`, `"case "` | +| Kotlin | `"override fun "` already there. Add `"val "`, `"var "`, `"enum "`, `"annotation "`, `"typealias "` | +| Scala | `"val "`, `"var "`, `"type "`, `"implicit "`, `"given "`, `"extension "` | +| PHP | `"class "`, `"interface "`, `"trait "`, `"enum "` | +| Dart | Add `"Widget "`, `"get "`, `"set "`, `"enum "`, `"typedef "`, `"extension "` | +| Elixir | `"defmacro "`, `"defstruct"`, `"defprotocol "`, `"defimpl "` | +| Zig | `"test "`, `"var "` | +| Haskell | Already broad. No changes. | +| Objective-C | `"@interface "`, `"@implementation "`, `"@protocol "`, `"typedef "` | +| Others | Review on a case-by-case basis during implementation | + +### Step 2.5.3: Add structural fallback extraction + +**File**: `src/generator/code_syntax.rs` + +When keyword-based extraction yields fewer than 20 snippets from a file, run a second pass that extracts code blocks purely by structure. This captures anonymous functions, nested blocks, and other constructs that don't start with recognized keywords. + +#### Design + +Add a `structural_fallback: bool` field to each `BlockStyle` variant: + +```rust +pub enum BlockStyle { + Braces { + patterns: &'static [&'static str], + structural_fallback: bool, + }, + Indentation { + patterns: &'static [&'static str], + structural_fallback: bool, + }, + EndDelimited { + patterns: &'static [&'static str], + structural_fallback: bool, + }, +} +``` + +Set `structural_fallback: true` for all languages. This can be disabled per-language if it produces poor results. + +Update `extract_code_snippets()`: + +```rust +pub fn extract_code_snippets(source: &str, block_style: &BlockStyle) -> Vec { + let mut snippets = keyword_extract(source, block_style); + + if snippets.len() < 20 && has_structural_fallback(block_style) { + let structural = structural_extract(source, block_style); + // Add structural snippets that don't overlap with keyword ones + for s in structural { + if !snippets.contains(&s) { + snippets.push(s); + } + } + } + + snippets.truncate(200); + snippets +} +``` + +#### Structural extraction for Braces languages + +`structural_extract_braces(source)`: +1. Scan for lines containing `{` where brace depth transitions from 0→1 or 1→2 +2. Capture from that line until depth returns to its starting level +3. Apply the same quality filters: 3-30 lines, 20+ non-whitespace chars, ≤800 bytes +4. Skip noise blocks: reject snippets where first non-blank line is only `{`, or where the block is just imports/use statements + +#### Structural extraction for Indentation languages + +`structural_extract_indent(source)`: +1. Scan for non-blank lines at indentation level 0 (top-level) that are followed by indented lines +2. Capture the top-level line + all subsequent lines with greater indentation +3. Apply same quality filters +4. Skip noise: reject if all body lines are `import`/`from`/`use`/`#include` statements + +#### Structural extraction for EndDelimited languages + +`structural_extract_end(source)`: +1. Scan for lines at top-level indentation followed by indented body ending with `end` +2. Same quality filters and noise rejection + +#### Noise filtering + +A snippet is "noise" and should be rejected if: +- First meaningful line (after stripping comments) is just `{` or `}` +- Body consists entirely of `import`, `use`, `from`, `require`, `include`, or blank lines +- It's a single-statement block (only 1 non-blank body line after the opening) + +### Step 2.5.4: Add more source URLs for low-count languages + +After implementing the extraction improvements, re-run `test_verify_repo_urls` to identify languages still under 100 snippets. For those, add 1-2 more source file URLs from the same or new repos to increase raw material. + +This step is intentionally deferred until after extraction improvements, since better extraction may push many languages over the 100 threshold without needing more URLs. + +### Phase 2.5 Verification + +1. `cargo test` — all existing tests pass +2. Run `cargo test test_verify_repo_urls -- --ignored --nocapture` — verify all 30 languages produce 50+ snippets (ideally 100+) +3. Spot-check structural fallback snippets for 3-4 languages — verify they contain real code, not just import blocks or noise +4. `cargo build --no-default-features` — compiles without network features +5. Verify no change to built-in snippet behavior (built-in snippets don't go through extraction) diff --git a/src/app.rs b/src/app.rs index 68f4a7b..95aa990 100644 --- a/src/app.rs +++ b/src/app.rs @@ -16,7 +16,11 @@ use crate::engine::skill_tree::{BranchId, BranchStatus, DrillScope, SkillTree}; use crate::generator::TextGenerator; use crate::generator::capitalize; use crate::generator::code_patterns; -use crate::generator::code_syntax::CodeSyntaxGenerator; +use crate::generator::code_syntax::{ + CodeSyntaxGenerator, build_code_download_queue, code_language_options, + download_code_repo_to_cache_with_progress, is_language_cached, language_by_key, + languages_with_content, +}; use crate::generator::dictionary::Dictionary; use crate::generator::numbers; use crate::generator::passage::{ @@ -48,6 +52,8 @@ pub enum AppScreen { PassageBookSelect, PassageIntro, PassageDownloadProgress, + CodeIntro, + CodeDownloadProgress, } #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -63,7 +69,13 @@ pub enum PassageDownloadCompleteAction { ReturnToSettings, } -struct PassageDownloadJob { +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CodeDownloadCompleteAction { + StartCodeDrill, + ReturnToSettings, +} + +struct DownloadJob { downloaded_bytes: Arc, total_bytes: Arc, done: Arc, @@ -112,6 +124,7 @@ pub struct App { pub skill_tree_detail_scroll: usize, pub drill_source_info: Option, pub code_language_selected: usize, + pub code_language_scroll: usize, pub passage_book_selected: usize, pub passage_intro_selected: usize, pub passage_intro_downloads_enabled: bool, @@ -126,18 +139,37 @@ pub struct App { pub passage_download_queue: Vec, pub passage_drill_selection_override: Option, pub passage_download_action: PassageDownloadCompleteAction, + pub code_intro_selected: usize, + pub code_intro_downloads_enabled: bool, + pub code_intro_download_dir: String, + pub code_intro_snippets_per_repo: usize, + pub code_intro_downloading: bool, + pub code_intro_download_total: usize, + pub code_intro_downloaded: usize, + pub code_intro_current_repo: String, + pub code_intro_download_bytes: u64, + pub code_intro_download_bytes_total: u64, + pub code_download_queue: Vec<(String, usize)>, + pub code_drill_language_override: Option, + pub code_download_attempted: bool, + pub code_download_action: CodeDownloadCompleteAction, pub shift_held: bool, pub keyboard_model: KeyboardModel, rng: SmallRng, transition_table: TransitionTable, #[allow(dead_code)] dictionary: Dictionary, - passage_download_job: Option, + passage_download_job: Option, + code_download_job: Option, } impl App { pub fn new() -> Self { - let config = Config::load().unwrap_or_default(); + let mut config = Config::load().unwrap_or_default(); + + // Normalize code_language: reset to default if not a valid option + let valid_keys: Vec<&str> = code_language_options().iter().map(|(k, _)| *k).collect(); + config.normalize_code_language(&valid_keys); let loaded_theme = Theme::load(&config.theme).unwrap_or_default(); let theme: &'static Theme = Box::leak(Box::new(loaded_theme)); let menu = Menu::new(theme); @@ -183,6 +215,9 @@ impl App { let intro_downloads_enabled = config.passage_downloads_enabled; let intro_download_dir = config.passage_download_dir.clone(); let intro_paragraph_limit = config.passage_paragraphs_per_book; + let code_intro_downloads_enabled = config.code_downloads_enabled; + let code_intro_download_dir = config.code_download_dir.clone(); + let code_intro_snippets_per_repo = config.code_snippets_per_repo; let mut app = Self { screen: AppScreen::Menu, @@ -211,6 +246,7 @@ impl App { skill_tree_detail_scroll: 0, drill_source_info: None, code_language_selected: 0, + code_language_scroll: 0, passage_book_selected: 0, passage_intro_selected: 0, passage_intro_downloads_enabled: intro_downloads_enabled, @@ -225,12 +261,27 @@ impl App { passage_download_queue: Vec::new(), passage_drill_selection_override: None, passage_download_action: PassageDownloadCompleteAction::StartPassageDrill, + code_intro_selected: 0, + code_intro_downloads_enabled, + code_intro_download_dir, + code_intro_snippets_per_repo, + code_intro_downloading: false, + code_intro_download_total: 0, + code_intro_downloaded: 0, + code_intro_current_repo: String::new(), + code_intro_download_bytes: 0, + code_intro_download_bytes_total: 0, + code_download_queue: Vec::new(), + code_drill_language_override: None, + code_download_attempted: false, + code_download_action: CodeDownloadCompleteAction::StartCodeDrill, shift_held: false, keyboard_model, rng: SmallRng::from_entropy(), transition_table, dictionary, passage_download_job: None, + code_download_job: None, }; app.start_drill(); app @@ -368,15 +419,17 @@ impl App { } DrillMode::Code => { let filter = CharFilter::new(('a'..='z').collect()); - let lang = if self.config.code_language == "all" { - let langs = ["rust", "python", "javascript", "go"]; - let idx = self.rng.gen_range(0..langs.len()); - langs[idx].to_string() - } else { - self.config.code_language.clone() - }; + let lang = self + .code_drill_language_override + .clone() + .unwrap_or_else(|| self.config.code_language.clone()); let rng = SmallRng::from_rng(&mut self.rng).unwrap(); - let mut generator = CodeSyntaxGenerator::new(rng, &lang); + let mut generator = CodeSyntaxGenerator::new( + rng, + &lang, + &self.config.code_download_dir, + ); + self.code_drill_language_override = None; let text = generator.generate(&filter, None, word_count); (text, Some(generator.last_source().to_string())) } @@ -648,11 +701,13 @@ impl App { } pub fn go_to_code_language_select(&mut self) { - let langs = ["rust", "python", "javascript", "go", "all"]; - self.code_language_selected = langs + let options = code_language_options(); + self.code_language_selected = options .iter() - .position(|&l| l == self.config.code_language) + .position(|(k, _)| *k == self.config.code_language) .unwrap_or(0); + // Center the selected item in the viewport (rough estimate of 15 visible rows) + self.code_language_scroll = self.code_language_selected.saturating_sub(7); self.screen = AppScreen::CodeLanguageSelect; } @@ -689,6 +744,215 @@ impl App { self.screen = AppScreen::PassageIntro; } + pub fn go_to_code_intro(&mut self) { + self.code_intro_selected = 0; + self.code_intro_downloads_enabled = self.config.code_downloads_enabled; + self.code_intro_download_dir = self.config.code_download_dir.clone(); + self.code_intro_snippets_per_repo = self.config.code_snippets_per_repo; + self.code_intro_downloading = false; + self.code_intro_download_total = 0; + self.code_intro_downloaded = 0; + self.code_intro_current_repo.clear(); + self.code_intro_download_bytes = 0; + self.code_intro_download_bytes_total = 0; + self.code_download_queue.clear(); + self.code_download_job = None; + self.code_download_action = CodeDownloadCompleteAction::StartCodeDrill; + self.code_download_attempted = false; + self.screen = AppScreen::CodeIntro; + } + + pub fn start_code_drill(&mut self) { + // Step 1: Resolve concrete language (never download with "all" selected) + if self.code_drill_language_override.is_none() { + let chosen = if self.config.code_language == "all" { + let available = languages_with_content(&self.config.code_download_dir); + if available.is_empty() { + "rust".to_string() + } else { + let idx = self.rng.gen_range(0..available.len()); + available[idx].to_string() + } + } else { + self.config.code_language.clone() + }; + self.code_drill_language_override = Some(chosen); + } + + let chosen = self.code_drill_language_override.clone().unwrap(); + + // Step 2: Check if we need to download (only if not already attempted) + if self.config.code_downloads_enabled + && !self.code_download_attempted + && !is_language_cached(&self.config.code_download_dir, &chosen) + { + if let Some(lang) = language_by_key(&chosen) { + if !lang.repos.is_empty() { + let repo_idx = self.rng.gen_range(0..lang.repos.len()); + self.code_download_queue = vec![(chosen.clone(), repo_idx)]; + self.code_intro_download_total = 1; + self.code_intro_downloaded = 0; + self.code_intro_downloading = true; + self.code_intro_current_repo = lang.repos[repo_idx].key.to_string(); + self.code_download_action = CodeDownloadCompleteAction::StartCodeDrill; + self.code_download_job = None; + self.code_download_attempted = true; + self.screen = AppScreen::CodeDownloadProgress; + return; + } + } + } + + // Step 3: If language has no built-in AND no cache → fallback + if !is_language_cached(&self.config.code_download_dir, &chosen) { + if let Some(lang) = language_by_key(&chosen) { + if !lang.has_builtin { + self.code_drill_language_override = Some("rust".to_string()); + } + } + } + + // Step 4: Start the drill + self.code_download_attempted = false; + self.drill_mode = DrillMode::Code; + self.drill_scope = DrillScope::Global; + self.start_drill(); + } + + pub fn start_code_downloads(&mut self) { + let queue = build_code_download_queue( + &self.config.code_language, + &self.code_intro_download_dir, + ); + + self.code_intro_download_total = queue.len(); + self.code_download_queue = queue; + self.code_intro_downloaded = 0; + self.code_intro_downloading = self.code_intro_download_total > 0; + self.code_intro_download_bytes = 0; + self.code_intro_download_bytes_total = 0; + self.code_download_job = None; + } + + pub fn start_code_downloads_from_settings(&mut self) { + self.go_to_code_intro(); + self.code_download_action = CodeDownloadCompleteAction::ReturnToSettings; + self.start_code_downloads(); + if !self.code_intro_downloading { + self.go_to_settings(); + } + } + + pub fn process_code_download_tick(&mut self) { + if !self.code_intro_downloading { + return; + } + + if self.code_download_job.is_none() { + let Some((lang_key, repo_idx)) = self.code_download_queue.pop() else { + self.code_intro_downloading = false; + self.code_intro_current_repo.clear(); + match self.code_download_action { + CodeDownloadCompleteAction::StartCodeDrill => self.start_code_drill(), + CodeDownloadCompleteAction::ReturnToSettings => self.go_to_settings(), + } + return; + }; + + self.spawn_code_download_job(&lang_key, repo_idx); + return; + } + + let mut finished = false; + if let Some(job) = self.code_download_job.as_mut() { + self.code_intro_download_bytes = job.downloaded_bytes.load(Ordering::Relaxed); + self.code_intro_download_bytes_total = job.total_bytes.load(Ordering::Relaxed); + finished = job.done.load(Ordering::Relaxed); + } + + if !finished { + return; + } + + if let Some(mut job) = self.code_download_job.take() { + if let Some(handle) = job.handle.take() { + let _ = handle.join(); + } + self.code_intro_downloaded = self.code_intro_downloaded.saturating_add(1); + } + + if self.code_intro_downloaded >= self.code_intro_download_total { + self.code_intro_downloading = false; + self.code_intro_current_repo.clear(); + self.code_intro_download_bytes = 0; + self.code_intro_download_bytes_total = 0; + match self.code_download_action { + CodeDownloadCompleteAction::StartCodeDrill => self.start_code_drill(), + CodeDownloadCompleteAction::ReturnToSettings => self.go_to_settings(), + } + } + } + + fn spawn_code_download_job(&mut self, language_key: &str, repo_idx: usize) { + let Some(lang) = language_by_key(language_key) else { + return; + }; + let Some(repo) = lang.repos.get(repo_idx) else { + return; + }; + + self.code_intro_current_repo = repo.key.to_string(); + self.code_intro_download_bytes = 0; + self.code_intro_download_bytes_total = 0; + + let downloaded_bytes = Arc::new(AtomicU64::new(0)); + let total_bytes = Arc::new(AtomicU64::new(0)); + let done = Arc::new(AtomicBool::new(false)); + let success = Arc::new(AtomicBool::new(false)); + + let dl_clone = Arc::clone(&downloaded_bytes); + let total_clone = Arc::clone(&total_bytes); + let done_clone = Arc::clone(&done); + let success_clone = Arc::clone(&success); + + let cache_dir = self.code_intro_download_dir.clone(); + let lang_key = language_key.to_string(); + let snippets_limit = self.code_intro_snippets_per_repo; + + // Get static references for thread + let repo_ref: &'static crate::generator::code_syntax::CodeRepo = + &lang.repos[repo_idx]; + let block_style_ref: &'static crate::generator::code_syntax::BlockStyle = + &lang.block_style; + + let handle = thread::spawn(move || { + let ok = download_code_repo_to_cache_with_progress( + &cache_dir, + &lang_key, + repo_ref, + block_style_ref, + snippets_limit, + |downloaded, total| { + dl_clone.store(downloaded, Ordering::Relaxed); + if let Some(total) = total { + total_clone.store(total, Ordering::Relaxed); + } + }, + ); + + success_clone.store(ok, Ordering::Relaxed); + done_clone.store(true, Ordering::Relaxed); + }); + + self.code_download_job = Some(DownloadJob { + downloaded_bytes, + total_bytes, + done, + success, + handle: Some(handle), + }); + } + pub fn start_passage_drill(&mut self) { // Lazy source selection: choose a specific source for this drill and // download exactly one missing book when needed. @@ -765,6 +1029,14 @@ impl App { self.passage_download_job = None; } + pub fn cancel_code_download(&mut self) { + self.code_download_queue.clear(); + self.code_intro_downloading = false; + self.code_download_job = None; + self.code_drill_language_override = None; + self.code_download_attempted = false; + } + pub fn start_passage_downloads_from_settings(&mut self) { self.go_to_passage_intro(); self.passage_download_action = PassageDownloadCompleteAction::ReturnToSettings; @@ -867,7 +1139,7 @@ impl App { done_clone.store(true, Ordering::Relaxed); }); - self.passage_download_job = Some(PassageDownloadJob { + self.passage_download_job = Some(DownloadJob { downloaded_bytes, total_bytes, done, @@ -900,21 +1172,37 @@ impl App { self.config.word_count = (self.config.word_count + 5).min(100); } 3 => { - let langs = ["rust", "python", "javascript", "go", "all"]; - let idx = langs + let options = code_language_options(); + let keys: Vec<&str> = options.iter().map(|(k, _)| *k).collect(); + let idx = keys .iter() .position(|&l| l == self.config.code_language) .unwrap_or(0); - let next = (idx + 1) % langs.len(); - self.config.code_language = langs[next].to_string(); + let next = (idx + 1) % keys.len(); + self.config.code_language = keys[next].to_string(); } 4 => { - self.config.passage_downloads_enabled = !self.config.passage_downloads_enabled; + self.config.code_downloads_enabled = !self.config.code_downloads_enabled; } 5 => { // Editable text field handled directly in key handler. } 6 => { + self.config.code_snippets_per_repo = + match self.config.code_snippets_per_repo { + 0 => 1, + n if n >= 200 => 0, + n => n + 10, + }; + } + // 7 = Download Code Now (action button) + 8 => { + self.config.passage_downloads_enabled = !self.config.passage_downloads_enabled; + } + 9 => { + // Passage download dir - editable text field handled directly in key handler. + } + 10 => { self.config.passage_paragraphs_per_book = match self.config.passage_paragraphs_per_book { 0 => 1, @@ -950,21 +1238,37 @@ impl App { self.config.word_count = self.config.word_count.saturating_sub(5).max(5); } 3 => { - let langs = ["rust", "python", "javascript", "go", "all"]; - let idx = langs + let options = code_language_options(); + let keys: Vec<&str> = options.iter().map(|(k, _)| *k).collect(); + let idx = keys .iter() .position(|&l| l == self.config.code_language) .unwrap_or(0); - let next = if idx == 0 { langs.len() - 1 } else { idx - 1 }; - self.config.code_language = langs[next].to_string(); + let next = if idx == 0 { keys.len() - 1 } else { idx - 1 }; + self.config.code_language = keys[next].to_string(); } 4 => { - self.config.passage_downloads_enabled = !self.config.passage_downloads_enabled; + self.config.code_downloads_enabled = !self.config.code_downloads_enabled; } 5 => { // Editable text field handled directly in key handler. } 6 => { + self.config.code_snippets_per_repo = + match self.config.code_snippets_per_repo { + 0 => 200, + 1 => 0, + n => n.saturating_sub(10).max(1), + }; + } + // 7 = Download Code Now (action button) + 8 => { + self.config.passage_downloads_enabled = !self.config.passage_downloads_enabled; + } + 9 => { + // Passage download dir - editable text field handled directly in key handler. + } + 10 => { self.config.passage_paragraphs_per_book = match self.config.passage_paragraphs_per_book { 0 => 500, diff --git a/src/config.rs b/src/config.rs index 4d87264..c1fe858 100644 --- a/src/config.rs +++ b/src/config.rs @@ -26,6 +26,14 @@ pub struct Config { pub passage_paragraphs_per_book: usize, #[serde(default = "default_passage_onboarding_done")] pub passage_onboarding_done: bool, + #[serde(default = "default_code_downloads_enabled")] + pub code_downloads_enabled: bool, + #[serde(default = "default_code_download_dir")] + pub code_download_dir: String, + #[serde(default = "default_code_snippets_per_repo")] + pub code_snippets_per_repo: usize, + #[serde(default = "default_code_onboarding_done")] + pub code_onboarding_done: bool, } fn default_target_wpm() -> u32 { @@ -63,6 +71,23 @@ fn default_passage_paragraphs_per_book() -> usize { fn default_passage_onboarding_done() -> bool { false } +fn default_code_downloads_enabled() -> bool { + false +} +fn default_code_download_dir() -> String { + dirs::data_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join("keydr") + .join("code") + .to_string_lossy() + .to_string() +} +fn default_code_snippets_per_repo() -> usize { + 200 +} +fn default_code_onboarding_done() -> bool { + false +} impl Default for Config { fn default() -> Self { @@ -77,6 +102,10 @@ impl Default for Config { passage_download_dir: default_passage_download_dir(), passage_paragraphs_per_book: default_passage_paragraphs_per_book(), passage_onboarding_done: default_passage_onboarding_done(), + code_downloads_enabled: default_code_downloads_enabled(), + code_download_dir: default_code_download_dir(), + code_snippets_per_repo: default_code_snippets_per_repo(), + code_onboarding_done: default_code_onboarding_done(), } } } @@ -114,4 +143,97 @@ impl Config { pub fn target_cpm(&self) -> f64 { self.target_wpm as f64 * 5.0 } + + /// Validate `code_language` against known options, resetting to default if invalid. + /// Call after deserialization to handle stale/renamed keys from old configs. + pub fn normalize_code_language(&mut self, valid_keys: &[&str]) { + // Backwards compatibility: old "shell" key is now "bash". + if self.code_language == "shell" { + self.code_language = "bash".to_string(); + } + if !valid_keys.contains(&self.code_language.as_str()) { + self.code_language = default_code_language(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_serde_defaults_from_empty() { + // Simulates loading an old config file with no code drill fields + let config: Config = toml::from_str("").unwrap(); + assert_eq!(config.code_downloads_enabled, false); + assert_eq!(config.code_snippets_per_repo, 200); + assert_eq!(config.code_onboarding_done, false); + assert!(!config.code_download_dir.is_empty()); + assert!(config.code_download_dir.contains("code")); + } + + #[test] + fn test_config_serde_defaults_from_old_fields_only() { + // Simulates a config file that only has pre-existing fields + let toml_str = r#" +target_wpm = 60 +theme = "monokai" +code_language = "go" +"#; + let config: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(config.target_wpm, 60); + assert_eq!(config.theme, "monokai"); + assert_eq!(config.code_language, "go"); + // New fields should have defaults + assert_eq!(config.code_downloads_enabled, false); + assert_eq!(config.code_snippets_per_repo, 200); + assert_eq!(config.code_onboarding_done, false); + } + + #[test] + fn test_config_serde_roundtrip() { + let config = Config::default(); + let serialized = toml::to_string_pretty(&config).unwrap(); + let deserialized: Config = toml::from_str(&serialized).unwrap(); + assert_eq!(config.code_downloads_enabled, deserialized.code_downloads_enabled); + assert_eq!(config.code_download_dir, deserialized.code_download_dir); + assert_eq!(config.code_snippets_per_repo, deserialized.code_snippets_per_repo); + assert_eq!(config.code_onboarding_done, deserialized.code_onboarding_done); + } + + #[test] + fn test_normalize_code_language_valid_key_unchanged() { + let mut config = Config::default(); + config.code_language = "python".to_string(); + let valid_keys = vec!["rust", "python", "javascript", "go", "all"]; + config.normalize_code_language(&valid_keys); + assert_eq!(config.code_language, "python"); + } + + #[test] + fn test_normalize_code_language_invalid_key_resets() { + let mut config = Config::default(); + config.code_language = "haskell".to_string(); + let valid_keys = vec!["rust", "python", "javascript", "go", "all"]; + config.normalize_code_language(&valid_keys); + assert_eq!(config.code_language, "rust"); + } + + #[test] + fn test_normalize_code_language_empty_string_resets() { + let mut config = Config::default(); + config.code_language = String::new(); + let valid_keys = vec!["rust", "python", "javascript", "go", "all"]; + config.normalize_code_language(&valid_keys); + assert_eq!(config.code_language, "rust"); + } + + #[test] + fn test_normalize_code_language_shell_maps_to_bash() { + let mut config = Config::default(); + config.code_language = "shell".to_string(); + let valid_keys = vec!["rust", "python", "javascript", "go", "bash", "all"]; + config.normalize_code_language(&valid_keys); + assert_eq!(config.code_language, "bash"); + } } diff --git a/src/generator/cache.rs b/src/generator/cache.rs index e4c0709..384e4d6 100644 --- a/src/generator/cache.rs +++ b/src/generator/cache.rs @@ -3,10 +3,12 @@ use std::fs; use std::io::Read; use std::path::PathBuf; +#[allow(dead_code)] pub struct DiskCache { base_dir: PathBuf, } +#[allow(dead_code)] impl DiskCache { pub fn new(subdir: &str) -> Option { let base = dirs::data_dir()?.join("keydr").join(subdir); @@ -37,6 +39,7 @@ impl DiskCache { } } +#[allow(dead_code)] #[cfg(feature = "network")] pub fn fetch_url(url: &str) -> Option { let client = reqwest::blocking::Client::builder() @@ -51,6 +54,7 @@ pub fn fetch_url(url: &str) -> Option { } } +#[allow(dead_code)] #[cfg(not(feature = "network"))] pub fn fetch_url(_url: &str) -> Option { None diff --git a/src/generator/code_syntax.rs b/src/generator/code_syntax.rs index 12883a8..8cafe8e 100644 --- a/src/generator/code_syntax.rs +++ b/src/generator/code_syntax.rs @@ -1,9 +1,786 @@ +use std::fs; + use rand::Rng; use rand::rngs::SmallRng; use crate::engine::filter::CharFilter; use crate::generator::TextGenerator; -use crate::generator::cache::{DiskCache, fetch_url}; +use crate::generator::cache::fetch_url_bytes_with_progress; + +pub enum BlockStyle { + Braces(&'static [&'static str]), + Indentation(&'static [&'static str]), + EndDelimited(&'static [&'static str]), +} + +pub struct CodeLanguage { + pub key: &'static str, + pub display_name: &'static str, + #[allow(dead_code)] + pub extensions: &'static [&'static str], + pub repos: &'static [CodeRepo], + pub has_builtin: bool, + pub block_style: BlockStyle, +} + +pub struct CodeRepo { + pub key: &'static str, + pub urls: &'static [&'static str], +} + +pub const CODE_LANGUAGES: &[CodeLanguage] = &[ + // === Built-in languages (has_builtin: true) === + CodeLanguage { + key: "rust", + display_name: "Rust", + extensions: &[".rs"], + repos: &[ + CodeRepo { + key: "tokio", + urls: &[ + "https://raw.githubusercontent.com/tokio-rs/tokio/master/tokio/src/sync/mutex.rs", + "https://raw.githubusercontent.com/tokio-rs/tokio/master/tokio/src/net/tcp/stream.rs", + ], + }, + CodeRepo { + key: "ripgrep", + urls: &[ + "https://raw.githubusercontent.com/BurntSushi/ripgrep/master/crates/regex/src/config.rs", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "fn ", "pub fn ", "async fn ", "pub async fn ", "impl ", "trait ", "struct ", "enum ", + "macro_rules! ", "mod ", "const ", "static ", "type ", "pub struct ", "pub enum ", + "pub trait ", "pub mod ", "pub const ", "pub static ", "pub type ", + ]), + }, + CodeLanguage { + key: "python", + display_name: "Python", + extensions: &[".py", ".pyi"], + repos: &[ + CodeRepo { + key: "cpython", + urls: &[ + "https://raw.githubusercontent.com/python/cpython/main/Lib/json/encoder.py", + "https://raw.githubusercontent.com/python/cpython/main/Lib/pathlib/__init__.py", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Indentation(&["def ", "class ", "async def ", "@"]), + }, + CodeLanguage { + key: "javascript", + display_name: "JavaScript", + extensions: &[".js", ".mjs"], + repos: &[ + CodeRepo { + key: "node-stdlib", + urls: &[ + "https://raw.githubusercontent.com/nodejs/node/main/lib/path.js", + "https://raw.githubusercontent.com/nodejs/node/main/lib/url.js", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "function ", + "async function ", + "const ", + "class ", + "export function ", + "export default function ", + "let ", + "export ", + ]), + }, + CodeLanguage { + key: "go", + display_name: "Go", + extensions: &[".go"], + repos: &[ + CodeRepo { + key: "go-stdlib", + urls: &[ + "https://raw.githubusercontent.com/golang/go/master/src/fmt/print.go", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&["func ", "type "]), + }, + CodeLanguage { + key: "typescript", + display_name: "TypeScript", + extensions: &[".ts", ".tsx"], + repos: &[ + CodeRepo { + key: "ts-node", + urls: &[ + "https://raw.githubusercontent.com/TypeStrong/ts-node/main/src/index.ts", + ], + }, + CodeRepo { + key: "deno-std", + urls: &[ + "https://raw.githubusercontent.com/denoland/std/main/path/posix/normalize.ts", + "https://raw.githubusercontent.com/denoland/std/main/fs/walk.ts", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "function ", + "export function ", + "async function ", + "const ", + "class ", + "interface ", + "type ", + "export default function ", + "let ", + "export ", + ]), + }, + CodeLanguage { + key: "java", + display_name: "Java", + extensions: &[".java"], + repos: &[ + CodeRepo { + key: "guava", + urls: &[ + "https://raw.githubusercontent.com/google/guava/master/guava/src/com/google/common/collect/ImmutableList.java", + "https://raw.githubusercontent.com/google/guava/master/guava/src/com/google/common/base/Preconditions.java", + ], + }, + CodeRepo { + key: "gson", + urls: &[ + "https://raw.githubusercontent.com/google/gson/main/gson/src/main/java/com/google/gson/Gson.java", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "public ", + "private ", + "protected ", + "static ", + "class ", + "interface ", + "void ", + "int ", + "String ", + "boolean ", + "@", + "abstract ", + "final ", + ]), + }, + CodeLanguage { + key: "c", + display_name: "C", + extensions: &[".c", ".h"], + repos: &[ + CodeRepo { + key: "redis", + urls: &[ + "https://raw.githubusercontent.com/redis/redis/unstable/src/server.c", + "https://raw.githubusercontent.com/redis/redis/unstable/src/networking.c", + ], + }, + CodeRepo { + key: "jq", + urls: &[ + "https://raw.githubusercontent.com/jqlang/jq/master/src/builtin.c", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "int ", + "void ", + "char ", + "float ", + "double ", + "struct ", + "unsigned ", + "static ", + "const ", + "typedef ", + "#define ", + "enum ", + ]), + }, + CodeLanguage { + key: "cpp", + display_name: "C++", + extensions: &[".cpp", ".hpp", ".cc", ".cxx"], + repos: &[ + CodeRepo { + key: "json", + urls: &[ + "https://raw.githubusercontent.com/nlohmann/json/develop/include/nlohmann/json.hpp", + ], + }, + CodeRepo { + key: "fmt", + urls: &[ + "https://raw.githubusercontent.com/fmtlib/fmt/master/include/fmt/format.h", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "int ", + "void ", + "char ", + "auto ", + "class ", + "struct ", + "template", + "namespace ", + "virtual ", + "static ", + "const ", + "typedef ", + "#define ", + "enum ", + "constexpr ", + ]), + }, + CodeLanguage { + key: "ruby", + display_name: "Ruby", + extensions: &[".rb"], + repos: &[ + CodeRepo { + key: "rake", + urls: &[ + "https://raw.githubusercontent.com/ruby/rake/master/lib/rake/task.rb", + "https://raw.githubusercontent.com/ruby/rake/master/lib/rake/application.rb", + ], + }, + CodeRepo { + key: "sinatra", + urls: &[ + "https://raw.githubusercontent.com/sinatra/sinatra/main/lib/sinatra/base.rb", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::EndDelimited(&[ + "def ", "class ", "module ", "attr_", "scope ", "describe ", "it ", + ]), + }, + CodeLanguage { + key: "swift", + display_name: "Swift", + extensions: &[".swift"], + repos: &[ + CodeRepo { + key: "swift-algorithms", + urls: &[ + "https://raw.githubusercontent.com/apple/swift-algorithms/main/Sources/Algorithms/Chunked.swift", + "https://raw.githubusercontent.com/apple/swift-algorithms/main/Sources/Algorithms/Combinations.swift", + ], + }, + CodeRepo { + key: "swift-nio", + urls: &[ + "https://raw.githubusercontent.com/apple/swift-nio/main/Sources/NIOCore/Channel.swift", + "https://raw.githubusercontent.com/apple/swift-nio/main/Sources/NIOCore/EventLoop.swift", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&[ + "func ", + "class ", + "struct ", + "enum ", + "protocol ", + "var ", + "let ", + "init(", + "deinit ", + "extension ", + "typealias ", + ]), + }, + CodeLanguage { + key: "bash", + display_name: "Bash", + extensions: &[".sh", ".bash"], + repos: &[ + CodeRepo { + key: "nvm", + urls: &[ + "https://raw.githubusercontent.com/nvm-sh/nvm/master/nvm.sh", + ], + }, + CodeRepo { + key: "oh-my-zsh", + urls: &[ + "https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/lib/functions.zsh", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::Braces(&["function ", "if ", "for ", "while ", "case "]), + }, + CodeLanguage { + key: "lua", + display_name: "Lua", + extensions: &[".lua"], + repos: &[ + CodeRepo { + key: "kong", + urls: &[ + "https://raw.githubusercontent.com/Kong/kong/master/kong/init.lua", + ], + }, + CodeRepo { + key: "luarocks", + urls: &[ + "https://raw.githubusercontent.com/luarocks/luarocks/master/src/luarocks/core/cfg.lua", + ], + }, + ], + has_builtin: true, + block_style: BlockStyle::EndDelimited(&["function ", "local function "]), + }, + // === Network-only languages (has_builtin: false) === + CodeLanguage { + key: "kotlin", + display_name: "Kotlin", + extensions: &[".kt", ".kts"], + repos: &[ + CodeRepo { + key: "kotlinx-coroutines", + urls: &[ + "https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/flow/Builders.kt", + "https://raw.githubusercontent.com/Kotlin/kotlinx.coroutines/master/kotlinx-coroutines-core/common/src/channels/Channel.kt", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "fun ", "class ", "object ", "interface ", "suspend fun ", + "public ", "private ", "internal ", "override fun ", "open ", + "data class ", "sealed ", "abstract ", + "val ", "var ", "enum ", "annotation ", "typealias ", + ]), + }, + CodeLanguage { + key: "scala", + display_name: "Scala", + extensions: &[".scala"], + repos: &[ + CodeRepo { + key: "scala-stdlib", + urls: &[ + "https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/immutable/List.scala", + "https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/collection/mutable/HashMap.scala", + "https://raw.githubusercontent.com/scala/scala/2.13.x/src/library/scala/Option.scala", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "def ", "class ", "object ", "trait ", "case class ", + "val ", "var ", "type ", "implicit ", "given ", "extension ", + ]), + }, + CodeLanguage { + key: "csharp", + display_name: "C#", + extensions: &[".cs"], + repos: &[ + CodeRepo { + key: "aspnetcore", + urls: &[ + "https://raw.githubusercontent.com/dotnet/aspnetcore/main/src/Http/Http.Abstractions/src/HttpContext.cs", + ], + }, + CodeRepo { + key: "roslyn", + urls: &[ + "https://raw.githubusercontent.com/dotnet/roslyn/main/src/Compilers/CSharp/Portable/Syntax/SyntaxFactory.cs", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "public ", + "private ", + "protected ", + "internal ", + "static ", + "class ", + "interface ", + "void ", + "async ", + ]), + }, + CodeLanguage { + key: "php", + display_name: "PHP", + extensions: &[".php"], + repos: &[ + CodeRepo { + key: "wordpress", + urls: &[ + "https://raw.githubusercontent.com/WordPress/WordPress/master/wp-includes/formatting.php", + ], + }, + CodeRepo { + key: "symfony", + urls: &[ + "https://raw.githubusercontent.com/symfony/symfony/7.2/src/Symfony/Component/HttpFoundation/Request.php", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "function ", + "public function ", + "private function ", + "protected function ", + "class ", + "interface ", + "trait ", + "enum ", + ]), + }, + CodeLanguage { + key: "dart", + display_name: "Dart", + extensions: &[".dart"], + repos: &[ + CodeRepo { + key: "flutter", + urls: &[ + "https://raw.githubusercontent.com/flutter/flutter/master/packages/flutter/lib/src/widgets/framework.dart", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "void ", "Future ", "Future<", "class ", "int ", "String ", "bool ", "static ", "factory ", + "Widget ", "get ", "set ", "enum ", "typedef ", "extension ", + ]), + }, + CodeLanguage { + key: "elixir", + display_name: "Elixir", + extensions: &[".ex", ".exs"], + repos: &[ + CodeRepo { + key: "phoenix", + urls: &[ + "https://raw.githubusercontent.com/phoenixframework/phoenix/main/lib/phoenix/router.ex", + ], + }, + CodeRepo { + key: "elixir-lang", + urls: &[ + "https://raw.githubusercontent.com/elixir-lang/elixir/main/lib/elixir/lib/enum.ex", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::EndDelimited(&[ + "def ", "defp ", "defmodule ", + "defmacro ", "defstruct", "defprotocol ", "defimpl ", + ]), + }, + CodeLanguage { + key: "perl", + display_name: "Perl", + extensions: &[".pl", ".pm"], + repos: &[ + CodeRepo { + key: "mojolicious", + urls: &[ + "https://raw.githubusercontent.com/mojolicious/mojo/main/lib/Mojolicious.pm", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&["sub "]), + }, + CodeLanguage { + key: "zig", + display_name: "Zig", + extensions: &[".zig"], + repos: &[ + CodeRepo { + key: "zig-stdlib", + urls: &[ + "https://raw.githubusercontent.com/ziglang/zig/master/lib/std/mem.zig", + "https://raw.githubusercontent.com/ziglang/zig/master/lib/std/fmt.zig", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&["pub fn ", "fn ", "const ", "pub const ", "test ", "var "]), + }, + CodeLanguage { + key: "julia", + display_name: "Julia", + extensions: &[".jl"], + repos: &[ + CodeRepo { + key: "julia-stdlib", + urls: &[ + "https://raw.githubusercontent.com/JuliaLang/julia/master/base/array.jl", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::EndDelimited(&["function ", "macro "]), + }, + CodeLanguage { + key: "nim", + display_name: "Nim", + extensions: &[".nim"], + repos: &[ + CodeRepo { + key: "nim-stdlib", + urls: &[ + "https://raw.githubusercontent.com/nim-lang/Nim/devel/lib/pure/strutils.nim", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Indentation(&["proc ", "func ", "method ", "type "]), + }, + CodeLanguage { + key: "ocaml", + display_name: "OCaml", + extensions: &[".ml", ".mli"], + repos: &[ + CodeRepo { + key: "ocaml-stdlib", + urls: &[ + "https://raw.githubusercontent.com/ocaml/ocaml/trunk/stdlib/list.ml", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Indentation(&["let ", "type ", "module "]), + }, + CodeLanguage { + key: "haskell", + display_name: "Haskell", + extensions: &[".hs"], + repos: &[ + CodeRepo { + key: "aeson", + urls: &[ + "https://raw.githubusercontent.com/haskell/aeson/master/src/Data/Aeson/Types/Internal.hs", + ], + }, + CodeRepo { + key: "xmonad", + urls: &[ + "https://raw.githubusercontent.com/xmonad/xmonad/master/src/XMonad/Operations.hs", + ], + }, + ], + has_builtin: false, + // Haskell: top-level declarations are indented blocks + block_style: BlockStyle::Indentation(&[ + "data ", "type ", "class ", "instance ", "newtype ", "module ", + ]), + }, + CodeLanguage { + key: "clojure", + display_name: "Clojure", + extensions: &[".clj", ".cljs"], + repos: &[ + CodeRepo { + key: "clojure-core", + urls: &[ + "https://raw.githubusercontent.com/clojure/clojure/master/src/clj/clojure/core.clj", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Indentation(&["(defn ", "(defn- ", "(defmacro "]), + }, + CodeLanguage { + key: "r", + display_name: "R", + extensions: &[".r", ".R"], + repos: &[ + CodeRepo { + key: "shiny", + urls: &[ + "https://raw.githubusercontent.com/rstudio/shiny/main/R/bootstrap.R", + "https://raw.githubusercontent.com/rstudio/shiny/main/R/input-text.R", + ], + }, + ], + has_builtin: false, + // R functions are defined as `name <- function(...)`. Since our extractor only + // supports `starts_with`, we match roxygen doc blocks that precede functions. + block_style: BlockStyle::Braces(&["#' "]), + }, + CodeLanguage { + key: "erlang", + display_name: "Erlang", + extensions: &[".erl"], + repos: &[ + CodeRepo { + key: "cowboy", + urls: &[ + "https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_req.erl", + "https://raw.githubusercontent.com/ninenines/cowboy/master/src/cowboy_http.erl", + ], + }, + ], + has_builtin: false, + // Erlang: -spec and -record use braces for types/fields. + // Erlang functions themselves don't use braces (they end with `.`), + // so extraction is limited to type specs and records. + block_style: BlockStyle::Braces(&[ + "-spec ", "-record(", "-type ", "-callback ", + ]), + }, + CodeLanguage { + key: "groovy", + display_name: "Groovy", + extensions: &[".groovy"], + repos: &[ + CodeRepo { + key: "nextflow", + urls: &[ + "https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy", + "https://raw.githubusercontent.com/nextflow-io/nextflow/master/modules/nextflow/src/main/groovy/nextflow/Session.groovy", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&["def ", "void ", "static ", "public ", "private "]), + }, + CodeLanguage { + key: "fsharp", + display_name: "F#", + extensions: &[".fs", ".fsx"], + repos: &[ + CodeRepo { + key: "fsharp-compiler", + urls: &[ + "https://raw.githubusercontent.com/dotnet/fsharp/main/src/Compiler/Utilities/lib.fs", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Indentation(&["let ", "member ", "type ", "module "]), + }, + CodeLanguage { + key: "objective-c", + display_name: "Objective-C", + extensions: &[".m", ".h"], + repos: &[ + CodeRepo { + key: "afnetworking", + urls: &[ + "https://raw.githubusercontent.com/AFNetworking/AFNetworking/master/AFNetworking/AFURLSessionManager.m", + ], + }, + ], + has_builtin: false, + block_style: BlockStyle::Braces(&[ + "- (", "+ (", "- (void)", "- (id)", "- (BOOL)", + "@interface ", "@implementation ", "@protocol ", "typedef ", + ]), + }, +]; + +/// Returns list of (key, display_name) for language selection UI. +pub fn code_language_options() -> Vec<(&'static str, String)> { + let mut options: Vec<(&'static str, String)> = CODE_LANGUAGES + .iter() + .map(|lang| (lang.key, lang.display_name.to_string())) + .collect(); + options.sort_by_key(|(_, display)| display.to_lowercase()); + options.insert(0, ("all", "All (random)".to_string())); + options +} + +/// Look up a language by its key. +pub fn language_by_key(key: &str) -> Option<&'static CodeLanguage> { + CODE_LANGUAGES.iter().find(|lang| lang.key == key) +} + +/// Check if any cached snippet files exist for a language. +pub fn is_language_cached(cache_dir: &str, key: &str) -> bool { + let dir = std::path::Path::new(cache_dir); + if !dir.is_dir() { + return false; + } + let prefix = format!("{}_", key); + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if name.starts_with(&prefix) && name.ends_with(".txt") { + if let Ok(meta) = entry.metadata() { + if meta.len() > 0 { + return true; + } + } + } + } + } + false +} + +/// Returns language keys that have either built-in snippets or cached content. +pub fn languages_with_content(cache_dir: &str) -> Vec<&'static str> { + CODE_LANGUAGES + .iter() + .filter(|lang| lang.has_builtin || is_language_cached(cache_dir, lang.key)) + .map(|lang| lang.key) + .collect() +} + +/// Build a download queue of `(language_key, repo_index)` pairs for uncached repos. +/// When `lang_key` is `"all"`, queues all uncached repos across all languages. +pub fn build_code_download_queue(lang_key: &str, cache_dir: &str) -> Vec<(String, usize)> { + let languages_to_download: Vec<&str> = if lang_key == "all" { + CODE_LANGUAGES.iter().map(|l| l.key).collect() + } else if language_by_key(lang_key).is_some() { + vec![lang_key] + } else { + vec![] + }; + + let mut queue: Vec<(String, usize)> = Vec::new(); + for lk in &languages_to_download { + if let Some(lang) = language_by_key(lk) { + for (repo_idx, repo) in lang.repos.iter().enumerate() { + let cache_path = std::path::Path::new(cache_dir) + .join(format!("{}_{}.txt", lang.key, repo.key)); + if !cache_path.exists() + || std::fs::metadata(&cache_path) + .map(|m| m.len() == 0) + .unwrap_or(true) + { + queue.push((lang.key.to_string(), repo_idx)); + } + } + } + } + queue +} pub struct CodeSyntaxGenerator { rng: SmallRng, @@ -13,14 +790,14 @@ pub struct CodeSyntaxGenerator { } impl CodeSyntaxGenerator { - pub fn new(rng: SmallRng, language: &str) -> Self { + pub fn new(rng: SmallRng, language: &str, cache_dir: &str) -> Self { let mut generator = Self { rng, language: language.to_string(), fetched_snippets: Vec::new(), last_source: "Built-in snippets".to_string(), }; - generator.load_cached_snippets(); + generator.load_cached_snippets(cache_dir); generator } @@ -28,152 +805,325 @@ impl CodeSyntaxGenerator { &self.last_source } - fn load_cached_snippets(&mut self) { - if let Some(cache) = DiskCache::new("code_cache") { - let key = format!("{}_snippets", self.language); - if let Some(content) = cache.get(&key) { - self.fetched_snippets = content - .split("\n---SNIPPET---\n") - .filter(|s| !s.trim().is_empty()) - .map(|s| s.to_string()) - .collect(); - } - } - } - - fn try_fetch_code(&mut self) { - let urls = match self.language.as_str() { - "rust" => vec![ - "https://raw.githubusercontent.com/tokio-rs/tokio/master/tokio/src/sync/mutex.rs", - "https://raw.githubusercontent.com/serde-rs/serde/master/serde/src/ser/mod.rs", - ], - "python" => vec![ - "https://raw.githubusercontent.com/python/cpython/main/Lib/json/encoder.py", - "https://raw.githubusercontent.com/python/cpython/main/Lib/pathlib/__init__.py", - ], - "javascript" | "js" => vec![ - "https://raw.githubusercontent.com/lodash/lodash/main/src/chunk.ts", - "https://raw.githubusercontent.com/expressjs/express/master/lib/router/index.js", - ], - "go" => vec!["https://raw.githubusercontent.com/golang/go/master/src/fmt/print.go"], - _ => vec![], - }; - - let cache = match DiskCache::new("code_cache") { - Some(c) => c, - None => return, - }; - - let key = format!("{}_snippets", self.language); - if cache.get(&key).is_some() { + fn load_cached_snippets(&mut self, cache_dir: &str) { + let dir = std::path::Path::new(cache_dir); + if !dir.is_dir() { return; } - - let mut all_snippets = Vec::new(); - for url in urls { - if let Some(content) = fetch_url(url) { - let snippets = extract_code_snippets(&content); - all_snippets.extend(snippets); + let prefix = format!("{}_", self.language); + if let Ok(entries) = fs::read_dir(dir) { + for entry in entries.flatten() { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if name_str.starts_with(&prefix) && name_str.ends_with(".txt") { + if let Ok(content) = fs::read_to_string(entry.path()) { + let snippets: Vec = content + .split("\n---SNIPPET---\n") + .filter(|s| !s.trim().is_empty()) + .map(|s| s.to_string()) + .collect(); + self.fetched_snippets.extend(snippets); + } + } } } - - if !all_snippets.is_empty() { - let combined = all_snippets.join("\n---SNIPPET---\n"); - cache.put(&key, &combined); - self.fetched_snippets = all_snippets; - } } fn rust_snippets() -> Vec<&'static str> { vec![ - "fn main() {\n println!(\"hello\");\n}", - "let mut x = 0;\nx += 1;", - "for i in 0..10 {\n println!(\"{}\", i);\n}", - "if x > 0 {\n return true;\n}", - "match val {\n Some(x) => x,\n None => 0,\n}", - "struct Point {\n x: f64,\n y: f64,\n}", - "impl Point {\n fn new(x: f64, y: f64) -> Self {\n Self { x, y }\n }\n}", + r#"fn main() { + println!("hello"); +}"#, + r#"let mut x = 0; +x += 1;"#, + r#"for i in 0..10 { + println!("{}", i); +}"#, + r#"if x > 0 { + return true; +}"#, + r#"match val { + Some(x) => x, + None => 0, +}"#, + r#"struct Point { + x: f64, + y: f64, +}"#, + r#"impl Point { + fn new(x: f64, y: f64) -> Self { + Self { x, y } + } +}"#, "let v: Vec = vec![1, 2, 3];", - "fn add(a: i32, b: i32) -> i32 {\n a + b\n}", + r#"fn add(a: i32, b: i32) -> i32 { + a + b +}"#, "use std::collections::HashMap;", - "pub fn process(input: &str) -> Result {\n Ok(input.to_string())\n}", - "let result = items\n .iter()\n .filter(|x| x > &0)\n .map(|x| x * 2)\n .collect::>();", - "enum Color {\n Red,\n Green,\n Blue,\n}", - "trait Display {\n fn show(&self) -> String;\n}", - "while let Some(item) = stack.pop() {\n process(item);\n}", - "#[derive(Debug, Clone)]\nstruct Config {\n name: String,\n value: i32,\n}", - "let handle = std::thread::spawn(|| {\n println!(\"thread\");\n});", - "let mut map = HashMap::new();\nmap.insert(\"key\", 42);", - "fn factorial(n: u64) -> u64 {\n if n <= 1 {\n 1\n } else {\n n * factorial(n - 1)\n }\n}", - "impl Iterator for Counter {\n type Item = u32;\n\n fn next(&mut self) -> Option {\n None\n }\n}", - "async fn fetch(url: &str) -> Result {\n let body = reqwest::get(url)\n .await?\n .text()\n .await?;\n Ok(body)\n}", - "let closure = |x: i32, y: i32| -> i32 {\n x + y\n};", - "#[cfg(test)]\nmod tests {\n use super::*;\n\n #[test]\n fn it_works() {\n assert_eq!(2 + 2, 4);\n }\n}", - "pub struct Builder {\n name: Option,\n}\n\nimpl Builder {\n pub fn name(mut self, n: &str) -> Self {\n self.name = Some(n.into());\n self\n }\n}", - "use std::sync::{Arc, Mutex};\nlet data = Arc::new(Mutex::new(vec![1, 2, 3]));", - "if let Ok(value) = \"42\".parse::() {\n println!(\"parsed: {}\", value);\n}", - "fn longest<'a>(x: &'a str, y: &'a str) -> &'a str {\n if x.len() > y.len() {\n x\n } else {\n y\n }\n}", + r#"pub fn process(input: &str) -> Result { + Ok(input.to_string()) +}"#, + r#"let result = items + .iter() + .filter(|x| x > &0) + .map(|x| x * 2) + .collect::>();"#, + r#"enum Color { + Red, + Green, + Blue, +}"#, + r#"trait Display { + fn show(&self) -> String; +}"#, + r#"while let Some(item) = stack.pop() { + process(item); +}"#, + r#"#[derive(Debug, Clone)] +struct Config { + name: String, + value: i32, +}"#, + r#"let handle = std::thread::spawn(|| { + println!("thread"); +});"#, + r#"let mut map = HashMap::new(); +map.insert("key", 42);"#, + r#"fn factorial(n: u64) -> u64 { + if n <= 1 { + 1 + } else { + n * factorial(n - 1) + } +}"#, + r#"impl Iterator for Counter { + type Item = u32; + + fn next(&mut self) -> Option { + None + } +}"#, + r#"async fn fetch(url: &str) -> Result { + let body = reqwest::get(url) + .await? + .text() + .await?; + Ok(body) +}"#, + r#"let closure = |x: i32, y: i32| -> i32 { + x + y +};"#, + r#"#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + assert_eq!(2 + 2, 4); + } +}"#, + r#"pub struct Builder { + name: Option, +} + +impl Builder { + pub fn name(mut self, n: &str) -> Self { + self.name = Some(n.into()); + self + } +}"#, + r#"use std::sync::{Arc, Mutex}; +let data = Arc::new(Mutex::new(vec![1, 2, 3]));"#, + r#"if let Ok(value) = "42".parse::() { + println!("parsed: {}", value); +}"#, + r#"fn longest<'a>(x: &'a str, y: &'a str) -> &'a str { + if x.len() > y.len() { + x + } else { + y + } +}"#, "type Result = std::result::Result>;", - "macro_rules! vec_of_strings {\n ($($x:expr),*) => {\n vec![$($x.to_string()),*]\n };\n}", - "let (tx, rx) = std::sync::mpsc::channel();\ntx.send(42).unwrap();", + r#"macro_rules! vec_of_strings { + ($($x:expr),*) => { + vec![$($x.to_string()),*] + }; +}"#, + r#"let (tx, rx) = std::sync::mpsc::channel(); +tx.send(42).unwrap();"#, ] } fn python_snippets() -> Vec<&'static str> { vec![ - "def main():\n print(\"hello\")", - "for i in range(10):\n print(i)", - "if x > 0:\n return True", - "class Point:\n def __init__(self, x, y):\n self.x = x\n self.y = y", - "import os\npath = os.path.join(\"a\", \"b\")", - "result = [\n x * 2\n for x in items\n if x > 0\n]", - "with open(\"file.txt\") as f:\n data = f.read()", - "def add(a: int, b: int) -> int:\n return a + b", - "try:\n result = process(data)\nexcept ValueError as e:\n print(e)", + r#"def main(): + print("hello")"#, + r#"for i in range(10): + print(i)"#, + r#"if x > 0: + return True"#, + r#"class Point: + def __init__(self, x, y): + self.x = x + self.y = y"#, + r#"import os +path = os.path.join("a", "b")"#, + r#"result = [ + x * 2 + for x in items + if x > 0 +]"#, + r#"with open("file.txt") as f: + data = f.read()"#, + r#"def add(a: int, b: int) -> int: + return a + b"#, + r#"try: + result = process(data) +except ValueError as e: + print(e)"#, "from collections import defaultdict", "lambda x: x * 2 + 1", - "dict_comp = {\n k: v\n for k, v in pairs.items()\n}", - "async def fetch(url):\n async with aiohttp.ClientSession() as session:\n return await session.get(url)", - "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)", - "@property\ndef name(self):\n return self._name", - "from dataclasses import dataclass\n\n@dataclass\nclass Config:\n name: str\n value: int = 0", + r#"dict_comp = { + k: v + for k, v in pairs.items() +}"#, + r#"async def fetch(url): + async with aiohttp.ClientSession() as session: + return await session.get(url)"#, + r#"def fibonacci(n): + if n <= 1: + return n + return fibonacci(n-1) + fibonacci(n-2)"#, + r#"@property +def name(self): + return self._name"#, + r#"from dataclasses import dataclass + +@dataclass +class Config: + name: str + value: int = 0"#, "yield from range(10)", - "sorted(\n items,\n key=lambda x: x.name,\n reverse=True,\n)", + r#"sorted( + items, + key=lambda x: x.name, + reverse=True, +)"#, "from typing import Optional, List, Dict", - "with contextlib.suppress(FileNotFoundError):\n os.remove(\"temp.txt\")", - "class Meta(type):\n def __new__(cls, name, bases, attrs):\n return super().__new__(\n cls, name, bases, attrs\n )", - "from functools import lru_cache\n\n@lru_cache(maxsize=128)\ndef expensive(n):\n return sum(range(n))", - "from pathlib import Path\nfiles = list(Path(\".\").glob(\"**/*.py\"))", - "assert isinstance(result, dict), \\\n f\"Expected dict, got {type(result)}\"", - "values = {*set_a, *set_b}\nmerged = {**dict_a, **dict_b}", + r#"with contextlib.suppress(FileNotFoundError): + os.remove("temp.txt")"#, + r#"class Meta(type): + def __new__(cls, name, bases, attrs): + return super().__new__( + cls, name, bases, attrs + )"#, + r#"from functools import lru_cache + +@lru_cache(maxsize=128) +def expensive(n): + return sum(range(n))"#, + r#"from pathlib import Path +files = list(Path(".").glob("**/*.py"))"#, + r#"assert isinstance(result, dict), \ + f"Expected dict, got {type(result)}""#, + r#"values = {*set_a, *set_b} +merged = {**dict_a, **dict_b}"#, ] } fn javascript_snippets() -> Vec<&'static str> { vec![ - "const x = 42;\nconsole.log(x);", - "function add(a, b) {\n return a + b;\n}", - "const arr = [1, 2, 3].map(\n x => x * 2\n);", - "if (x > 0) {\n return true;\n}", - "for (let i = 0; i < 10; i++) {\n console.log(i);\n}", - "class Point {\n constructor(x, y) {\n this.x = x;\n this.y = y;\n }\n}", + r#"const x = 42; +console.log(x);"#, + r#"function add(a, b) { + return a + b; +}"#, + r#"const arr = [1, 2, 3].map( + x => x * 2 +);"#, + r#"if (x > 0) { + return true; +}"#, + r#"for (let i = 0; i < 10; i++) { + console.log(i); +}"#, + r#"class Point { + constructor(x, y) { + this.x = x; + this.y = y; + } +}"#, "const { name, age } = person;", - "async function fetch(url) {\n const res = await get(url);\n return res.json();\n}", - "const obj = {\n ...defaults,\n ...overrides,\n};", - "try {\n parse(data);\n} catch (e) {\n console.error(e);\n}", - "export default function handler(req, res) {\n res.send(\"ok\");\n}", - "const result = items\n .filter(x => x > 0)\n .reduce((a, b) => a + b, 0);", - "const promise = new Promise(\n (resolve, reject) => {\n setTimeout(resolve, 1000);\n }\n);", + r#"async function fetch(url) { + const res = await get(url); + return res.json(); +}"#, + r#"const obj = { + ...defaults, + ...overrides, +};"#, + r#"try { + parse(data); +} catch (e) { + console.error(e); +}"#, + r#"export default function handler(req, res) { + res.send("ok"); +}"#, + r#"const result = items + .filter(x => x > 0) + .reduce((a, b) => a + b, 0);"#, + r#"const promise = new Promise( + (resolve, reject) => { + setTimeout(resolve, 1000); + } +);"#, "const [first, ...rest] = array;", - "class EventEmitter {\n constructor() {\n this.listeners = new Map();\n }\n}", - "const proxy = new Proxy(target, {\n get(obj, prop) {\n return obj[prop];\n }\n});", - "for await (const chunk of stream) {\n process(chunk);\n}", - "const memoize = (fn) => {\n const cache = new Map();\n return (...args) => {\n return cache.get(args) ?? fn(...args);\n };\n};", - "import { useState, useEffect } from 'react';\nconst [state, setState] = useState(null);", - "const pipe = (...fns) => (x) =>\n fns.reduce((v, f) => f(v), x);", - "Object.entries(obj).forEach(\n ([key, value]) => {\n console.log(key, value);\n }\n);", - "const debounce = (fn, ms) => {\n let timer;\n return (...args) => {\n clearTimeout(timer);\n timer = setTimeout(\n () => fn(...args),\n ms\n );\n };\n};", - "const observable = new Observable(\n subscriber => {\n subscriber.next(1);\n subscriber.complete();\n }\n);", + r#"class EventEmitter { + constructor() { + this.listeners = new Map(); + } +}"#, + r#"const proxy = new Proxy(target, { + get(obj, prop) { + return obj[prop]; + } +});"#, + r#"for await (const chunk of stream) { + process(chunk); +}"#, + r#"const memoize = (fn) => { + const cache = new Map(); + return (...args) => { + return cache.get(args) ?? fn(...args); + }; +};"#, + r#"import { useState, useEffect } from 'react'; +const [state, setState] = useState(null);"#, + r#"const pipe = (...fns) => (x) => + fns.reduce((v, f) => f(v), x);"#, + r#"Object.entries(obj).forEach( + ([key, value]) => { + console.log(key, value); + } +);"#, + r#"const debounce = (fn, ms) => { + let timer; + return (...args) => { + clearTimeout(timer); + timer = setTimeout( + () => fn(...args), + ms + ); + }; +};"#, + r#"const observable = new Observable( + subscriber => { + subscriber.next(1); + subscriber.complete(); + } +);"#, ] } @@ -202,12 +1152,492 @@ impl CodeSyntaxGenerator { ] } + fn typescript_snippets() -> Vec<&'static str> { + vec![ + r#"interface User { + id: number; + name: string; + email: string; +}"#, + r#"type Result = { + data: T; + error: string | null; +};"#, + r#"function identity(arg: T): T { + return arg; +}"#, + r#"export async function fetchData( + url: string +): Promise { + const res = await fetch(url); + return res.json() as T; +}"#, + r#"class Stack { + private items: T[] = []; + + push(item: T): void { + this.items.push(item); + } + + pop(): T | undefined { + return this.items.pop(); + } +}"#, + r#"const enum Direction { + Up = "UP", + Down = "DOWN", + Left = "LEFT", + Right = "RIGHT", +}"#, + r#"type EventHandler = ( + event: T +) => void;"#, + r#"export function createStore( + initialState: S +) { + let state = initialState; + return { + getState: () => state, + setState: (next: S) => { + state = next; + }, + }; +}"#, + r#"interface Repository { + findById(id: string): Promise; + save(entity: T): Promise; + delete(id: string): Promise; +}"#, + r#"const guard = ( + value: unknown +): value is string => { + return typeof value === "string"; +};"#, + r#"type DeepPartial = { + [P in keyof T]?: T[P] extends object + ? DeepPartial + : T[P]; +};"#, + r#"export function debounce void>( + fn: T, + delay: number +): T { + let timer: ReturnType; + return ((...args: any[]) => { + clearTimeout(timer); + timer = setTimeout(() => fn(...args), delay); + }) as T; +}"#, + ] + } + + fn java_snippets() -> Vec<&'static str> { + vec![ + r#"public class Main { + public static void main(String[] args) { + System.out.println("hello"); + } +}"#, + r#"public int add(int a, int b) { + return a + b; +}"#, + r#"public class Stack { + private List items = new ArrayList<>(); + + public void push(T item) { + items.add(item); + } + + public T pop() { + return items.remove(items.size() - 1); + } +}"#, + r#"public interface Repository { + Optional findById(String id); + void save(T entity); + boolean delete(String id); +}"#, + r#"List result = items.stream() + .filter(s -> s.length() > 3) + .map(String::toUpperCase) + .collect(Collectors.toList());"#, + r#"try { + BufferedReader reader = new BufferedReader( + new FileReader("data.txt") + ); + String line = reader.readLine(); +} catch (IOException e) { + e.printStackTrace(); +}"#, + r#"@Override +public boolean equals(Object obj) { + if (this == obj) return true; + if (!(obj instanceof Point)) return false; + Point other = (Point) obj; + return x == other.x && y == other.y; +}"#, + r#"public static > T max( + T a, T b +) { + return a.compareTo(b) >= 0 ? a : b; +}"#, + r#"Map counts = new HashMap<>(); +for (String word : words) { + counts.merge(word, 1, Integer::sum); +}"#, + r#"public record Point(double x, double y) { + public double distance() { + return Math.sqrt(x * x + y * y); + } +}"#, + r#"CompletableFuture future = + CompletableFuture.supplyAsync(() -> { + return fetchData(); + }).thenApply(data -> { + return process(data); + });"#, + r#"private final Lock lock = new ReentrantLock(); + +public void update(String value) { + lock.lock(); + try { + this.data = value; + } finally { + lock.unlock(); + } +}"#, + ] + } + + fn c_snippets() -> Vec<&'static str> { + vec![ + r#"int main(int argc, char *argv[]) { + printf("hello\n"); + return 0; +}"#, + r#"struct Point { + double x; + double y; +};"#, + r#"int *create_array(int size) { + int *arr = malloc(size * sizeof(int)); + if (arr == NULL) { + return NULL; + } + memset(arr, 0, size * sizeof(int)); + return arr; +}"#, + r#"void swap(int *a, int *b) { + int temp = *a; + *a = *b; + *b = temp; +}"#, + r#"typedef struct Node { + int data; + struct Node *next; +} Node;"#, + r#"char *str_dup(const char *src) { + size_t len = strlen(src) + 1; + char *dst = malloc(len); + if (dst != NULL) { + memcpy(dst, src, len); + } + return dst; +}"#, + r#"void free_list(Node *head) { + Node *current = head; + while (current != NULL) { + Node *next = current->next; + free(current); + current = next; + } +}"#, + r#"int binary_search( + int *arr, int size, int target +) { + int low = 0, high = size - 1; + while (low <= high) { + int mid = low + (high - low) / 2; + if (arr[mid] == target) return mid; + if (arr[mid] < target) low = mid + 1; + else high = mid - 1; + } + return -1; +}"#, + r#"static int compare( + const void *a, const void *b +) { + return (*(int *)a - *(int *)b); +}"#, + r#"FILE *fp = fopen("data.txt", "r"); +if (fp == NULL) { + perror("fopen"); + return 1; +} +fclose(fp);"#, + r#"#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b))"#, + r#"void print_array( + const int *arr, size_t len +) { + for (size_t i = 0; i < len; i++) { + printf("%d ", arr[i]); + } + printf("\n"); +}"#, + ] + } + + fn cpp_snippets() -> Vec<&'static str> { + vec![ + r#"class Vector { +public: + Vector(double x, double y) + : x_(x), y_(y) {} + + double length() const { + return std::sqrt(x_ * x_ + y_ * y_); + } + +private: + double x_, y_; +};"#, + r#"template +T max_value(T a, T b) { + return (a > b) ? a : b; +}"#, + r#"auto ptr = std::make_unique(); +ptr->update(); +auto shared = std::make_shared();"#, + r#"std::vector nums = {3, 1, 4, 1, 5}; +std::sort(nums.begin(), nums.end()); +auto it = std::find( + nums.begin(), nums.end(), 4 +);"#, + r#"class Shape { +public: + virtual double area() const = 0; + virtual ~Shape() = default; +};"#, + r#"template +void print_all(const Container& c) { + for (const auto& item : c) { + std::cout << item << " "; + } + std::cout << std::endl; +}"#, + r#"std::map counts; +for (const auto& word : words) { + counts[word]++; +}"#, + r#"namespace utils { + std::string trim(const std::string& s) { + auto start = s.find_first_not_of(" \t"); + auto end = s.find_last_not_of(" \t"); + return s.substr(start, end - start + 1); + } +}"#, + r#"auto future = std::async( + std::launch::async, + []() { return compute(); } +); +auto result = future.get();"#, + r#"class Singleton { +public: + static Singleton& instance() { + static Singleton s; + return s; + } + Singleton(const Singleton&) = delete; + Singleton& operator=(const Singleton&) = delete; +};"#, + r#"try { + auto data = parse(input); +} catch (const std::exception& e) { + std::cerr << e.what() << std::endl; +}"#, + r#"template +class Stack { + std::vector data_; +public: + void push(const T& val) { + data_.push_back(val); + } + T pop() { + T top = data_.back(); + data_.pop_back(); + return top; + } +};"#, + ] + } + + fn ruby_snippets() -> Vec<&'static str> { + vec![ + "class Animal\n attr_reader :name, :age\n\n def initialize(name, age)\n @name = name\n @age = age\n end\nend", + "def fibonacci(n)\n return n if n <= 1\n fibonacci(n - 1) + fibonacci(n - 2)\nend", + "numbers = [1, 2, 3, 4, 5]\nresult = numbers\n .select { |n| n.even? }\n .map { |n| n * 2 }", + "class Stack\n def initialize\n @data = []\n end\n\n def push(item)\n @data.push(item)\n end\n\n def pop\n @data.pop\n end\nend", + "File.open(\"data.txt\", \"r\") do |f|\n f.each_line do |line|\n puts line.strip\n end\nend", + "module Serializable\n def to_json\n instance_variables.each_with_object({}) do |var, hash|\n hash[var.to_s.delete(\"@\")] = instance_variable_get(var)\n end.to_json\n end\nend", + "begin\n result = parse(data)\nrescue ArgumentError => e\n puts \"Error: #{e.message}\"\nensure\n cleanup\nend", + "double = ->(x) { x * 2 }\ntriple = proc { |x| x * 3 }\nputs double.call(5)\nputs triple.call(5)", + "class Config\n def self.load(path)\n YAML.load_file(path)\n end\n\n def self.defaults\n { timeout: 30, retries: 3 }\n end\nend", + "hash = { name: \"Alice\", age: 30 }\nhash.each do |key, value|\n puts \"#{key}: #{value}\"\nend", + "def with_retry(attempts: 3)\n attempts.times do |i|\n begin\n return yield\n rescue StandardError => e\n raise if i == attempts - 1\n end\n end\nend", + "class Logger\n def initialize(output = $stdout)\n @output = output\n end\n\n def info(msg)\n @output.puts \"[INFO] #{msg}\"\n end\n\n def error(msg)\n @output.puts \"[ERROR] #{msg}\"\n end\nend", + ] + } + + fn swift_snippets() -> Vec<&'static str> { + vec![ + r#"struct Point { + var x: Double + var y: Double + + func distance(to other: Point) -> Double { + let dx = x - other.x + let dy = y - other.y + return (dx * dx + dy * dy).squareRoot() + } +}"#, + r#"enum Result { + case success(T) + case failure(Error) +}"#, + r#"func fetchData( + from url: URL, + completion: @escaping (Data?) -> Void +) { + URLSession.shared.dataTask(with: url) { + data, _, _ in + completion(data) + }.resume() +}"#, + r#"protocol Drawable { + func draw() + var bounds: CGRect { get } +}"#, + r#"class ViewModel: ObservableObject { + @Published var items: [String] = [] + + func loadItems() { + items = ["one", "two", "three"] + } +}"#, + r#"guard let value = optionalValue else { + return nil +} +let result = process(value)"#, + r#"let numbers = [1, 2, 3, 4, 5] +let doubled = numbers + .filter { $0 > 2 } + .map { $0 * 2 }"#, + r#"extension Array where Element: Comparable { + func sorted() -> [Element] { + return self.sorted(by: <) + } +}"#, + r#"struct Config: Codable { + let name: String + let timeout: Int + let retries: Int + + static let defaults = Config( + name: "default", + timeout: 30, + retries: 3 + ) +}"#, + r#"func retry( + attempts: Int, + task: () throws -> T +) rethrows -> T { + for i in 0.. { + private var storage: [Key: Value] = [:] + + func get(_ key: Key) -> Value? { + return storage[key] + } + + func set(_ key: Key, value: Value) { + storage[key] = value + } +}"#, + r#"enum NetworkError: Error { + case badURL + case timeout + case serverError(Int) + + var description: String { + switch self { + case .badURL: return "Invalid URL" + case .timeout: return "Request timed out" + case .serverError(let code): + return "Server error: \(code)" + } + } +}"#, + ] + } + + fn bash_snippets() -> Vec<&'static str> { + vec![ + "#!/bin/bash\nset -euo pipefail\nIFS=$'\\n\\t'", + "function log() {\n local level=\"$1\"\n local msg=\"$2\"\n echo \"[$level] $(date '+%Y-%m-%d %H:%M:%S') $msg\"\n}", + "for file in *.txt; do\n if [ -f \"$file\" ]; then\n wc -l \"$file\"\n fi\ndone", + "count=0\nwhile read -r line; do\n count=$((count + 1))\n echo \"$count: $line\"\ndone < input.txt", + "function check_deps() {\n local deps=(\"git\" \"curl\" \"jq\")\n for cmd in \"${deps[@]}\"; do\n if ! command -v \"$cmd\" &>/dev/null; then\n echo \"Missing: $cmd\"\n exit 1\n fi\n done\n}", + "case \"$1\" in\n start)\n echo \"Starting...\"\n ;;\n stop)\n echo \"Stopping...\"\n ;;\n *)\n echo \"Usage: $0 {start|stop}\"\n exit 1\n ;;\nesac", + "readonly CONFIG_DIR=\"${HOME}/.config/myapp\"\nreadonly DATA_DIR=\"${HOME}/.local/share/myapp\"\nmkdir -p \"$CONFIG_DIR\" \"$DATA_DIR\"", + "function cleanup() {\n rm -rf \"$TMPDIR\"\n echo \"Cleaned up temp files\"\n}\ntrap cleanup EXIT", + "if [ -z \"${API_KEY:-}\" ]; then\n echo \"Error: API_KEY not set\" >&2\n exit 1\nfi", + "function retry() {\n local attempts=\"$1\"\n shift\n local count=0\n until \"$@\"; do\n count=$((count + 1))\n if [ \"$count\" -ge \"$attempts\" ]; then\n return 1\n fi\n sleep 1\n done\n}", + "declare -A colors\ncolors[red]=\"#ff0000\"\ncolors[green]=\"#00ff00\"\ncolors[blue]=\"#0000ff\"\nfor key in \"${!colors[@]}\"; do\n echo \"$key: ${colors[$key]}\"\ndone", + "find . -name \"*.log\" -mtime +7 -print0 |\n xargs -0 rm -f\necho \"Old log files removed\"", + ] + } + + fn lua_snippets() -> Vec<&'static str> { + vec![ + "local function greet(name)\n print(\"Hello, \" .. name)\nend", + "local config = {\n host = \"localhost\",\n port = 8080,\n debug = false,\n}", + "function factorial(n)\n if n <= 1 then\n return 1\n end\n return n * factorial(n - 1)\nend", + "local mt = {\n __index = function(t, k)\n return rawget(t, k) or 0\n end,\n __tostring = function(t)\n return table.concat(t, \", \")\n end,\n}", + "local function map(tbl, fn)\n local result = {}\n for i, v in ipairs(tbl) do\n result[i] = fn(v)\n end\n return result\nend", + "local function read_file(path)\n local f = io.open(path, \"r\")\n if not f then\n return nil, \"cannot open file\"\n end\n local content = f:read(\"*a\")\n f:close()\n return content\nend", + "local Class = {}\nClass.__index = Class\n\nfunction Class:new(name)\n local instance = setmetatable({}, self)\n instance.name = name\n return instance\nend", + "local function filter(tbl, pred)\n local result = {}\n for _, v in ipairs(tbl) do\n if pred(v) then\n table.insert(result, v)\n end\n end\n return result\nend", + "local function memoize(fn)\n local cache = {}\n return function(...)\n local key = table.concat({...}, \",\")\n if cache[key] == nil then\n cache[key] = fn(...)\n end\n return cache[key]\n end\nend", + "for i = 1, 10 do\n if i % 2 == 0 then\n print(i .. \" is even\")\n else\n print(i .. \" is odd\")\n end\nend", + "local function merge(a, b)\n local result = {}\n for k, v in pairs(a) do\n result[k] = v\n end\n for k, v in pairs(b) do\n result[k] = v\n end\n return result\nend", + "local function try_catch(fn, handler)\n local ok, err = pcall(fn)\n if not ok then\n handler(err)\n end\nend", + ] + } + fn get_snippets(&self) -> Vec<&'static str> { match self.language.as_str() { "rust" => Self::rust_snippets(), "python" => Self::python_snippets(), "javascript" | "js" => Self::javascript_snippets(), "go" => Self::go_snippets(), + "typescript" | "ts" => Self::typescript_snippets(), + "java" => Self::java_snippets(), + "c" => Self::c_snippets(), + "cpp" | "c++" => Self::cpp_snippets(), + "ruby" => Self::ruby_snippets(), + "swift" => Self::swift_snippets(), + "bash" => Self::bash_snippets(), + "lua" => Self::lua_snippets(), _ => Self::rust_snippets(), } } @@ -220,35 +1650,48 @@ impl TextGenerator for CodeSyntaxGenerator { _focused: Option, word_count: usize, ) -> String { - // Try to fetch from GitHub on first use - if self.fetched_snippets.is_empty() { - self.try_fetch_code(); - } - let embedded = self.get_snippets(); - let mut result = Vec::new(); - let target_words = word_count; - let mut current_words = 0; - let mut used_fetched = false; + let target_words = word_count.max(1); + let mut candidates: Vec<(bool, usize)> = Vec::new(); // (is_fetched, idx) + let min_units = (target_words / 3).max(4); - let total_available = embedded.len() + self.fetched_snippets.len(); - - while current_words < target_words { - let idx = self.rng.gen_range(0..total_available.max(1)); - - let snippet = if idx < embedded.len() { - embedded[idx] - } else if !self.fetched_snippets.is_empty() { - let f_idx = (idx - embedded.len()) % self.fetched_snippets.len(); - used_fetched = true; - &self.fetched_snippets[f_idx] - } else { - embedded[idx % embedded.len()] - }; - - current_words += snippet.split_whitespace().count(); - result.push(snippet.to_string()); + for (i, snippet) in embedded.iter().enumerate() { + if approx_token_count(snippet) >= min_units { + candidates.push((false, i)); + } } + for (i, snippet) in self.fetched_snippets.iter().enumerate() { + if approx_token_count(snippet) >= min_units { + candidates.push((true, i)); + } + } + + // If everything is short, fall back to all snippets. + if candidates.is_empty() { + for (i, _) in embedded.iter().enumerate() { + candidates.push((false, i)); + } + for (i, _) in self.fetched_snippets.iter().enumerate() { + candidates.push((true, i)); + } + } + if candidates.is_empty() { + return String::new(); + } + + let pick = self.rng.gen_range(0..candidates.len()); + let (is_fetched, idx) = candidates[pick]; + let used_fetched = is_fetched; + + let selected = if is_fetched { + self.fetched_snippets + .get(idx) + .map(|s| s.as_str()) + .unwrap_or_else(|| embedded[0]) + } else { + embedded.get(idx).copied().unwrap_or(embedded[0]) + }; + let text = fit_snippet_to_target(selected, target_words); self.last_source = if used_fetched { format!("GitHub source cache ({})", self.language) @@ -256,62 +1699,855 @@ impl TextGenerator for CodeSyntaxGenerator { format!("Built-in snippets ({})", self.language) }; - result.join("\n\n") + text } } +fn approx_token_count(text: &str) -> usize { + text.split_whitespace().count() +} + +fn fit_snippet_to_target(snippet: &str, target_units: usize) -> String { + let max_units = target_units.saturating_mul(3).saturating_div(2).max(target_units); + if approx_token_count(snippet) <= max_units { + return snippet.to_string(); + } + + let mut out_lines: Vec<&str> = Vec::new(); + let mut units = 0usize; + for line in snippet.lines() { + out_lines.push(line); + units = units.saturating_add(approx_token_count(line)); + if units >= target_units && out_lines.len() >= 2 { + break; + } + } + + if out_lines.is_empty() { + snippet.to_string() + } else { + out_lines.join("\n") + } +} + +/// Download code from a repo and save extracted snippets to cache. +pub fn download_code_repo_to_cache_with_progress( + cache_dir: &str, + language_key: &str, + repo: &CodeRepo, + block_style: &BlockStyle, + snippets_limit: usize, + mut on_progress: F, +) -> bool +where + F: FnMut(u64, Option), +{ + if let Err(_) = fs::create_dir_all(cache_dir) { + return false; + } + + let mut all_snippets = Vec::new(); + + for url in repo.urls { + let bytes = fetch_url_bytes_with_progress(url, &mut on_progress); + if let Some(bytes) = bytes { + if let Ok(content) = String::from_utf8(bytes) { + let snippets = extract_code_snippets(&content, block_style); + all_snippets.extend(snippets); + } + } + } + + if all_snippets.is_empty() { + return false; + } + + all_snippets.truncate(snippets_limit); + + let cache_path = std::path::Path::new(cache_dir) + .join(format!("{}_{}.txt", language_key, repo.key)); + let combined = all_snippets.join("\n---SNIPPET---\n"); + fs::write(cache_path, combined).is_ok() +} + /// Extract function-length snippets from raw source code, preserving whitespace. -fn extract_code_snippets(source: &str) -> Vec { - let mut snippets = Vec::new(); +/// Uses the given `BlockStyle` to determine how to find and delimit code blocks. +/// When keyword-based extraction yields fewer than 20 snippets, runs a structural +/// fallback pass to capture blocks by structure (brace depth, indentation, etc.). +pub fn extract_code_snippets(source: &str, block_style: &BlockStyle) -> Vec { let lines: Vec<&str> = source.lines().collect(); - let mut i = 0; - while i < lines.len() { - // Look for function/method starts - let line = lines[i].trim(); - let is_func_start = line.starts_with("fn ") - || line.starts_with("pub fn ") - || line.starts_with("def ") - || line.starts_with("func ") - || line.starts_with("function ") - || line.starts_with("async fn ") - || line.starts_with("pub async fn "); + let mut snippets = keyword_extract(&lines, block_style); - if is_func_start { + if snippets.len() < 20 { + let structural = structural_extract(&lines, block_style); + for s in structural { + if !snippets.contains(&s) { + snippets.push(s); + } + } + } + + snippets.truncate(200); + snippets +} + +/// Check if a snippet is "noise" (import-only, single-statement body, etc.) +fn is_noise_snippet(snippet: &str) -> bool { + let meaningful_lines: Vec<&str> = snippet + .lines() + .filter(|l| { + let t = l.trim(); + !t.is_empty() && !t.starts_with("//") && !t.starts_with('#') && !t.starts_with("/*") + && !t.starts_with('*') && !t.starts_with("*/") + }) + .collect(); + + if meaningful_lines.is_empty() { + return true; + } + + // Reject if first meaningful line is just `{` or `}` + let first = meaningful_lines[0].trim(); + if first == "{" || first == "}" { + return true; + } + + // Reject if body consists entirely of import/use/require/include statements + let import_prefixes = [ + "import ", "from ", "use ", "require", "#include", "using ", + "package ", "module ", "extern crate ", + ]; + let body_lines: Vec<&str> = meaningful_lines.iter().skip(1).copied().collect(); + if !body_lines.is_empty() + && body_lines.iter().all(|l| { + let t = l.trim(); + import_prefixes.iter().any(|p| t.starts_with(p)) || t == "{" || t == "}" + }) + { + return true; + } + + // Reject single-statement body (only 1 non-blank body line after opening) + let non_blank_body: Vec<&str> = snippet + .lines() + .skip(1) + .filter(|l| !l.trim().is_empty() && l.trim() != "}" && l.trim() != "end") + .collect(); + if non_blank_body.len() <= 1 && snippet.lines().count() <= 3 { + return true; + } + + false +} + +/// Validate a candidate snippet for quality. +fn is_valid_snippet(snippet: &str) -> bool { + let line_count = snippet.lines().count(); + if line_count < 3 || line_count > 30 { + return false; + } + let char_count = snippet.chars().filter(|c| !c.is_whitespace()).count(); + if char_count < 20 || snippet.len() > 800 { + return false; + } + if !snippet.contains('\n') { + return false; + } + !is_noise_snippet(snippet) +} + +/// Keyword-based extraction (the original algorithm). +fn keyword_extract(lines: &[&str], block_style: &BlockStyle) -> Vec { + let mut snippets = Vec::new(); + let mut i = 0; + + while i < lines.len() { + let trimmed = lines[i].trim(); + + match block_style { + BlockStyle::Braces(patterns) => { + if patterns.iter().any(|p| trimmed.starts_with(p)) { + let mut snippet_lines = Vec::new(); + let mut depth = 0i32; + let mut j = i; + + while j < lines.len() && snippet_lines.len() < 30 { + let l = lines[j]; + snippet_lines.push(l); + depth += l.chars().filter(|&c| c == '{').count() as i32; + depth -= l.chars().filter(|&c| c == '}').count() as i32; + if depth <= 0 && j > i { + break; + } + j += 1; + } + + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); + } + i = j + 1; + } else { + i += 1; + } + } + BlockStyle::Indentation(patterns) => { + if patterns.iter().any(|p| trimmed.starts_with(p)) { + let base_indent = lines[i].len() - lines[i].trim_start().len(); + let mut snippet_lines = vec![lines[i]]; + let mut j = i + 1; + + while j < lines.len() && snippet_lines.len() < 30 { + let l = lines[j]; + if l.trim().is_empty() { + snippet_lines.push(l); + j += 1; + continue; + } + let indent = l.len() - l.trim_start().len(); + if indent > base_indent { + snippet_lines.push(l); + j += 1; + } else { + break; + } + } + + while snippet_lines.last().map_or(false, |l| l.trim().is_empty()) { + snippet_lines.pop(); + } + + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); + } + i = j; + } else { + i += 1; + } + } + BlockStyle::EndDelimited(patterns) => { + if patterns.iter().any(|p| trimmed.starts_with(p)) { + let base_indent = lines[i].len() - lines[i].trim_start().len(); + let mut snippet_lines = vec![lines[i]]; + let mut j = i + 1; + + while j < lines.len() && snippet_lines.len() < 30 { + let l = lines[j]; + snippet_lines.push(l); + let l_trimmed = l.trim(); + let l_indent = l.len() - l.trim_start().len(); + if l_trimmed == "end" && l_indent <= base_indent { + break; + } + j += 1; + } + + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); + } + i = j + 1; + } else { + i += 1; + } + } + } + } + + snippets +} + +/// Structural fallback: extract code blocks by structure when keywords don't +/// find enough. Captures anonymous functions, nested blocks, and other constructs. +fn structural_extract(lines: &[&str], block_style: &BlockStyle) -> Vec { + match block_style { + BlockStyle::Braces(_) => structural_extract_braces(lines), + BlockStyle::Indentation(_) => structural_extract_indent(lines), + BlockStyle::EndDelimited(_) => structural_extract_end(lines), + } +} + +/// Structural extraction for brace-delimited languages. +/// Scans for lines containing `{` where brace depth transitions from low levels, +/// captures until depth returns. +fn structural_extract_braces(lines: &[&str]) -> Vec { + let mut snippets = Vec::new(); + let mut global_depth = 0i32; + let mut i = 0; + + while i < lines.len() { + let l = lines[i]; + let opens = l.chars().filter(|&c| c == '{').count() as i32; + let closes = l.chars().filter(|&c| c == '}').count() as i32; + let new_depth = global_depth + opens - closes; + + // Detect transition from depth 0→1 or 1→2 (entering a new block) + if opens > 0 && (global_depth == 0 || global_depth == 1) && new_depth > global_depth { + let start_depth = global_depth; let mut snippet_lines = Vec::new(); - let mut depth = 0i32; + let mut depth = global_depth; let mut j = i; while j < lines.len() && snippet_lines.len() < 30 { - let l = lines[j]; - snippet_lines.push(l); - - depth += l.chars().filter(|&c| c == '{' || c == '(').count() as i32; - depth -= l.chars().filter(|&c| c == '}' || c == ')').count() as i32; - - if depth <= 0 && j > i { + let sl = lines[j]; + snippet_lines.push(sl); + depth += sl.chars().filter(|&c| c == '{').count() as i32; + depth -= sl.chars().filter(|&c| c == '}').count() as i32; + if depth <= start_depth && j > i { break; } j += 1; } - if snippet_lines.len() >= 3 && snippet_lines.len() <= 30 { - // Preserve original newlines and indentation - let snippet = snippet_lines.join("\n"); - let char_count = snippet.chars().filter(|c| !c.is_whitespace()).count(); - // Require at least one newline (reject single-line snippets) - let has_newline = snippet.contains('\n'); - if char_count >= 20 && snippet.len() <= 800 && has_newline { - snippets.push(snippet); - } + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); } + // Continue from after the block + global_depth = depth; + i = j + 1; + } else { + global_depth = new_depth; + i += 1; + } + } + snippets +} + +/// Structural extraction for indentation-based languages. +/// Captures top-level non-blank lines followed by indented blocks. +fn structural_extract_indent(lines: &[&str]) -> Vec { + let mut snippets = Vec::new(); + let mut i = 0; + + while i < lines.len() { + let l = lines[i]; + if l.trim().is_empty() { + i += 1; + continue; + } + + let base_indent = l.len() - l.trim_start().len(); + // Only consider top-level or near-top-level lines (indent 0 or 4) + if base_indent > 4 { + i += 1; + continue; + } + + // Check if next non-blank line is indented more + let mut has_body = false; + let mut peek = i + 1; + while peek < lines.len() { + if lines[peek].trim().is_empty() { + peek += 1; + continue; + } + let peek_indent = lines[peek].len() - lines[peek].trim_start().len(); + has_body = peek_indent > base_indent; + break; + } + + if !has_body { + i += 1; + continue; + } + + let mut snippet_lines = vec![lines[i]]; + let mut j = i + 1; + + while j < lines.len() && snippet_lines.len() < 30 { + let sl = lines[j]; + if sl.trim().is_empty() { + snippet_lines.push(sl); + j += 1; + continue; + } + let indent = sl.len() - sl.trim_start().len(); + if indent > base_indent { + snippet_lines.push(sl); + j += 1; + } else { + break; + } + } + + while snippet_lines.last().map_or(false, |sl| sl.trim().is_empty()) { + snippet_lines.pop(); + } + + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); + } + i = j; + } + + snippets +} + +/// Structural extraction for end-delimited languages (Ruby, Lua, Elixir). +/// Captures top-level lines followed by body ending with `end`. +fn structural_extract_end(lines: &[&str]) -> Vec { + let mut snippets = Vec::new(); + let mut i = 0; + + while i < lines.len() { + let l = lines[i]; + if l.trim().is_empty() { + i += 1; + continue; + } + + let base_indent = l.len() - l.trim_start().len(); + // Only consider top-level or near-top-level lines + if base_indent > 4 { + i += 1; + continue; + } + + // Look ahead for a matching `end` at same or lesser indent + let mut snippet_lines = vec![lines[i]]; + let mut j = i + 1; + let mut found_end = false; + + while j < lines.len() && snippet_lines.len() < 30 { + let sl = lines[j]; + snippet_lines.push(sl); + let sl_trimmed = sl.trim(); + let sl_indent = sl.len() - sl.trim_start().len(); + if sl_trimmed == "end" && sl_indent <= base_indent { + found_end = true; + break; + } + j += 1; + } + + if found_end { + let snippet = snippet_lines.join("\n"); + if is_valid_snippet(&snippet) { + snippets.push(snippet); + } i = j + 1; } else { i += 1; } } - snippets.truncate(50); snippets } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_raw_string_snippets_preserved() { + // Verify rust snippet content is correct after raw string conversion + let snippets = CodeSyntaxGenerator::rust_snippets(); + let main_snippet = snippets[0]; + assert!(main_snippet.contains("fn main()")); + assert!(main_snippet.contains("println!")); + assert!(main_snippet.contains('\n')); + assert_eq!(main_snippet.matches('{').count(), 1); + assert_eq!(main_snippet.matches('}').count(), 1); + + // Verify Python indentation preserved + let py_snippets = CodeSyntaxGenerator::python_snippets(); + let class_snippet = py_snippets[3]; // class Point + assert!(class_snippet.contains("class Point:")); + assert!(class_snippet.contains(" def __init__")); + assert!(class_snippet.contains(" self.x = x")); + + // Verify Go tabs preserved + let go_snippets = CodeSyntaxGenerator::go_snippets(); + let main_go = go_snippets[0]; + assert!(main_go.contains('\t')); + assert!(main_go.contains("fmt.Println")); + + // Verify JavaScript content + let js_snippets = CodeSyntaxGenerator::javascript_snippets(); + assert!(js_snippets[1].contains("function add")); + } + + #[test] + fn test_snippet_counts_unchanged() { + assert_eq!(CodeSyntaxGenerator::rust_snippets().len(), 30); + assert_eq!(CodeSyntaxGenerator::python_snippets().len(), 25); + assert_eq!(CodeSyntaxGenerator::javascript_snippets().len(), 23); + assert_eq!(CodeSyntaxGenerator::go_snippets().len(), 20); + assert_eq!(CodeSyntaxGenerator::typescript_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::java_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::c_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::cpp_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::ruby_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::swift_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::bash_snippets().len(), 12); + assert_eq!(CodeSyntaxGenerator::lua_snippets().len(), 12); + } + + #[test] + fn test_languages_with_content_includes_builtin() { + let langs = languages_with_content("/nonexistent/path"); + assert!(langs.contains(&"rust")); + assert!(langs.contains(&"python")); + assert!(langs.contains(&"javascript")); + assert!(langs.contains(&"go")); + assert!(langs.contains(&"typescript")); + assert!(langs.contains(&"java")); + assert!(langs.contains(&"c")); + assert!(langs.contains(&"cpp")); + assert!(langs.contains(&"ruby")); + assert!(langs.contains(&"swift")); + assert!(langs.contains(&"bash")); + assert!(langs.contains(&"lua")); + // Network-only languages should NOT appear without cache + assert!(!langs.contains(&"kotlin")); + assert!(!langs.contains(&"scala")); + } + + #[test] + fn test_code_language_options() { + let options = code_language_options(); + assert!(options.iter().any(|(k, _)| *k == "rust")); + assert!(options.iter().any(|(k, _)| *k == "all")); + assert_eq!(options.first().unwrap().0, "all"); + assert_eq!(options.first().unwrap().1, "All (random)"); + } + + #[test] + fn test_code_language_options_sorted_after_all() { + let options = code_language_options(); + assert!(!options.is_empty()); + assert_eq!(options[0].0, "all"); + for i in 1..options.len().saturating_sub(1) { + let a = options[i].1.to_lowercase(); + let b = options[i + 1].1.to_lowercase(); + assert!( + a <= b, + "Language options are not sorted at index {i}: '{}' > '{}'", + options[i].1, + options[i + 1].1 + ); + } + } + + #[test] + fn test_language_by_key() { + assert!(language_by_key("rust").is_some()); + assert_eq!(language_by_key("rust").unwrap().display_name, "Rust"); + assert!(language_by_key("nonexistent").is_none()); + } + + #[test] + fn test_is_language_cached_empty_dir() { + assert!(!is_language_cached("/nonexistent/path", "rust")); + } + + #[test] + fn test_config_code_language_options_valid() { + let options = code_language_options(); + let keys: Vec<&str> = options.iter().map(|(k, _)| *k).collect(); + // All CODE_LANGUAGES keys should appear + for lang in CODE_LANGUAGES { + assert!(keys.contains(&lang.key), "Missing key: {}", lang.key); + } + } + + #[test] + fn test_build_download_queue_single_language() { + // With a nonexistent cache dir, all repos should be queued + let queue = build_code_download_queue("rust", "/nonexistent/cache/dir"); + let rust_lang = language_by_key("rust").unwrap(); + assert_eq!(queue.len(), rust_lang.repos.len()); + for (lang_key, _) in &queue { + assert_eq!(lang_key, "rust"); + } + } + + #[test] + fn test_build_download_queue_all_languages() { + let queue = build_code_download_queue("all", "/nonexistent/cache/dir"); + // Should include repos from every language + let total_repos: usize = CODE_LANGUAGES.iter().map(|l| l.repos.len()).sum(); + assert_eq!(queue.len(), total_repos); + // Should include items from multiple languages + let unique_langs: std::collections::HashSet<&str> = + queue.iter().map(|(k, _)| k.as_str()).collect(); + assert!(unique_langs.len() > 1); + } + + #[test] + fn test_build_download_queue_invalid_language() { + let queue = build_code_download_queue("nonexistent_lang", "/nonexistent/cache/dir"); + assert!(queue.is_empty()); + } + + #[test] + fn test_build_download_queue_skips_cached() { + // Create a temp dir with a cached file for one rust repo + let tmp = std::env::temp_dir().join("keydr_test_queue_cache"); + let _ = fs::create_dir_all(&tmp); + let rust_lang = language_by_key("rust").unwrap(); + let first_repo = &rust_lang.repos[0]; + let cache_file = tmp.join(format!("rust_{}.txt", first_repo.key)); + fs::write(&cache_file, "some cached content").unwrap(); + + let queue = build_code_download_queue("rust", tmp.to_str().unwrap()); + // Should NOT include the cached repo + assert!( + !queue.iter().any(|(_, idx)| *idx == 0), + "Cached repo should be skipped" + ); + // Should still include other uncached repos + assert_eq!(queue.len(), rust_lang.repos.len() - 1); + + // Cleanup + let _ = fs::remove_dir_all(&tmp); + } + + #[test] + fn test_build_download_queue_empty_cache_file_not_skipped() { + // An empty cache file should still be queued (treated as uncached) + let tmp = std::env::temp_dir().join("keydr_test_queue_empty"); + let _ = fs::create_dir_all(&tmp); + let rust_lang = language_by_key("rust").unwrap(); + let first_repo = &rust_lang.repos[0]; + let cache_file = tmp.join(format!("rust_{}.txt", first_repo.key)); + fs::write(&cache_file, "").unwrap(); + + let queue = build_code_download_queue("rust", tmp.to_str().unwrap()); + // Empty file should still be in queue + assert_eq!(queue.len(), rust_lang.repos.len()); + + let _ = fs::remove_dir_all(&tmp); + } + + #[test] + fn test_extract_braces_style() { + let source = r#"fn hello() { + println!("hello"); + println!("world"); +} + +fn other() { + let x = 1; + let y = 2; +} +"#; + let style = BlockStyle::Braces(&["fn "]); + let snippets = extract_code_snippets(source, &style); + assert_eq!(snippets.len(), 2); + assert!(snippets[0].contains("hello")); + assert!(snippets[1].contains("other")); + } + + #[test] + fn test_extract_indentation_style() { + let source = r#"def greet(name): + msg = "Hello, " + name + print(msg) + return msg + +x = 42 + +def add(a, b): + result = a + b + return result +"#; + let style = BlockStyle::Indentation(&["def "]); + let snippets = extract_code_snippets(source, &style); + assert_eq!(snippets.len(), 2); + assert!(snippets[0].contains("greet")); + assert!(snippets[1].contains("add")); + } + + #[test] + fn test_extract_end_delimited_style() { + let source = r#"def fibonacci(n) + return n if n <= 1 + fibonacci(n - 1) + fibonacci(n - 2) +end + +def hello + puts "hello" + puts "world" +end +"#; + let style = BlockStyle::EndDelimited(&["def "]); + let snippets = extract_code_snippets(source, &style); + assert_eq!(snippets.len(), 2); + assert!(snippets[0].contains("fibonacci")); + assert!(snippets[1].contains("hello")); + } + + #[test] + fn test_extract_rejects_short_snippets() { + let source = r#"fn a() { + x +} +"#; + let style = BlockStyle::Braces(&["fn "]); + let snippets = extract_code_snippets(source, &style); + // 3 lines but < 20 non-whitespace chars + assert_eq!(snippets.len(), 0); + } + + #[test] + fn test_extract_indentation_with_blank_lines() { + let source = r#"def complex(): + x = 1 + + y = 2 + + return x + y + 42 + 100 + +z = 99 +"#; + let style = BlockStyle::Indentation(&["def "]); + let snippets = extract_code_snippets(source, &style); + assert_eq!(snippets.len(), 1); + assert!(snippets[0].contains("x = 1")); + assert!(snippets[0].contains("y = 2")); + assert!(snippets[0].contains("return")); + } + + #[test] + fn test_total_language_count() { + // 12 built-in + 18 network-only = 30 + assert_eq!(CODE_LANGUAGES.len(), 30); + let builtin_count = CODE_LANGUAGES.iter().filter(|l| l.has_builtin).count(); + assert_eq!(builtin_count, 12); + let network_count = CODE_LANGUAGES.iter().filter(|l| !l.has_builtin).count(); + assert_eq!(network_count, 18); + } + + #[test] + fn test_fit_snippet_to_target_trims_large_snippet() { + let snippet = "line one words here\nline two words here\nline three words here\nline four words here\nline five words here"; + let fitted = fit_snippet_to_target(snippet, 6); + assert!(approx_token_count(&fitted) <= 9); // 1.5x target + assert!(fitted.lines().count() >= 2); + } + + /// Fetches every repo URL for all languages, runs extraction, and prints + /// a summary with example snippets. Run with: + /// cargo test --features network test_verify_repo_urls -- --ignored --nocapture + #[test] + #[ignore] + fn test_verify_repo_urls() { + use crate::generator::cache::fetch_url; + + let mut total_ok = 0usize; + let mut total_fail = 0usize; + let mut langs_with_no_snippets: Vec<&str> = Vec::new(); + + for lang in CODE_LANGUAGES { + println!("\n{}", "=".repeat(60)); + println!("Language: {} ({})", lang.display_name, lang.key); + println!(" Built-in: {}", lang.has_builtin); + println!(" Repos: {}", lang.repos.len()); + + let mut lang_total_snippets = 0usize; + + for repo in lang.repos { + println!("\n Repo: {}", repo.key); + + for url in repo.urls { + let short_url = if url.len() > 80 { + format!("{}...", &url[..77]) + } else { + url.to_string() + }; + + match fetch_url(url) { + Some(content) => { + let lines = content.lines().count(); + let bytes = content.len(); + println!(" OK {short_url}"); + println!(" ({lines} lines, {bytes} bytes)"); + total_ok += 1; + + let snippets = + extract_code_snippets(&content, &lang.block_style); + println!(" Extracted {} snippets", snippets.len()); + lang_total_snippets += snippets.len(); + + // Show first 2 snippets (truncated) + for (si, snippet) in snippets.iter().take(2).enumerate() { + let preview: String = snippet + .lines() + .take(5) + .collect::>() + .join("\n"); + let suffix = if snippet.lines().count() > 5 { + "\n ..." + } else { + "" + }; + let indented: String = preview + .lines() + .map(|l| format!(" {l}")) + .collect::>() + .join("\n"); + println!( + " --- snippet {} ---\n{}{}", + si + 1, indented, suffix, + ); + } + } + None => { + println!(" FAIL {short_url}"); + total_fail += 1; + } + } + } + } + + println!( + "\n TOTAL for {}: {} snippets", + lang.key, lang_total_snippets + ); + if lang_total_snippets == 0 && !lang.repos.is_empty() { + langs_with_no_snippets.push(lang.key); + } + } + + println!("\n{}", "=".repeat(60)); + println!("SUMMARY"); + println!(" URLs fetched OK: {total_ok}"); + println!(" URLs failed: {total_fail}"); + println!( + " Languages with 0 extracted snippets: {:?}", + langs_with_no_snippets + ); + + if total_fail > 0 { + println!("\nWARNING: {total_fail} URL(s) failed to fetch"); + } + if !langs_with_no_snippets.is_empty() { + println!( + "\nWARNING: {} language(s) produced 0 snippets from downloads", + langs_with_no_snippets.len() + ); + } + } + + #[test] + fn test_all_languages_have_extraction_patterns() { + for lang in CODE_LANGUAGES { + let pattern_count = match &lang.block_style { + BlockStyle::Braces(pats) => pats.len(), + BlockStyle::Indentation(pats) => pats.len(), + BlockStyle::EndDelimited(pats) => pats.len(), + }; + assert!( + pattern_count > 0, + "Language '{}' has empty extraction patterns — downloads will never yield snippets", + lang.key + ); + } + } +} diff --git a/src/generator/github_code.rs b/src/generator/github_code.rs deleted file mode 100644 index a0dc3ac..0000000 --- a/src/generator/github_code.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::engine::filter::CharFilter; -use crate::generator::TextGenerator; - -#[allow(dead_code)] -pub struct GitHubCodeGenerator { - cached_snippets: Vec, - current_idx: usize, -} - -impl GitHubCodeGenerator { - #[allow(dead_code)] - pub fn new() -> Self { - Self { - cached_snippets: Vec::new(), - current_idx: 0, - } - } -} - -impl Default for GitHubCodeGenerator { - fn default() -> Self { - Self::new() - } -} - -impl TextGenerator for GitHubCodeGenerator { - fn generate( - &mut self, - _filter: &CharFilter, - _focused: Option, - _word_count: usize, - ) -> String { - if self.cached_snippets.is_empty() { - return "// GitHub code fetching not yet configured. Use settings to add a repository." - .to_string(); - } - let snippet = self.cached_snippets[self.current_idx % self.cached_snippets.len()].clone(); - self.current_idx += 1; - snippet - } -} diff --git a/src/generator/mod.rs b/src/generator/mod.rs index 6f2d97b..6700699 100644 --- a/src/generator/mod.rs +++ b/src/generator/mod.rs @@ -3,7 +3,6 @@ pub mod capitalize; pub mod code_patterns; pub mod code_syntax; pub mod dictionary; -pub mod github_code; pub mod numbers; pub mod passage; pub mod phonetic; diff --git a/src/main.rs b/src/main.rs index 0b9f623..137ccc0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,12 +26,13 @@ use ratatui::backend::CrosstermBackend; use ratatui::layout::{Constraint, Direction, Layout, Rect}; use ratatui::style::{Modifier, Style}; use ratatui::text::{Line, Span}; -use ratatui::widgets::{Block, Paragraph, Widget}; +use ratatui::widgets::{Block, Paragraph, Widget, Wrap}; use app::{App, AppScreen, DrillMode}; use engine::skill_tree::DrillScope; use event::{AppEvent, EventHandler}; -use generator::passage::passage_options; +use generator::code_syntax::{code_language_options, is_language_cached, language_by_key}; +use generator::passage::{is_book_cached, passage_options}; use ui::components::dashboard::Dashboard; use ui::components::keyboard_diagram::KeyboardDiagram; use ui::components::skill_tree::{SkillTreeWidget, detail_line_count, selectable_branches}; @@ -123,6 +124,12 @@ fn run_app( { app.process_passage_download_tick(); } + if (app.screen == AppScreen::CodeIntro + || app.screen == AppScreen::CodeDownloadProgress) + && app.code_intro_downloading + { + app.process_code_download_tick(); + } // Fallback: clear depressed keys after 150ms if no Release event received if let Some(last) = app.last_key_time { if last.elapsed() > Duration::from_millis(150) && !app.depressed_keys.is_empty() @@ -184,6 +191,8 @@ fn handle_key(app: &mut App, key: KeyEvent) { AppScreen::PassageBookSelect => handle_passage_book_key(app, key), AppScreen::PassageIntro => handle_passage_intro_key(app, key), AppScreen::PassageDownloadProgress => handle_passage_download_progress_key(app, key), + AppScreen::CodeIntro => handle_code_intro_key(app, key), + AppScreen::CodeDownloadProgress => handle_code_download_progress_key(app, key), } } @@ -196,10 +205,18 @@ fn handle_menu_key(app: &mut App, key: KeyEvent) { app.start_drill(); } KeyCode::Char('2') => { - app.go_to_code_language_select(); + if app.config.code_onboarding_done { + app.go_to_code_language_select(); + } else { + app.go_to_code_intro(); + } } KeyCode::Char('3') => { - app.go_to_passage_book_select(); + if app.config.passage_onboarding_done { + app.go_to_passage_book_select(); + } else { + app.go_to_passage_intro(); + } } KeyCode::Char('t') => app.go_to_skill_tree(), KeyCode::Char('s') => app.go_to_stats(), @@ -213,10 +230,18 @@ fn handle_menu_key(app: &mut App, key: KeyEvent) { app.start_drill(); } 1 => { - app.go_to_code_language_select(); + if app.config.code_onboarding_done { + app.go_to_code_language_select(); + } else { + app.go_to_code_intro(); + } } 2 => { - app.go_to_passage_book_select(); + if app.config.passage_onboarding_done { + app.go_to_passage_book_select(); + } else { + app.go_to_passage_intro(); + } } 3 => app.go_to_skill_tree(), 4 => app.go_to_stats(), @@ -342,16 +367,26 @@ fn handle_stats_key(app: &mut App, key: KeyEvent) { } fn handle_settings_key(app: &mut App, key: KeyEvent) { + const MAX_SETTINGS: usize = 11; + if app.settings_editing_download_dir { match key.code { KeyCode::Esc => { app.settings_editing_download_dir = false; } KeyCode::Backspace => { - app.config.passage_download_dir.pop(); + if app.settings_selected == 5 { + app.config.code_download_dir.pop(); + } else if app.settings_selected == 9 { + app.config.passage_download_dir.pop(); + } } KeyCode::Char(ch) if !key.modifiers.contains(KeyModifiers::CONTROL) => { - app.config.passage_download_dir.push(ch); + if app.settings_selected == 5 { + app.config.code_download_dir.push(ch); + } else if app.settings_selected == 9 { + app.config.passage_download_dir.push(ch); + } } _ => {} } @@ -369,29 +404,29 @@ fn handle_settings_key(app: &mut App, key: KeyEvent) { } } KeyCode::Down | KeyCode::Char('j') => { - if app.settings_selected < 7 { + if app.settings_selected < MAX_SETTINGS { app.settings_selected += 1; } } KeyCode::Enter => { - if app.settings_selected == 5 { - app.settings_editing_download_dir = true; - } else if app.settings_selected == 7 { - app.start_passage_downloads_from_settings(); - } else { - app.settings_cycle_forward(); + match app.settings_selected { + 5 | 9 => app.settings_editing_download_dir = true, + 7 => app.start_code_downloads_from_settings(), + 11 => app.start_passage_downloads_from_settings(), + _ => app.settings_cycle_forward(), } } KeyCode::Right | KeyCode::Char('l') => { - if app.settings_selected < 5 { - app.settings_cycle_forward(); - } else if app.settings_selected == 6 { - app.settings_cycle_forward(); + // Allow cycling for non-text, non-button fields + match app.settings_selected { + 5 | 7 | 9 | 11 => {} // text fields or action buttons + _ => app.settings_cycle_forward(), } } KeyCode::Left | KeyCode::Char('h') => { - if app.settings_selected < 5 || app.settings_selected == 6 { - app.settings_cycle_backward(); + match app.settings_selected { + 5 | 7 | 9 | 11 => {} // text fields or action buttons + _ => app.settings_cycle_backward(), } } _ => {} @@ -399,7 +434,11 @@ fn handle_settings_key(app: &mut App, key: KeyEvent) { } fn handle_code_language_key(app: &mut App, key: KeyEvent) { - const LANGS: &[&str] = &["rust", "python", "javascript", "go", "all"]; + let options = code_language_options(); + let len = options.len(); + if len == 0 { + return; + } match key.code { KeyCode::Esc | KeyCode::Char('q') => app.go_to_menu(), @@ -407,44 +446,68 @@ fn handle_code_language_key(app: &mut App, key: KeyEvent) { app.code_language_selected = app.code_language_selected.saturating_sub(1); } KeyCode::Down | KeyCode::Char('j') => { - if app.code_language_selected + 1 < LANGS.len() { + if app.code_language_selected + 1 < len { app.code_language_selected += 1; } } - KeyCode::Char('1') => { + KeyCode::PageUp => { + app.code_language_selected = app.code_language_selected.saturating_sub(10); + } + KeyCode::PageDown => { + app.code_language_selected = (app.code_language_selected + 10).min(len - 1); + } + KeyCode::Home | KeyCode::Char('g') => { app.code_language_selected = 0; - start_code_drill(app, LANGS); } - KeyCode::Char('2') => { - app.code_language_selected = 1; - start_code_drill(app, LANGS); - } - KeyCode::Char('3') => { - app.code_language_selected = 2; - start_code_drill(app, LANGS); - } - KeyCode::Char('4') => { - app.code_language_selected = 3; - start_code_drill(app, LANGS); - } - KeyCode::Char('5') => { - app.code_language_selected = 4; - start_code_drill(app, LANGS); + KeyCode::End | KeyCode::Char('G') => { + app.code_language_selected = len - 1; } KeyCode::Enter => { - start_code_drill(app, LANGS); + if app.code_language_selected >= options.len() { + return; + } + let key = options[app.code_language_selected].0; + if !is_code_language_disabled(app, key) { + confirm_code_language_and_continue(app, &options); + } } _ => {} } + + // Adjust scroll to keep selected item visible. + // Use a rough viewport estimate; render will use exact terminal size. + let viewport = 15usize; + if app.code_language_selected < app.code_language_scroll { + app.code_language_scroll = app.code_language_selected; + } else if app.code_language_selected >= app.code_language_scroll + viewport { + app.code_language_scroll = app.code_language_selected + 1 - viewport; + } } -fn start_code_drill(app: &mut App, langs: &[&str]) { - if app.code_language_selected < langs.len() { - app.config.code_language = langs[app.code_language_selected].to_string(); - let _ = app.config.save(); - app.drill_mode = DrillMode::Code; - app.drill_scope = DrillScope::Global; - app.start_drill(); +fn code_language_requires_download(app: &App, key: &str) -> bool { + if key == "all" { + return false; + } + let Some(lang) = language_by_key(key) else { + return false; + }; + !lang.has_builtin && !is_language_cached(&app.config.code_download_dir, key) +} + +fn is_code_language_disabled(app: &App, key: &str) -> bool { + !app.config.code_downloads_enabled && code_language_requires_download(app, key) +} + +fn confirm_code_language_and_continue(app: &mut App, options: &[(&str, String)]) { + if app.code_language_selected >= options.len() { + return; + } + app.config.code_language = options[app.code_language_selected].0.to_string(); + let _ = app.config.save(); + if app.config.code_onboarding_done { + app.start_code_drill(); + } else { + app.go_to_code_intro(); } } @@ -464,16 +527,32 @@ fn handle_passage_book_key(app: &mut App, key: KeyEvent) { let idx = (ch as usize).saturating_sub('1' as usize); if idx < options.len() { app.passage_book_selected = idx; - confirm_passage_book_and_continue(app, &options); + let key = options[idx].0; + if !is_passage_option_disabled(app, key) { + confirm_passage_book_and_continue(app, &options); + } } } KeyCode::Enter => { - confirm_passage_book_and_continue(app, &options); + if app.passage_book_selected < options.len() { + let key = options[app.passage_book_selected].0; + if !is_passage_option_disabled(app, key) { + confirm_passage_book_and_continue(app, &options); + } + } } _ => {} } } +fn passage_option_requires_download(app: &App, key: &str) -> bool { + key != "all" && key != "builtin" && !is_book_cached(&app.config.passage_download_dir, key) +} + +fn is_passage_option_disabled(app: &App, key: &str) -> bool { + !app.config.passage_downloads_enabled && passage_option_requires_download(app, key) +} + fn confirm_passage_book_and_continue(app: &mut App, options: &[(&'static str, String)]) { if app.passage_book_selected >= options.len() { return; @@ -564,7 +643,7 @@ fn handle_passage_intro_key(app: &mut App, key: KeyEvent) { app.config.passage_paragraphs_per_book = app.passage_intro_paragraph_limit; app.config.passage_onboarding_done = true; let _ = app.config.save(); - app.start_passage_drill(); + app.go_to_passage_book_select(); } _ => {} } @@ -577,6 +656,98 @@ fn handle_passage_download_progress_key(app: &mut App, key: KeyEvent) { } } +fn handle_code_intro_key(app: &mut App, key: KeyEvent) { + const INTRO_FIELDS: usize = 4; + + if app.code_intro_downloading { + return; + } + + match key.code { + KeyCode::Esc | KeyCode::Char('q') => app.go_to_menu(), + KeyCode::Up | KeyCode::Char('k') => { + app.code_intro_selected = app.code_intro_selected.saturating_sub(1); + } + KeyCode::Down | KeyCode::Char('j') => { + if app.code_intro_selected + 1 < INTRO_FIELDS { + app.code_intro_selected += 1; + } + } + KeyCode::Left | KeyCode::Char('h') => match app.code_intro_selected { + 0 => app.code_intro_downloads_enabled = !app.code_intro_downloads_enabled, + 2 => { + app.code_intro_snippets_per_repo = match app.code_intro_snippets_per_repo { + 0 => 200, + 1 => 0, + n => n.saturating_sub(10).max(1), + }; + } + _ => {} + }, + KeyCode::Right | KeyCode::Char('l') => match app.code_intro_selected { + 0 => app.code_intro_downloads_enabled = !app.code_intro_downloads_enabled, + 2 => { + app.code_intro_snippets_per_repo = match app.code_intro_snippets_per_repo { + 0 => 1, + n if n >= 200 => 0, + n => n + 10, + }; + } + _ => {} + }, + KeyCode::Backspace => match app.code_intro_selected { + 1 => { + app.code_intro_download_dir.pop(); + } + 2 => { + app.code_intro_snippets_per_repo /= 10; + } + _ => {} + }, + KeyCode::Char(ch) => match app.code_intro_selected { + 1 if !key.modifiers.contains(KeyModifiers::CONTROL) => { + app.code_intro_download_dir.push(ch); + } + 2 if ch.is_ascii_digit() => { + let digit = (ch as u8 - b'0') as usize; + app.code_intro_snippets_per_repo = app + .code_intro_snippets_per_repo + .saturating_mul(10) + .saturating_add(digit) + .min(10_000); + } + _ => {} + }, + KeyCode::Enter => { + if app.code_intro_selected == 0 { + app.code_intro_downloads_enabled = !app.code_intro_downloads_enabled; + return; + } + if app.code_intro_selected != 3 { + return; + } + + app.config.code_downloads_enabled = app.code_intro_downloads_enabled; + app.config.code_download_dir = app.code_intro_download_dir.clone(); + app.config.code_snippets_per_repo = app.code_intro_snippets_per_repo; + app.config.code_onboarding_done = true; + let _ = app.config.save(); + app.go_to_code_language_select(); + } + _ => {} + } +} + +fn handle_code_download_progress_key(app: &mut App, key: KeyEvent) { + match key.code { + KeyCode::Esc | KeyCode::Char('q') => { + app.cancel_code_download(); + app.go_to_menu(); + } + _ => {} + } +} + fn handle_skill_tree_key(app: &mut App, key: KeyEvent) { const DETAIL_SCROLL_STEP: usize = 10; let max_scroll = skill_tree_detail_max_scroll(app); @@ -684,6 +855,8 @@ fn render(frame: &mut ratatui::Frame, app: &App) { AppScreen::PassageBookSelect => render_passage_book_select(frame, app), AppScreen::PassageIntro => render_passage_intro(frame, app), AppScreen::PassageDownloadProgress => render_passage_download_progress(frame, app), + AppScreen::CodeIntro => render_code_intro(frame, app), + AppScreen::CodeDownloadProgress => render_code_download_progress(frame, app), } } @@ -966,21 +1139,51 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { let inner = block.inner(centered); block.render(centered, frame.buffer_mut()); - let available_themes = ui::theme::Theme::available_themes(); - let languages_all = ["rust", "python", "javascript", "go", "all"]; - let current_lang = &app.config.code_language; - - let fields: Vec<(String, String)> = vec![ + let fields: Vec<(String, String, bool)> = vec![ ( "Target WPM".to_string(), format!("{}", app.config.target_wpm), + false, ), - ("Theme".to_string(), app.config.theme.clone()), + ("Theme".to_string(), app.config.theme.clone(), false), ( "Word Count".to_string(), format!("{}", app.config.word_count), + false, + ), + ( + "Code Language".to_string(), + app.config.code_language.clone(), + false, + ), + ( + "Code Downloads".to_string(), + if app.config.code_downloads_enabled { + "On".to_string() + } else { + "Off".to_string() + }, + false, + ), + ( + "Code Download Dir".to_string(), + app.config.code_download_dir.clone(), + true, // path field + ), + ( + "Snippets per Repo".to_string(), + if app.config.code_snippets_per_repo == 0 { + "Unlimited".to_string() + } else { + format!("{}", app.config.code_snippets_per_repo) + }, + false, + ), + ( + "Download Code Now".to_string(), + "Run downloader".to_string(), + false, ), - ("Code Language".to_string(), current_lang.clone()), ( "Passage Downloads".to_string(), if app.config.passage_downloads_enabled { @@ -988,10 +1191,12 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { } else { "Off".to_string() }, + false, ), ( "Passage Download Dir".to_string(), app.config.passage_download_dir.clone(), + true, // path field ), ( "Paragraphs per Book".to_string(), @@ -1000,10 +1205,12 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { } else { format!("{}", app.config.passage_paragraphs_per_book) }, + false, ), ( "Download Passages Now".to_string(), "Run downloader".to_string(), + false, ), ]; @@ -1033,12 +1240,13 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { ) .split(layout[1]); - for (i, (label, value)) in fields.iter().enumerate() { + for (i, (label, value, is_path)) in fields.iter().enumerate() { let is_selected = i == app.settings_selected; let indicator = if is_selected { " > " } else { " " }; let label_text = format!("{indicator}{label}:"); - let value_text = if i == 7 { + let is_button = i == 7 || i == 11; // Download Code Now, Download Passages Now + let value_text = if is_button { format!(" [ {value} ]") } else { format!(" < {value} >") @@ -1062,7 +1270,7 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { colors.text_pending() }); - let lines = if i == 5 { + let lines = if *is_path { let path_line = if app.settings_editing_download_dir && is_selected { format!(" {value}_") } else { @@ -1088,23 +1296,74 @@ fn render_settings(frame: &mut ratatui::Frame, app: &App) { Paragraph::new(lines).render(field_layout[i], frame.buffer_mut()); } - let _ = (available_themes, languages_all); + let footer_hints: Vec<&str> = if app.settings_editing_download_dir { + vec!["Editing path:", "[Type/Backspace] Modify", "[ESC] Done editing"] + } else { + vec![ + "[ESC] Save & back", + "[Enter/arrows] Change value", + "[Enter on path] Edit dir", + ] + }; + let footer_lines: Vec = pack_hint_lines(&footer_hints, layout[3].width as usize) + .into_iter() + .map(|line| Line::from(Span::styled(line, Style::default().fg(colors.accent())))) + .collect(); + Paragraph::new(footer_lines) + .wrap(Wrap { trim: false }) + .render(layout[3], frame.buffer_mut()); +} - let footer = Paragraph::new(Line::from(Span::styled( - if app.settings_editing_download_dir { - " Editing path: [Type/Backspace] Modify [ESC] Done editing" +fn wrapped_line_count(text: &str, width: usize) -> usize { + if width == 0 { + return 0; + } + let chars = text.chars().count().max(1); + chars.div_ceil(width) +} + +fn pack_hint_lines(hints: &[&str], width: usize) -> Vec { + if width == 0 || hints.is_empty() { + return Vec::new(); + } + + let prefix = " "; + let separator = " "; + let mut out: Vec = Vec::new(); + let mut current = prefix.to_string(); + let mut has_hint = false; + + for hint in hints { + if hint.is_empty() { + continue; + } + let candidate = if has_hint { + format!("{current}{separator}{hint}") } else { - " [ESC] Save & back [Enter/arrows] Change value [Enter on path] Edit dir" - }, - Style::default().fg(colors.accent()), - ))); - footer.render(layout[3], frame.buffer_mut()); + format!("{current}{hint}") + }; + if candidate.chars().count() <= width { + current = candidate; + has_hint = true; + } else { + if has_hint { + out.push(current); + } + current = format!("{prefix}{hint}"); + has_hint = true; + } + } + + if has_hint { + out.push(current); + } + out } fn render_code_language_select(frame: &mut ratatui::Frame, app: &App) { let area = frame.area(); let colors = &app.theme.colors; - let centered = ui::layout::centered_rect(40, 50, area); + let centered = ui::layout::centered_rect(50, 70, area); let block = Block::bordered() .title(" Select Code Language ") @@ -1113,20 +1372,83 @@ fn render_code_language_select(frame: &mut ratatui::Frame, app: &App) { let inner = block.inner(centered); block.render(centered, frame.buffer_mut()); - let langs = ["Rust", "Python", "JavaScript", "Go", "All (random)"]; - let lang_keys = ["rust", "python", "javascript", "go", "all"]; + let options = code_language_options(); + let cache_dir = &app.config.code_download_dir; + let footer_hints = [ + "[Up/Down/PgUp/PgDn] Navigate", + "[Enter] Confirm", + "[ESC] Back", + ]; + let disabled_notice = + " Some languages are disabled: enable network downloads in intro/settings."; + let has_disabled = !app.config.code_downloads_enabled + && options + .iter() + .any(|(key, _)| is_code_language_disabled(app, key)); + let width = inner.width as usize; + let hint_lines_vec = pack_hint_lines(&footer_hints, width); + let hint_lines = hint_lines_vec.len(); + let notice_lines = wrapped_line_count(disabled_notice, width); + let total_height = inner.height as usize; + let show_notice = has_disabled && total_height >= hint_lines + notice_lines + 3; + let desired_footer_height = hint_lines + if show_notice { notice_lines } else { 0 }; + let footer_height = desired_footer_height.min(total_height.saturating_sub(1)) as u16; + let (list_area, footer_area) = if footer_height > 0 { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(footer_height)]) + .split(inner); + (chunks[0], Some(chunks[1])) + } else { + (inner, None) + }; + + let viewport_height = (list_area.height as usize).saturating_sub(2).max(1); + let scroll = app.code_language_scroll; let mut lines: Vec = Vec::new(); - lines.push(Line::from("")); - for (i, &lang) in langs.iter().enumerate() { + // Show scroll indicator at top if scrolled down + if scroll > 0 { + lines.push(Line::from(Span::styled( + format!(" ... {} more above ...", scroll), + Style::default().fg(colors.text_pending()), + ))); + } else { + lines.push(Line::from("")); + } + + let visible_end = (scroll + viewport_height).min(options.len()); + + for i in scroll..visible_end { + let (key, display) = &options[i]; let is_selected = i == app.code_language_selected; - let is_current = lang_keys[i] == app.config.code_language; + let is_current = *key == app.config.code_language; + let is_disabled = is_code_language_disabled(app, key); let indicator = if is_selected { " > " } else { " " }; let current_marker = if is_current { " (current)" } else { "" }; - let style = if is_selected { + // Determine availability label + let availability = if *key == "all" { + String::new() + } else if let Some(lang) = language_by_key(key) { + if lang.has_builtin { + " (built-in)".to_string() + } else if is_language_cached(cache_dir, key) { + " (cached)".to_string() + } else if is_disabled { + " (disabled: download required)".to_string() + } else { + " (download required)".to_string() + } + } else { + String::new() + }; + + let style = if is_disabled { + Style::default().fg(colors.text_pending()) + } else if is_selected { Style::default() .fg(colors.accent()) .add_modifier(Modifier::BOLD) @@ -1135,18 +1457,38 @@ fn render_code_language_select(frame: &mut ratatui::Frame, app: &App) { }; lines.push(Line::from(Span::styled( - format!("{indicator}[{}] {lang}{current_marker}", i + 1), + format!("{indicator}{display}{current_marker}{availability}"), style, ))); } - lines.push(Line::from("")); - lines.push(Line::from(Span::styled( - " [1-5] Select [Enter] Confirm [ESC] Back", - Style::default().fg(colors.text_pending()), - ))); + // Show scroll indicator at bottom if more items below + if visible_end < options.len() { + lines.push(Line::from(Span::styled( + format!(" ... {} more below ...", options.len() - visible_end), + Style::default().fg(colors.text_pending()), + ))); + } else { + lines.push(Line::from("")); + } - Paragraph::new(lines).render(inner, frame.buffer_mut()); + Paragraph::new(lines).render(list_area, frame.buffer_mut()); + + if let Some(footer) = footer_area { + let mut footer_lines: Vec = hint_lines_vec + .iter() + .map(|line| Line::from(Span::styled(line.clone(), Style::default().fg(colors.text_pending())))) + .collect(); + if show_notice { + footer_lines.push(Line::from(Span::styled( + disabled_notice, + Style::default().fg(colors.text_pending()), + ))); + } + Paragraph::new(footer_lines) + .wrap(Wrap { trim: false }) + .render(footer, frame.buffer_mut()); + } } fn render_passage_book_select(frame: &mut ratatui::Frame, app: &App) { @@ -1162,11 +1504,55 @@ fn render_passage_book_select(frame: &mut ratatui::Frame, app: &App) { block.render(centered, frame.buffer_mut()); let options = passage_options(); - let mut lines: Vec = vec![Line::from("")]; - for (i, (_, label)) in options.iter().enumerate() { + let footer_hints = ["[Up/Down] Navigate", "[Enter] Confirm", "[ESC] Back"]; + let disabled_notice = + " Some sources are disabled: enable network downloads in intro/settings."; + let has_disabled = !app.config.passage_downloads_enabled + && options + .iter() + .any(|(key, _)| is_passage_option_disabled(app, key)); + let width = inner.width as usize; + let hint_lines_vec = pack_hint_lines(&footer_hints, width); + let hint_lines = hint_lines_vec.len(); + let notice_lines = wrapped_line_count(disabled_notice, width); + let total_height = inner.height as usize; + let show_notice = has_disabled && total_height >= hint_lines + notice_lines + 3; + let desired_footer_height = hint_lines + if show_notice { notice_lines } else { 0 }; + let footer_height = desired_footer_height.min(total_height.saturating_sub(1)) as u16; + let (list_area, footer_area) = if footer_height > 0 { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(footer_height)]) + .split(inner); + (chunks[0], Some(chunks[1])) + } else { + (inner, None) + }; + + let viewport_height = list_area.height as usize; + let start = app + .passage_book_selected + .saturating_sub(viewport_height.saturating_sub(1)); + let end = (start + viewport_height).min(options.len()); + let mut lines: Vec = vec![]; + for (i, (key, label)) in options.iter().enumerate().skip(start).take(end - start) { let is_selected = i == app.passage_book_selected; + let is_disabled = is_passage_option_disabled(app, key); let indicator = if is_selected { " > " } else { " " }; - let style = if is_selected { + let availability = if *key == "all" { + String::new() + } else if *key == "builtin" { + " (built-in)".to_string() + } else if is_book_cached(&app.config.passage_download_dir, key) { + " (cached)".to_string() + } else if is_disabled { + " (disabled: download required)".to_string() + } else { + " (download required)".to_string() + }; + let style = if is_disabled { + Style::default().fg(colors.text_pending()) + } else if is_selected { Style::default() .fg(colors.accent()) .add_modifier(Modifier::BOLD) @@ -1174,16 +1560,28 @@ fn render_passage_book_select(frame: &mut ratatui::Frame, app: &App) { Style::default().fg(colors.fg()) }; lines.push(Line::from(Span::styled( - format!("{indicator}[{}] {label}", i + 1), + format!("{indicator}[{}] {label}{availability}", i + 1), style, ))); } - lines.push(Line::from("")); - lines.push(Line::from(Span::styled( - " [Up/Down] Navigate [Enter] Confirm [ESC] Back", - Style::default().fg(colors.text_pending()), - ))); - Paragraph::new(lines).render(inner, frame.buffer_mut()); + + Paragraph::new(lines).render(list_area, frame.buffer_mut()); + + if let Some(footer) = footer_area { + let mut footer_lines: Vec = hint_lines_vec + .iter() + .map(|line| Line::from(Span::styled(line.clone(), Style::default().fg(colors.text_pending())))) + .collect(); + if show_notice { + footer_lines.push(Line::from(Span::styled( + disabled_notice, + Style::default().fg(colors.text_pending()), + ))); + } + Paragraph::new(footer_lines) + .wrap(Wrap { trim: false }) + .render(footer, frame.buffer_mut()); + } } fn render_passage_intro(frame: &mut ratatui::Frame, app: &App) { @@ -1312,18 +1710,47 @@ fn render_passage_intro(frame: &mut ratatui::Frame, app: &App) { Style::default().fg(colors.text_pending()), ))); } - } else { - lines.push(Line::from(Span::styled( - " [Up/Down] Navigate [Left/Right] Adjust [Type/Backspace] Edit [Enter] Confirm", - Style::default().fg(colors.text_pending()), - ))); - lines.push(Line::from(Span::styled( - " [ESC] Cancel", - Style::default().fg(colors.text_pending()), - ))); } - - Paragraph::new(lines).render(inner, frame.buffer_mut()); + let hint_lines = if app.passage_intro_downloading { + Vec::new() + } else { + pack_hint_lines( + &[ + "[Up/Down] Navigate", + "[Left/Right] Adjust", + "[Type/Backspace] Edit", + "[Enter] Confirm", + "[ESC] Cancel", + ], + inner.width as usize, + ) + }; + let footer_height = if hint_lines.is_empty() { + 0 + } else { + (hint_lines.len() + 1) as u16 // add spacer line above hints + }; + let (content_area, footer_area) = if footer_height > 0 && footer_height < inner.height { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(footer_height)]) + .split(inner); + (chunks[0], Some(chunks[1])) + } else { + (inner, None) + }; + Paragraph::new(lines).render(content_area, frame.buffer_mut()); + if let Some(footer) = footer_area { + let mut footer_lines = vec![Line::from("")]; + footer_lines.extend( + hint_lines + .into_iter() + .map(|hint| Line::from(Span::styled(hint, Style::default().fg(colors.text_pending())))), + ); + Paragraph::new(footer_lines) + .wrap(Wrap { trim: false }) + .render(footer, frame.buffer_mut()); + } } fn render_passage_download_progress(frame: &mut ratatui::Frame, app: &App) { @@ -1381,6 +1808,235 @@ fn render_passage_download_progress(frame: &mut ratatui::Frame, app: &App) { Paragraph::new(lines).render(inner, frame.buffer_mut()); } +fn render_code_intro(frame: &mut ratatui::Frame, app: &App) { + let area = frame.area(); + let colors = &app.theme.colors; + let centered = ui::layout::centered_rect(75, 80, area); + + let block = Block::bordered() + .title(" Code Downloads Setup ") + .border_style(Style::default().fg(colors.accent())) + .style(Style::default().bg(colors.bg())); + let inner = block.inner(centered); + block.render(centered, frame.buffer_mut()); + + let snippets_value = if app.code_intro_snippets_per_repo == 0 { + "unlimited".to_string() + } else { + app.code_intro_snippets_per_repo.to_string() + }; + + let fields = vec![ + ( + "Enable network downloads", + if app.code_intro_downloads_enabled { + "On".to_string() + } else { + "Off".to_string() + }, + ), + ("Download directory", app.code_intro_download_dir.clone()), + ("Snippets per repo (0 = unlimited)", snippets_value), + ("Start code drill", "Confirm".to_string()), + ]; + + let mut lines = vec![ + Line::from(Span::styled( + "Configure code source settings before your first code drill.", + Style::default() + .fg(colors.fg()) + .add_modifier(Modifier::BOLD), + )), + Line::from(Span::styled( + "Downloads are lazy: code is fetched only when first needed.", + Style::default().fg(colors.text_pending()), + )), + Line::from(Span::styled( + "If you exit without confirming, this dialog will appear again next time.", + Style::default().fg(colors.text_pending()), + )), + Line::from(""), + ]; + + for (i, (label, value)) in fields.iter().enumerate() { + let is_selected = i == app.code_intro_selected; + let indicator = if is_selected { " > " } else { " " }; + let label_style = if is_selected { + Style::default() + .fg(colors.accent()) + .add_modifier(Modifier::BOLD) + } else { + Style::default().fg(colors.fg()) + }; + let value_style = if is_selected { + Style::default().fg(colors.focused_key()) + } else { + Style::default().fg(colors.text_pending()) + }; + + lines.push(Line::from(Span::styled( + format!("{indicator}{label}"), + label_style, + ))); + if i == 1 { + lines.push(Line::from(Span::styled(format!(" {value}"), value_style))); + } else if i == 3 { + lines.push(Line::from(Span::styled( + format!(" [{value}]"), + value_style, + ))); + } else { + lines.push(Line::from(Span::styled( + format!(" < {value} >"), + value_style, + ))); + } + lines.push(Line::from("")); + } + + if app.code_intro_downloading { + let total_repos = app.code_intro_download_total.max(1); + let done_repos = app.code_intro_downloaded.min(total_repos); + let total_bytes = app.code_intro_download_bytes_total; + let done_bytes = app + .code_intro_download_bytes + .min(total_bytes.max(app.code_intro_download_bytes)); + let width = 30usize; + let fill = if total_bytes > 0 { + ((done_bytes as usize).saturating_mul(width)) / (total_bytes as usize) + } else { + 0 + }; + let bar = format!( + "{}{}", + "=".repeat(fill), + " ".repeat(width.saturating_sub(fill)) + ); + let progress_text = if total_bytes > 0 { + format!(" Downloading: [{bar}] {done_bytes}/{total_bytes} bytes") + } else { + format!(" Downloading: {done_bytes} bytes") + }; + lines.push(Line::from(Span::styled( + progress_text, + Style::default() + .fg(colors.accent()) + .add_modifier(Modifier::BOLD), + ))); + if !app.code_intro_current_repo.is_empty() { + lines.push(Line::from(Span::styled( + format!( + " Current: {} (repo {}/{})", + app.code_intro_current_repo, + done_repos.saturating_add(1).min(total_repos), + total_repos + ), + Style::default().fg(colors.text_pending()), + ))); + } + } + let hint_lines = if app.code_intro_downloading { + Vec::new() + } else { + pack_hint_lines( + &[ + "[Up/Down] Navigate", + "[Left/Right] Adjust", + "[Type/Backspace] Edit", + "[Enter] Confirm", + "[ESC] Cancel", + ], + inner.width as usize, + ) + }; + let footer_height = if hint_lines.is_empty() { + 0 + } else { + (hint_lines.len() + 1) as u16 // add spacer line above hints + }; + let (content_area, footer_area) = if footer_height > 0 && footer_height < inner.height { + let chunks = Layout::default() + .direction(Direction::Vertical) + .constraints([Constraint::Min(1), Constraint::Length(footer_height)]) + .split(inner); + (chunks[0], Some(chunks[1])) + } else { + (inner, None) + }; + Paragraph::new(lines).render(content_area, frame.buffer_mut()); + if let Some(footer) = footer_area { + let mut footer_lines = vec![Line::from("")]; + footer_lines.extend( + hint_lines + .into_iter() + .map(|hint| Line::from(Span::styled(hint, Style::default().fg(colors.text_pending())))), + ); + Paragraph::new(footer_lines) + .wrap(Wrap { trim: false }) + .render(footer, frame.buffer_mut()); + } +} + +fn render_code_download_progress(frame: &mut ratatui::Frame, app: &App) { + let area = frame.area(); + let colors = &app.theme.colors; + let centered = ui::layout::centered_rect(60, 35, area); + + let block = Block::bordered() + .title(" Downloading Code Source ") + .border_style(Style::default().fg(colors.accent())) + .style(Style::default().bg(colors.bg())); + let inner = block.inner(centered); + block.render(centered, frame.buffer_mut()); + + let total_bytes = app.code_intro_download_bytes_total; + let done_bytes = app + .code_intro_download_bytes + .min(total_bytes.max(app.code_intro_download_bytes)); + let width = 36usize; + let fill = if total_bytes > 0 { + ((done_bytes as usize).saturating_mul(width)) / (total_bytes as usize) + } else { + 0 + }; + let bar = format!( + "{}{}", + "=".repeat(fill), + " ".repeat(width.saturating_sub(fill)) + ); + + let repo_name = if app.code_intro_current_repo.is_empty() { + "Preparing download...".to_string() + } else { + app.code_intro_current_repo.clone() + }; + + let lines = vec![ + Line::from(Span::styled( + format!(" Repo: {repo_name}"), + Style::default() + .fg(colors.fg()) + .add_modifier(Modifier::BOLD), + )), + Line::from(""), + Line::from(Span::styled( + if total_bytes > 0 { + format!(" [{bar}] {done_bytes}/{total_bytes} bytes") + } else { + format!(" Downloaded: {done_bytes} bytes") + }, + Style::default().fg(colors.accent()), + )), + Line::from(""), + Line::from(Span::styled( + " [ESC] Cancel", + Style::default().fg(colors.text_pending()), + )), + ]; + + Paragraph::new(lines).render(inner, frame.buffer_mut()); +} + fn render_skill_tree(frame: &mut ratatui::Frame, app: &App) { let area = frame.area(); let centered = ui::layout::centered_rect(70, 90, area); diff --git a/src/session/drill.rs b/src/session/drill.rs index 1c849a1..7cb4a95 100644 --- a/src/session/drill.rs +++ b/src/session/drill.rs @@ -218,6 +218,51 @@ mod tests { assert_eq!(drill.typo_count(), 1); } + #[test] + fn test_tab_counts_as_four_spaces() { + let mut drill = DrillState::new(" pass"); + let start = drill.cursor; + input::process_char(&mut drill, '\t'); + assert_eq!(drill.cursor, start + 4); + assert_eq!(drill.typo_count(), 0); + } + + #[test] + fn test_tab_counts_as_two_spaces() { + let mut drill = DrillState::new(" echo"); + let start = drill.cursor; + input::process_char(&mut drill, '\t'); + assert_eq!(drill.cursor, start + 2); + assert_eq!(drill.typo_count(), 0); + } + + #[test] + fn test_tab_not_accepted_for_non_four_space_prefix() { + let mut drill = DrillState::new("abc def"); + for ch in "abc".chars() { + input::process_char(&mut drill, ch); + } + let start = drill.cursor; + input::process_char(&mut drill, '\t'); + // Falls back to synthetic incorrect span behavior. + assert!(drill.cursor > start); + assert!(drill.typo_count() >= 1); + } + + #[test] + fn test_correct_enter_auto_indents_next_line() { + let mut drill = DrillState::new("if x:\n pass"); + for ch in "if x:".chars() { + input::process_char(&mut drill, ch); + } + // Correct newline should also consume the 4-space indent. + input::process_char(&mut drill, '\n'); + let expected_cursor = "if x:\n ".chars().count(); + assert_eq!(drill.cursor, expected_cursor); + assert_eq!(drill.typo_count(), 0); + assert_eq!(drill.accuracy(), 100.0); + } + #[test] fn test_nested_synthetic_spans_collapse_to_single_error() { let mut drill = DrillState::new("abcd\nefgh"); diff --git a/src/session/input.rs b/src/session/input.rs index 1505758..582cf00 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -27,7 +27,13 @@ pub fn process_char(drill: &mut DrillState, ch: char) -> Option } let expected = drill.target[drill.cursor]; - let correct = ch == expected; + let tab_indent_len = if ch == '\t' { + tab_indent_completion_len(drill) + } else { + 0 + }; + let tab_as_indent = tab_indent_len > 0; + let correct = ch == expected || tab_as_indent; let event = KeystrokeEvent { expected, @@ -36,9 +42,16 @@ pub fn process_char(drill: &mut DrillState, ch: char) -> Option correct, }; - if correct { + if tab_as_indent { + apply_tab_indent(drill, tab_indent_len); + } else if correct { drill.input.push(CharStatus::Correct); drill.cursor += 1; + // IDE-like behavior: when Enter is correctly typed, auto-consume + // indentation whitespace on the next line. + if ch == '\n' { + apply_auto_indent_after_newline(drill); + } } else if ch == '\n' { apply_newline_span(drill, ch); } else if ch == '\t' { @@ -56,6 +69,63 @@ pub fn process_char(drill: &mut DrillState, ch: char) -> Option Some(event) } +fn tab_indent_completion_len(drill: &DrillState) -> usize { + if drill.cursor >= drill.target.len() { + return 0; + } + + // Only treat Tab as indentation if cursor is in leading whitespace + // for the current line. + let line_start = drill.target[..drill.cursor] + .iter() + .rposition(|&c| c == '\n') + .map(|idx| idx + 1) + .unwrap_or(0); + if drill.target[line_start..drill.cursor] + .iter() + .any(|&c| c != ' ' && c != '\t') + { + return 0; + } + + let line_end = drill.target[drill.cursor..] + .iter() + .position(|&c| c == '\n') + .map(|offset| drill.cursor + offset) + .unwrap_or(drill.target.len()); + + let mut end = drill.cursor; + while end < line_end { + let c = drill.target[end]; + if c == ' ' || c == '\t' { + end += 1; + } else { + break; + } + } + + end.saturating_sub(drill.cursor) +} + +fn apply_tab_indent(drill: &mut DrillState, len: usize) { + for _ in 0..len { + drill.input.push(CharStatus::Correct); + } + drill.cursor = drill.cursor.saturating_add(len); +} + +fn apply_auto_indent_after_newline(drill: &mut DrillState) { + while drill.cursor < drill.target.len() { + let c = drill.target[drill.cursor]; + if c == ' ' || c == '\t' { + drill.input.push(CharStatus::Correct); + drill.cursor += 1; + } else { + break; + } + } +} + pub fn process_backspace(drill: &mut DrillState) { if drill.cursor == 0 { return; diff --git a/src/ui/layout.rs b/src/ui/layout.rs index 6ccddc5..c4f3a15 100644 --- a/src/ui/layout.rs +++ b/src/ui/layout.rs @@ -82,21 +82,17 @@ impl AppLayout { } pub fn centered_rect(percent_x: u16, percent_y: u16, area: Rect) -> Rect { - let vertical = Layout::default() - .direction(Direction::Vertical) - .constraints([ - Constraint::Percentage((100 - percent_y) / 2), - Constraint::Percentage(percent_y), - Constraint::Percentage((100 - percent_y) / 2), - ]) - .split(area); + const MIN_POPUP_WIDTH: u16 = 72; + const MIN_POPUP_HEIGHT: u16 = 18; - Layout::default() - .direction(Direction::Horizontal) - .constraints([ - Constraint::Percentage((100 - percent_x) / 2), - Constraint::Percentage(percent_x), - Constraint::Percentage((100 - percent_x) / 2), - ]) - .split(vertical[1])[1] + let requested_w = area.width.saturating_mul(percent_x.min(100)) / 100; + let requested_h = area.height.saturating_mul(percent_y.min(100)) / 100; + + let target_w = requested_w.max(MIN_POPUP_WIDTH).min(area.width); + let target_h = requested_h.max(MIN_POPUP_HEIGHT).min(area.height); + + let left = area.x.saturating_add((area.width.saturating_sub(target_w)) / 2); + let top = area.y.saturating_add((area.height.saturating_sub(target_h)) / 2); + + Rect::new(left, top, target_w, target_h) }