use std::collections::{HashMap, HashSet}; use std::io::{Error, Read, Write}; use url::{ParseError, Url}; use html5ever::interface::tree_builder::QuirksMode; use html5ever::tendril::{format_tendril, StrTendril, TendrilSink}; use html5ever::{ parse_document, parse_fragment, serialize, Attribute as HTML5everAttribute, LocalName, QualName, }; use crate::arena_dom::{Arena, Attribute, Node, NodeData, Ref, Sink, StyleAttribute}; use crate::css_at_rule::CssAtRule; use crate::css_parser::{parse_css_style_attribute, parse_css_stylesheet, CssRule, CssStyleRule}; use crate::css_property::CssProperty; pub struct Sanitizer<'arena> { arena: typed_arena::Arena>, config: &'arena SanitizerConfig, transformers: Vec<&'arena dyn Fn(Ref<'arena>, Arena<'arena>)>, } #[derive(Debug, Clone)] pub struct SanitizerConfig { pub allow_comments: bool, pub allow_doctype: bool, pub allowed_elements: HashSet, pub allowed_attributes: HashSet, pub allowed_attributes_per_element: HashMap>, pub add_attributes: HashMap, pub add_attributes_per_element: HashMap>, pub allowed_protocols: HashMap>>>, pub allowed_css_at_rules: HashSet, pub allowed_css_properties: HashSet, pub allowed_css_protocols: HashSet>, pub allow_css_comments: bool, pub remove_contents_when_unwrapped: HashSet, pub whitespace_around_unwrapped_content: HashMap>, } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum Protocol<'a> { Scheme(&'a str), Relative, } #[derive(Debug, Clone)] pub struct ContentWhitespace<'a> { before: &'a str, after: &'a str, } impl<'a> ContentWhitespace<'a> { pub fn space_around() -> ContentWhitespace<'a> { ContentWhitespace { before: " ", after: " ", } } } impl<'arena> Sanitizer<'arena> { pub fn new( config: &'arena SanitizerConfig, transformers: Vec<&'arena dyn Fn(Ref<'arena>, Arena<'arena>)>, ) -> Sanitizer<'arena> { Sanitizer { arena: typed_arena::Arena::new(), config, transformers, } } pub fn sanitize_fragment( &'arena self, input: &mut impl Read, output: &mut impl Write, ) -> Result<(), Error> { let root = self.parse_fragment(input)?; self.traverse(root); serialize(output, root, Default::default()) } pub fn sanitize_document( &'arena self, input: &mut impl Read, output: &mut impl Write, ) -> Result<(), Error> { let root = self.parse_document(input)?; self.traverse(root); serialize(output, root, Default::default()) } fn parse_document(&'arena self, data: &mut impl Read) -> Result, Error> { let mut bytes = Vec::new(); data.read_to_end(&mut bytes)?; let sink = Sink { arena: &self.arena, document: self.arena.alloc(Node::new(NodeData::Document)), quirks_mode: QuirksMode::NoQuirks, }; Ok(parse_document(sink, Default::default()) .from_utf8() .one(&bytes[..])) } fn parse_fragment(&'arena self, data: &mut impl Read) -> Result, Error> { let mut bytes = Vec::new(); data.read_to_end(&mut bytes)?; let sink = Sink { arena: &self.arena, document: self.arena.alloc(Node::new(NodeData::Document)), quirks_mode: QuirksMode::NoQuirks, }; Ok(parse_fragment( sink, Default::default(), QualName::new(None, ns!(), local_name!("body")), vec![], ) .from_utf8() .one(&bytes[..])) } fn traverse(&'arena self, node: Ref<'arena>) { if self.should_unwrap_node(node) { let sibling = node.next_sibling.get(); if self.should_remove_contents_when_unwrapped(node) { node.detach(); } else if let Some(unwrapped_node) = node.unwrap() { self.add_unwrapped_content_whitespace(node, unwrapped_node); self.traverse(unwrapped_node); } if let Some(sibling) = sibling { self.traverse(sibling); } return; } self.remove_attributes(node); self.add_attributes(node); self.sanitize_attribute_protocols(node); self.sanitize_style_tag_css(node); self.sanitize_style_attribute_css(node); for transformer in self.transformers.iter() { transformer(node, &self.arena); } if let Some(child) = node.first_child.get() { self.traverse(child); } if let Some(sibling) = node.next_sibling.get() { self.traverse(sibling); } } fn should_unwrap_node(&self, node: Ref) -> bool { match node.data { NodeData::Document | NodeData::Text { .. } | NodeData::StyleSheet { .. } | NodeData::ProcessingInstruction { .. } => false, NodeData::Comment { .. } => !self.config.allow_comments, NodeData::Doctype { .. } => !self.config.allow_doctype, NodeData::Element { ref name, .. } => { !self.config.allowed_elements.contains(&name.local) } } } fn should_remove_contents_when_unwrapped(&self, node: Ref) -> bool { match node.data { NodeData::Document | NodeData::Doctype { .. } | NodeData::Text { .. } | NodeData::StyleSheet { .. } | NodeData::ProcessingInstruction { .. } | NodeData::Comment { .. } => false, NodeData::Element { ref name, .. } => self .config .remove_contents_when_unwrapped .contains(&name.local), } } fn remove_attributes(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); let mut i = 0; let all_allowed = &self.config.allowed_attributes; let per_element_allowed = self.config.allowed_attributes_per_element.get(&name.local); while i != attrs.len() { if let Attribute::Text(attr) = &attrs[i] { if !all_allowed.contains(&attr.name.local) { if let Some(per_element_allowed) = per_element_allowed { if per_element_allowed.contains(&attr.name.local) { i += 1; continue; } } attrs.remove(i); continue; } } i += 1; } } } fn add_attributes(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); let add_attributes = &self.config.add_attributes; let add_attributes_per_element = self.config.add_attributes_per_element.get(&name.local); for (name, &value) in add_attributes.iter() { attrs.push(Attribute::Text(HTML5everAttribute { name: QualName::new(None, ns!(), name.clone()), value: StrTendril::from(value), })); } if let Some(add_attributes_per_element) = add_attributes_per_element { for (name, &value) in add_attributes_per_element.iter() { attrs.push(Attribute::Text(HTML5everAttribute { name: QualName::new(None, ns!(), name.clone()), value: StrTendril::from(value), })); } } } } fn sanitize_attribute_protocols(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); if let Some(protocols) = self.config.allowed_protocols.get(&name.local) { let mut i = 0; while i != attrs.len() { if let Attribute::Text(attr) = &attrs[i] { if let Some(allowed_protocols) = protocols.get(&attr.name.local) { match Url::parse(&attr.value) { Ok(url) => { if !allowed_protocols.contains(&Protocol::Scheme(url.scheme())) { attrs.remove(i); } else { i += 1; } } Err(ParseError::RelativeUrlWithoutBase) => { if !allowed_protocols.contains(&Protocol::Relative) { attrs.remove(i); } else { i += 1; } } Err(_) => { attrs.remove(i); } } } else { i += 1; } } else { i += 1; } } } } } fn sanitize_css_rules(&self, rules: Vec) -> Vec { rules .into_iter() .filter_map(|rule| match rule { CssRule::StyleRule(style_rule) => Some(CssRule::StyleRule(CssStyleRule { selectors: style_rule.selectors, declarations: style_rule .declarations .into_iter() .filter(|declaration| { self.config .allowed_css_properties .contains(&CssProperty::from(declaration.property.as_str())) }) .collect(), })), CssRule::AtRule(at_rule) => { if self .config .allowed_css_at_rules .contains(&CssAtRule::from(at_rule.name.as_str())) { Some(CssRule::AtRule(at_rule)) } else { None } } }) .collect() } fn sanitize_style_tag_css(&'arena self, node: Ref<'arena>) { if let NodeData::Element { ref name, .. } = node.data { if name.local == local_name!("style") { // TODO: is it okay to assume "); let mut output = vec![]; sanitizer .sanitize_fragment(&mut mock_data, &mut output) .unwrap(); assert_eq!( str::from_utf8(&output).unwrap(), "" ); } #[test] fn sanitize_css_protocols() { let mut sanitize_css_config = EMPTY_CONFIG.clone(); sanitize_css_config .allowed_elements .extend(vec![local_name!("html"), local_name!("style")]); sanitize_css_config.allowed_css_properties.extend(vec![ css_property!("background-image"), css_property!("content"), ]); sanitize_css_config .allowed_css_protocols .extend(vec![Protocol::Scheme("https")]); let sanitizer = Sanitizer::new(&sanitize_css_config, vec![]); let mut mock_data = MockRead::new( "", ); let mut output = vec![]; sanitizer .sanitize_fragment(&mut mock_data, &mut output) .unwrap(); assert_eq!( str::from_utf8(&output).unwrap(), "" ); } #[test] fn remove_doctype() { let mut disallow_doctype_config = EMPTY_CONFIG.clone(); disallow_doctype_config.allow_doctype = false; disallow_doctype_config .allowed_elements .extend(vec![local_name!("html"), local_name!("div")]); let sanitizer = Sanitizer::new(&disallow_doctype_config, vec![]); let mut mock_data = MockRead::new("
"); let mut output = vec![]; sanitizer .sanitize_document(&mut mock_data, &mut output) .unwrap(); assert_eq!(str::from_utf8(&output).unwrap(), "
"); } #[test] fn allow_doctype() { let mut allow_doctype_config = EMPTY_CONFIG.clone(); allow_doctype_config.allow_doctype = true; allow_doctype_config .allowed_elements .extend(vec![local_name!("html"), local_name!("div")]); let sanitizer = Sanitizer::new(&allow_doctype_config, vec![]); let mut mock_data = MockRead::new("
"); let mut output = vec![]; sanitizer .sanitize_document(&mut mock_data, &mut output) .unwrap(); assert_eq!( str::from_utf8(&output).unwrap(), "
" ); } #[test] fn add_unwrapped_content_whitespace() { let mut unwrapped_whitespace_config = EMPTY_CONFIG.clone(); unwrapped_whitespace_config .allowed_elements .extend(vec![local_name!("html"), local_name!("div")]); unwrapped_whitespace_config .whitespace_around_unwrapped_content .insert(local_name!("span"), ContentWhitespace::space_around()); let sanitizer = Sanitizer::new(&unwrapped_whitespace_config, vec![]); let mut mock_data = MockRead::new("
div-1content-1content-2div-2
"); let mut output = vec![]; sanitizer .sanitize_fragment(&mut mock_data, &mut output) .unwrap(); assert_eq!( str::from_utf8(&output).unwrap(), "
div-1 content-1 content-2 div-2
" ); } }