Prevent Base62Uuid decoding from panicing

An extra long encoded uuid could crash the server, now it's handled as
an error.
This commit is contained in:
Tyler Hallada 2023-10-17 00:59:01 -04:00
parent 7f86612899
commit 5f9d64f2d9

View File

@ -1,5 +1,6 @@
use std::fmt::{self, Display, Formatter};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
@ -7,7 +8,7 @@ const BASE62_CHARS: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmn
/// A wrapper around a UUID (from `uuid::Uuid`) that serializes to a Base62 string.
///
/// Database rows have a UUID primary key, but they are encoded in Base62 to be shorter and more
/// Database rows have a UUID primary key, but they are encoded in Base62 to be shorter and more
/// URL-friendly for the frontend.
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub struct Base62Uuid(
@ -44,9 +45,11 @@ impl Display for Base62Uuid {
}
}
impl From<&str> for Base62Uuid {
fn from(s: &str) -> Self {
Self(Uuid::from_u128(base62_decode(s)))
impl TryFrom<&str> for Base62Uuid {
type Error = anyhow::Error;
fn try_from(s: &str) -> Result<Self> {
Ok(Self(Uuid::from_u128(base62_decode(s)?)))
}
}
@ -68,7 +71,9 @@ where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(Uuid::from_u128(base62_decode(&s)))
Ok(Uuid::from_u128(
base62_decode(&s).map_err(serde::de::Error::custom)?,
))
}
pub fn base62_encode(mut number: u128) -> String {
@ -82,18 +87,25 @@ pub fn base62_encode(mut number: u128) -> String {
}
encoded.reverse();
String::from_utf8(encoded).unwrap()
unsafe {
// Safety: all characters in `encoded` must come from BASE62_CHARS, and characters in
// BASE62_CHARS are valid UTF-8 (they are ASCII). Therefore, `encoded` must contain only
// valid UTF-8.
String::from_utf8_unchecked(encoded)
}
}
pub fn base62_decode(input: &str) -> u128 {
pub fn base62_decode(input: &str) -> Result<u128> {
let base = BASE62_CHARS.len() as u128;
let mut number = 0u128;
for &byte in input.as_bytes() {
number = number * base + (BASE62_CHARS.iter().position(|&ch| ch == byte).unwrap() as u128);
if let Some(value) = BASE62_CHARS.iter().position(|&ch| ch == byte) {
number = number.checked_mul(base).context("u128 overflow")? + value as u128;
}
}
number
Ok(number)
}
#[cfg(test)]
@ -113,9 +125,33 @@ mod tests {
for original_uuid in original_uuids.iter() {
let encoded = base62_encode(original_uuid.as_u128());
let decoded_uuid = Uuid::from_u128(base62_decode(&encoded));
let decoded = base62_decode(&encoded).unwrap();
let decoded_uuid = Uuid::from_u128(decoded);
assert_eq!(*original_uuid, decoded_uuid);
}
}
#[test]
fn errors_if_encoded_string_has_extra_bytes() {
let uuid = Uuid::new_v4();
let encoded = base62_encode(uuid.as_u128());
let encoded_plus_extra = format!("{}{}", encoded, "extra");
let decode_result = base62_decode(&encoded_plus_extra);
assert!(decode_result.is_err());
}
#[test]
fn ignores_invalid_chars_in_encoded_string() {
let uuid = Uuid::new_v4();
let encoded = base62_encode(uuid.as_u128());
let encoded_plus_invalid_chars = format!("!??{}", encoded);
let decoded = base62_decode(&encoded_plus_invalid_chars).unwrap();
let decoded_uuid = Uuid::from_u128(decoded);
assert_eq!(uuid, decoded_uuid);
}
}