First improvement pass
This commit is contained in:
@@ -4,29 +4,108 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TransitionTable {
|
||||
pub transitions: HashMap<(char, char), Vec<(char, f64)>>,
|
||||
pub order: usize,
|
||||
transitions: HashMap<Vec<char>, Vec<(char, f64)>>,
|
||||
}
|
||||
|
||||
impl TransitionTable {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(order: usize) -> Self {
|
||||
Self {
|
||||
order,
|
||||
transitions: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add(&mut self, prev: char, curr: char, next: char, weight: f64) {
|
||||
pub fn add(&mut self, prefix: &[char], next: char, weight: f64) {
|
||||
self.transitions
|
||||
.entry((prev, curr))
|
||||
.entry(prefix.to_vec())
|
||||
.or_default()
|
||||
.push((next, weight));
|
||||
}
|
||||
|
||||
pub fn get_next_probs(&self, prev: char, curr: char) -> Option<&Vec<(char, f64)>> {
|
||||
self.transitions.get(&(prev, curr))
|
||||
pub fn segment(&self, prefix: &[char]) -> Option<&Vec<(char, f64)>> {
|
||||
// Try exact prefix match first, then fall back to shorter prefixes
|
||||
let key_len = self.order - 1;
|
||||
let prefix = if prefix.len() >= key_len {
|
||||
&prefix[prefix.len() - key_len..]
|
||||
} else {
|
||||
prefix
|
||||
};
|
||||
|
||||
// Try progressively shorter prefixes for backoff
|
||||
for start in 0..prefix.len() {
|
||||
let key = prefix[start..].to_vec();
|
||||
if let Some(entries) = self.transitions.get(&key) {
|
||||
return Some(entries);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Build an order-4 transition table from a word frequency list.
|
||||
/// Words earlier in the list are higher frequency and get more weight.
|
||||
pub fn build_from_words(words: &[String]) -> Self {
|
||||
let mut table = Self::new(4);
|
||||
let prefix_len = 3; // order - 1
|
||||
|
||||
for (rank, word) in words.iter().enumerate() {
|
||||
if word.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
if !word.chars().all(|c| c.is_ascii_lowercase()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Weight decreases with rank (frequency-based)
|
||||
let weight = 1.0 / (1.0 + (rank as f64 / 500.0));
|
||||
|
||||
// Add word start transitions (space prefix -> first chars)
|
||||
let chars: Vec<char> = word.chars().collect();
|
||||
|
||||
// Start of word: ' ' prefix
|
||||
for i in 0..chars.len() {
|
||||
let mut prefix = Vec::new();
|
||||
// Build prefix from space + preceding chars
|
||||
let start = if i >= prefix_len { i - prefix_len } else { 0 };
|
||||
if i < prefix_len {
|
||||
// Pad with spaces
|
||||
for _ in 0..(prefix_len - i) {
|
||||
prefix.push(' ');
|
||||
}
|
||||
}
|
||||
for j in start..i {
|
||||
prefix.push(chars[j]);
|
||||
}
|
||||
|
||||
let next = chars[i];
|
||||
table.add(&prefix, next, weight);
|
||||
}
|
||||
|
||||
// End of word: last chars -> space
|
||||
let end_start = if chars.len() >= prefix_len {
|
||||
chars.len() - prefix_len
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let mut end_prefix: Vec<char> = Vec::new();
|
||||
if chars.len() < prefix_len {
|
||||
for _ in 0..(prefix_len - chars.len()) {
|
||||
end_prefix.push(' ');
|
||||
}
|
||||
}
|
||||
for j in end_start..chars.len() {
|
||||
end_prefix.push(chars[j]);
|
||||
}
|
||||
table.add(&end_prefix, ' ', weight);
|
||||
}
|
||||
|
||||
table
|
||||
}
|
||||
|
||||
/// Legacy order-2 table for fallback
|
||||
#[allow(dead_code)]
|
||||
pub fn build_english() -> Self {
|
||||
let mut table = Self::new();
|
||||
let mut table = Self::new(4);
|
||||
|
||||
let common_patterns: &[(&str, f64)] = &[
|
||||
("the", 10.0), ("and", 8.0), ("ing", 7.0), ("tion", 6.0), ("ent", 5.0),
|
||||
@@ -40,25 +119,24 @@ impl TransitionTable {
|
||||
("ght", 2.0), ("whi", 2.0), ("who", 2.0), ("hen", 2.0), ("ter", 2.0),
|
||||
("man", 2.0), ("men", 2.0), ("ner", 2.0), ("per", 2.0), ("pre", 2.0),
|
||||
("ran", 2.0), ("lin", 2.0), ("kin", 2.0), ("din", 2.0), ("sin", 2.0),
|
||||
("out", 2.0), ("ind", 2.0), ("ith", 2.0), ("ber", 2.0), ("der", 2.0),
|
||||
("out", 2.0), ("ind", 2.0), ("ber", 2.0), ("der", 2.0),
|
||||
("end", 2.0), ("hin", 2.0), ("old", 2.0), ("ear", 2.0), ("ain", 2.0),
|
||||
("ant", 2.0), ("urn", 2.0), ("ell", 2.0), ("ill", 2.0), ("ade", 2.0),
|
||||
("igh", 2.0), ("ong", 2.0), ("ung", 2.0), ("ast", 2.0), ("ist", 2.0),
|
||||
("ong", 2.0), ("ung", 2.0), ("ast", 2.0), ("ist", 2.0),
|
||||
("ust", 2.0), ("ost", 2.0), ("ard", 2.0), ("ord", 2.0), ("art", 2.0),
|
||||
("ort", 2.0), ("ect", 2.0), ("act", 2.0), ("ack", 2.0), ("ick", 2.0),
|
||||
("ock", 2.0), ("uck", 2.0), ("ash", 2.0), ("ish", 2.0), ("ush", 2.0),
|
||||
("anc", 1.5), ("enc", 1.5), ("inc", 1.5), ("onc", 1.5), ("unc", 1.5),
|
||||
("unt", 1.5), ("int", 1.5), ("ont", 1.5), ("ent", 1.5), ("ment", 1.5),
|
||||
("ness", 1.5), ("less", 1.5), ("able", 1.5), ("ible", 1.5), ("ting", 1.5),
|
||||
("ring", 1.5), ("sing", 1.5), ("king", 1.5), ("ning", 1.5), ("ling", 1.5),
|
||||
("wing", 1.5), ("ding", 1.5), ("ping", 1.5), ("ging", 1.5), ("ving", 1.5),
|
||||
("bing", 1.5), ("ming", 1.5), ("fing", 1.0), ("hing", 1.0), ("cing", 1.0),
|
||||
];
|
||||
|
||||
for &(pattern, weight) in common_patterns {
|
||||
let chars: Vec<char> = pattern.chars().collect();
|
||||
for window in chars.windows(3) {
|
||||
table.add(window[0], window[1], window[2], weight);
|
||||
let prefix = vec![window[0], window[1]];
|
||||
table.add(&prefix, window[2], weight);
|
||||
}
|
||||
// Also add shorter prefix entries for the start of patterns
|
||||
if chars.len() >= 2 {
|
||||
table.add(&[' ', chars[0]], chars[1], weight * 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,20 +148,14 @@ impl TransitionTable {
|
||||
|
||||
for &c in &consonants {
|
||||
for &v in &vowels {
|
||||
table.add(' ', c, v, 1.0);
|
||||
table.add(v, c, 'e', 0.5);
|
||||
for &v2 in &vowels {
|
||||
table.add(c, v, v2.to_ascii_lowercase(), 0.3);
|
||||
}
|
||||
for &c2 in &consonants {
|
||||
table.add(v, c, c2, 0.2);
|
||||
}
|
||||
table.add(&[' ', c], v, 1.0);
|
||||
table.add(&[v, c], 'e', 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
for &v in &vowels {
|
||||
for &c in &consonants {
|
||||
table.add(' ', v, c, 0.5);
|
||||
table.add(&[' ', v], c, 0.5);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +165,6 @@ impl TransitionTable {
|
||||
|
||||
impl Default for TransitionTable {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
Self::new(4)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user