241 lines
6.7 KiB
Rust
241 lines
6.7 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct TransitionTable {
|
|
pub order: usize,
|
|
transitions: HashMap<Vec<char>, Vec<(char, f64)>>,
|
|
}
|
|
|
|
impl TransitionTable {
|
|
pub fn new(order: usize) -> Self {
|
|
Self {
|
|
order,
|
|
transitions: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
pub fn add(&mut self, prefix: &[char], next: char, weight: f64) {
|
|
self.transitions
|
|
.entry(prefix.to_vec())
|
|
.or_default()
|
|
.push((next, weight));
|
|
}
|
|
|
|
pub fn segment(&self, prefix: &[char]) -> Option<&Vec<(char, f64)>> {
|
|
// Try exact prefix match first, then fall back to shorter prefixes
|
|
let key_len = self.order - 1;
|
|
let prefix = if prefix.len() >= key_len {
|
|
&prefix[prefix.len() - key_len..]
|
|
} else {
|
|
prefix
|
|
};
|
|
|
|
// Try progressively shorter prefixes for backoff
|
|
for start in 0..prefix.len() {
|
|
let key = prefix[start..].to_vec();
|
|
if let Some(entries) = self.transitions.get(&key) {
|
|
return Some(entries);
|
|
}
|
|
}
|
|
None
|
|
}
|
|
|
|
/// Build an order-4 transition table from a word frequency list.
|
|
/// Words earlier in the list are higher frequency and get more weight.
|
|
pub fn build_from_words(words: &[String]) -> Self {
|
|
let mut table = Self::new(4);
|
|
let prefix_len = 3; // order - 1
|
|
|
|
for (rank, word) in words.iter().enumerate() {
|
|
if word.len() < 3 {
|
|
continue;
|
|
}
|
|
if !word.chars().all(|c| c.is_ascii_lowercase()) {
|
|
continue;
|
|
}
|
|
|
|
// Weight decreases with rank (frequency-based)
|
|
let weight = 1.0 / (1.0 + (rank as f64 / 500.0));
|
|
|
|
// Add word start transitions (space prefix -> first chars)
|
|
let chars: Vec<char> = word.chars().collect();
|
|
|
|
// Start of word: ' ' prefix
|
|
for i in 0..chars.len() {
|
|
let mut prefix = Vec::new();
|
|
// Build prefix from space + preceding chars
|
|
let start = if i >= prefix_len { i - prefix_len } else { 0 };
|
|
if i < prefix_len {
|
|
// Pad with spaces
|
|
for _ in 0..(prefix_len - i) {
|
|
prefix.push(' ');
|
|
}
|
|
}
|
|
for j in start..i {
|
|
prefix.push(chars[j]);
|
|
}
|
|
|
|
let next = chars[i];
|
|
table.add(&prefix, next, weight);
|
|
}
|
|
|
|
// End of word: last chars -> space
|
|
let end_start = if chars.len() >= prefix_len {
|
|
chars.len() - prefix_len
|
|
} else {
|
|
0
|
|
};
|
|
let mut end_prefix: Vec<char> = Vec::new();
|
|
if chars.len() < prefix_len {
|
|
for _ in 0..(prefix_len - chars.len()) {
|
|
end_prefix.push(' ');
|
|
}
|
|
}
|
|
for j in end_start..chars.len() {
|
|
end_prefix.push(chars[j]);
|
|
}
|
|
table.add(&end_prefix, ' ', weight);
|
|
}
|
|
|
|
table
|
|
}
|
|
|
|
/// Legacy order-2 table for fallback
|
|
#[allow(dead_code)]
|
|
pub fn build_english() -> Self {
|
|
let mut table = Self::new(4);
|
|
|
|
let common_patterns: &[(&str, f64)] = &[
|
|
("the", 10.0),
|
|
("and", 8.0),
|
|
("ing", 7.0),
|
|
("tion", 6.0),
|
|
("ent", 5.0),
|
|
("ion", 5.0),
|
|
("her", 4.0),
|
|
("for", 4.0),
|
|
("are", 4.0),
|
|
("his", 4.0),
|
|
("hat", 3.0),
|
|
("tha", 3.0),
|
|
("ere", 3.0),
|
|
("ate", 3.0),
|
|
("ith", 3.0),
|
|
("ver", 3.0),
|
|
("all", 3.0),
|
|
("not", 3.0),
|
|
("ess", 3.0),
|
|
("est", 3.0),
|
|
("rea", 3.0),
|
|
("sta", 3.0),
|
|
("ted", 3.0),
|
|
("com", 3.0),
|
|
("con", 3.0),
|
|
("oun", 2.5),
|
|
("pro", 2.5),
|
|
("oth", 2.5),
|
|
("igh", 2.5),
|
|
("ore", 2.5),
|
|
("our", 2.5),
|
|
("ine", 2.5),
|
|
("ove", 2.5),
|
|
("ome", 2.5),
|
|
("use", 2.5),
|
|
("ble", 2.0),
|
|
("ful", 2.0),
|
|
("ous", 2.0),
|
|
("str", 2.0),
|
|
("tri", 2.0),
|
|
("ght", 2.0),
|
|
("whi", 2.0),
|
|
("who", 2.0),
|
|
("hen", 2.0),
|
|
("ter", 2.0),
|
|
("man", 2.0),
|
|
("men", 2.0),
|
|
("ner", 2.0),
|
|
("per", 2.0),
|
|
("pre", 2.0),
|
|
("ran", 2.0),
|
|
("lin", 2.0),
|
|
("kin", 2.0),
|
|
("din", 2.0),
|
|
("sin", 2.0),
|
|
("out", 2.0),
|
|
("ind", 2.0),
|
|
("ber", 2.0),
|
|
("der", 2.0),
|
|
("end", 2.0),
|
|
("hin", 2.0),
|
|
("old", 2.0),
|
|
("ear", 2.0),
|
|
("ain", 2.0),
|
|
("ant", 2.0),
|
|
("urn", 2.0),
|
|
("ell", 2.0),
|
|
("ill", 2.0),
|
|
("ade", 2.0),
|
|
("ong", 2.0),
|
|
("ung", 2.0),
|
|
("ast", 2.0),
|
|
("ist", 2.0),
|
|
("ust", 2.0),
|
|
("ost", 2.0),
|
|
("ard", 2.0),
|
|
("ord", 2.0),
|
|
("art", 2.0),
|
|
("ort", 2.0),
|
|
("ect", 2.0),
|
|
("act", 2.0),
|
|
("ack", 2.0),
|
|
("ick", 2.0),
|
|
("ock", 2.0),
|
|
("uck", 2.0),
|
|
("ash", 2.0),
|
|
("ish", 2.0),
|
|
("ush", 2.0),
|
|
];
|
|
|
|
for &(pattern, weight) in common_patterns {
|
|
let chars: Vec<char> = pattern.chars().collect();
|
|
for window in chars.windows(3) {
|
|
let prefix = vec![window[0], window[1]];
|
|
table.add(&prefix, window[2], weight);
|
|
}
|
|
// Also add shorter prefix entries for the start of patterns
|
|
if chars.len() >= 2 {
|
|
table.add(&[' ', chars[0]], chars[1], weight * 0.5);
|
|
}
|
|
}
|
|
|
|
let vowels = ['a', 'e', 'i', 'o', 'u'];
|
|
let consonants = [
|
|
'b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'r', 's', 't', 'v', 'w',
|
|
'x', 'y', 'z',
|
|
];
|
|
|
|
for &c in &consonants {
|
|
for &v in &vowels {
|
|
table.add(&[' ', c], v, 1.0);
|
|
table.add(&[v, c], 'e', 0.5);
|
|
}
|
|
}
|
|
|
|
for &v in &vowels {
|
|
for &c in &consonants {
|
|
table.add(&[' ', v], c, 0.5);
|
|
}
|
|
}
|
|
|
|
table
|
|
}
|
|
}
|
|
|
|
impl Default for TransitionTable {
|
|
fn default() -> Self {
|
|
Self::new(4)
|
|
}
|
|
}
|