Prevent Base62Uuid decoding from panicing

An extra long encoded uuid could crash the server, now it's handled as
an error.
This commit is contained in:
Tyler Hallada 2023-10-17 00:59:01 -04:00
parent 7f86612899
commit 5f9d64f2d9

View File

@ -1,5 +1,6 @@
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
@ -7,7 +8,7 @@ const BASE62_CHARS: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn
/// A wrapper around a UUID (from `uuid::Uuid`) that serializes to a Base62 string. /// A wrapper around a UUID (from `uuid::Uuid`) that serializes to a Base62 string.
/// ///
/// Database rows have a UUID primary key, but they are encoded in Base62 to be shorter and more /// Database rows have a UUID primary key, but they are encoded in Base62 to be shorter and more
/// URL-friendly for the frontend. /// URL-friendly for the frontend.
#[derive(Debug, Serialize, Deserialize, Clone, Copy)] #[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub struct Base62Uuid( pub struct Base62Uuid(
@ -44,9 +45,11 @@ impl Display for Base62Uuid {
} }
} }
impl From<&str> for Base62Uuid { impl TryFrom<&str> for Base62Uuid {
fn from(s: &str) -> Self { type Error = anyhow::Error;
Self(Uuid::from_u128(base62_decode(s)))
fn try_from(s: &str) -> Result<Self> {
Ok(Self(Uuid::from_u128(base62_decode(s)?)))
} }
} }
@ -68,7 +71,9 @@ where
D: serde::Deserializer<'de>, D: serde::Deserializer<'de>,
{ {
let s = String::deserialize(deserializer)?; let s = String::deserialize(deserializer)?;
Ok(Uuid::from_u128(base62_decode(&s))) Ok(Uuid::from_u128(
base62_decode(&s).map_err(serde::de::Error::custom)?,
))
} }
pub fn base62_encode(mut number: u128) -> String { pub fn base62_encode(mut number: u128) -> String {
@ -82,18 +87,25 @@ pub fn base62_encode(mut number: u128) -> String {
} }
encoded.reverse(); encoded.reverse();
String::from_utf8(encoded).unwrap() unsafe {
// Safety: all characters in `encoded` must come from BASE62_CHARS, and characters in
// BASE62_CHARS are valid UTF-8 (they are ASCII). Therefore, `encoded` must contain only
// valid UTF-8.
String::from_utf8_unchecked(encoded)
}
} }
pub fn base62_decode(input: &str) -> u128 { pub fn base62_decode(input: &str) -> Result<u128> {
let base = BASE62_CHARS.len() as u128; let base = BASE62_CHARS.len() as u128;
let mut number = 0u128; let mut number = 0u128;
for &byte in input.as_bytes() { for &byte in input.as_bytes() {
number = number * base + (BASE62_CHARS.iter().position(|&ch| ch == byte).unwrap() as u128); if let Some(value) = BASE62_CHARS.iter().position(|&ch| ch == byte) {
number = number.checked_mul(base).context("u128 overflow")? + value as u128;
}
} }
number Ok(number)
} }
#[cfg(test)] #[cfg(test)]
@ -113,9 +125,33 @@ mod tests {
for original_uuid in original_uuids.iter() { for original_uuid in original_uuids.iter() {
let encoded = base62_encode(original_uuid.as_u128()); let encoded = base62_encode(original_uuid.as_u128());
let decoded_uuid = Uuid::from_u128(base62_decode(&encoded)); let decoded = base62_decode(&encoded).unwrap();
let decoded_uuid = Uuid::from_u128(decoded);
assert_eq!(*original_uuid, decoded_uuid); assert_eq!(*original_uuid, decoded_uuid);
} }
} }
#[test]
fn errors_if_encoded_string_has_extra_bytes() {
let uuid = Uuid::new_v4();
let encoded = base62_encode(uuid.as_u128());
let encoded_plus_extra = format!("{}{}", encoded, "extra");
let decode_result = base62_decode(&encoded_plus_extra);
assert!(decode_result.is_err());
}
#[test]
fn ignores_invalid_chars_in_encoded_string() {
let uuid = Uuid::new_v4();
let encoded = base62_encode(uuid.as_u128());
let encoded_plus_invalid_chars = format!("!??{}", encoded);
let decoded = base62_decode(&encoded_plus_invalid_chars).unwrap();
let decoded_uuid = Uuid::from_u128(decoded);
assert_eq!(uuid, decoded_uuid);
}
} }