use std::collections::{HashMap, HashSet}; use std::io::{Error, Read, Write}; use url::{ParseError, Url}; use html5ever::interface::tree_builder::QuirksMode; use html5ever::tendril::{StrTendril, TendrilSink}; use html5ever::{parse_document, parse_fragment, serialize, Attribute, LocalName, QualName}; use crate::arena_dom::{Arena, Node, NodeData, Ref, Sink}; use crate::css_at_rule::CssAtRule; use crate::css_parser::{parse_css_style_attribute, parse_css_stylesheet, CssRule}; use crate::css_property::CssProperty; use crate::css_token_parser::parse_and_serialize; pub struct Sanitizer<'arena> { arena: typed_arena::Arena>, config: &'arena SanitizerConfig, transformers: Vec<&'arena dyn Fn(Ref<'arena>, Arena<'arena>)>, } #[derive(Debug, Clone)] pub struct SanitizerConfig { pub allow_comments: bool, pub allowed_elements: HashSet, pub allowed_attributes: HashSet, pub allowed_attributes_per_element: HashMap>, pub add_attributes: HashMap, pub add_attributes_per_element: HashMap>, pub allowed_protocols: HashMap>>>, pub allowed_css_at_rules: HashSet, pub allowed_css_properties: HashSet, pub allow_css_comments: bool, pub remove_contents_when_unwrapped: HashSet, } #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub enum Protocol<'a> { Scheme(&'a str), Relative, } impl<'arena> Sanitizer<'arena> { pub fn new( config: &'arena SanitizerConfig, transformers: Vec<&'arena dyn Fn(Ref<'arena>, Arena<'arena>)>, ) -> Sanitizer<'arena> { Sanitizer { arena: typed_arena::Arena::new(), config, transformers, } } pub fn sanitize_fragment( &'arena self, input: &mut impl Read, output: &mut impl Write, ) -> Result<(), Error> { let root = self.parse_fragment(input)?; self.traverse(root); serialize(output, root, Default::default()) } pub fn sanitize_document( &'arena self, input: &mut impl Read, output: &mut impl Write, ) -> Result<(), Error> { let root = self.parse_document(input)?; self.traverse(root); serialize(output, root, Default::default()) } fn parse_document(&'arena self, data: &mut impl Read) -> Result, Error> { let mut bytes = Vec::new(); data.read_to_end(&mut bytes)?; let sink = Sink { arena: &self.arena, document: self.arena.alloc(Node::new(NodeData::Document)), quirks_mode: QuirksMode::NoQuirks, }; Ok(parse_document(sink, Default::default()) .from_utf8() .one(&bytes[..])) } fn parse_fragment(&'arena self, data: &mut impl Read) -> Result, Error> { let mut bytes = Vec::new(); data.read_to_end(&mut bytes)?; let sink = Sink { arena: &self.arena, document: self.arena.alloc(Node::new(NodeData::Document)), quirks_mode: QuirksMode::NoQuirks, }; Ok(parse_fragment( sink, Default::default(), QualName::new(None, ns!(), local_name!("body")), vec![], ) .from_utf8() .one(&bytes[..])) } fn traverse(&'arena self, node: Ref<'arena>) { println!("{}", &node); if self.should_unwrap_node(node) { let sibling = node.next_sibling.get(); println!("unwrapping node"); if self.should_remove_contents_when_unwrapped(node) { println!("detaching node"); node.detach(); println!("post-detach: {}", &node); } else if let Some(unwrapped_node) = node.unwrap() { println!("traversing unwrapped node"); self.traverse(unwrapped_node); } if let Some(sibling) = sibling { println!("traversing sibling"); self.traverse(sibling); } return; } println!("TRANSFORMING: {}", &node); self.remove_attributes(node); self.add_attributes(node); self.sanitize_attribute_protocols(node); self.sanitize_style_tag_css(node); self.sanitize_style_attribute_css(node); // self.serialize_css_test(node); for transformer in self.transformers.iter() { transformer(node, &self.arena); } if let Some(child) = node.first_child.get() { println!("traversing child"); self.traverse(child); } if let Some(sibling) = node.next_sibling.get() { println!("traversing sibling"); self.traverse(sibling); } } fn should_unwrap_node(&self, node: Ref) -> bool { match node.data { NodeData::Document | NodeData::Doctype { .. } | NodeData::Text { .. } | NodeData::ProcessingInstruction { .. } => false, NodeData::Comment { .. } => !self.config.allow_comments, NodeData::Element { ref name, .. } => { !self.config.allowed_elements.contains(&name.local) } } } fn should_remove_contents_when_unwrapped(&self, node: Ref) -> bool { match node.data { NodeData::Document | NodeData::Doctype { .. } | NodeData::Text { .. } | NodeData::ProcessingInstruction { .. } | NodeData::Comment { .. } => false, NodeData::Element { ref name, .. } => self .config .remove_contents_when_unwrapped .contains(&name.local), } } fn remove_attributes(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); let mut i = 0; let all_allowed = &self.config.allowed_attributes; let per_element_allowed = self.config.allowed_attributes_per_element.get(&name.local); while i != attrs.len() { if !all_allowed.contains(&attrs[i].name.local) { if let Some(per_element_allowed) = per_element_allowed { if per_element_allowed.contains(&attrs[i].name.local) { i += 1; continue; } } attrs.remove(i); continue; } i += 1; } } } fn add_attributes(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); let add_attributes = &self.config.add_attributes; let add_attributes_per_element = self.config.add_attributes_per_element.get(&name.local); for (name, &value) in add_attributes.iter() { attrs.push(Attribute { name: QualName::new(None, ns!(), name.clone()), value: StrTendril::from(value), }); } if let Some(add_attributes_per_element) = add_attributes_per_element { for (name, &value) in add_attributes_per_element.iter() { attrs.push(Attribute { name: QualName::new(None, ns!(), name.clone()), value: StrTendril::from(value), }); } } } } fn sanitize_attribute_protocols(&self, node: Ref<'arena>) { if let NodeData::Element { ref attrs, ref name, .. } = node.data { let attrs = &mut attrs.borrow_mut(); if let Some(protocols) = self.config.allowed_protocols.get(&name.local) { let mut i = 0; while i != attrs.len() { if let Some(allowed_protocols) = protocols.get(&attrs[i].name.local) { match Url::parse(&attrs[i].value) { Ok(url) => { if !allowed_protocols.contains(&Protocol::Scheme(url.scheme())) { attrs.remove(i); } else { i += 1; } } Err(ParseError::RelativeUrlWithoutBase) => { if !allowed_protocols.contains(&Protocol::Relative) { attrs.remove(i); } else { i += 1; } } Err(_) => { attrs.remove(i); } } } else { i += 1; } } } } } fn serialize_sanitized_css_rules(&self, rules: Vec) -> String { let mut sanitized_css = String::new(); for rule in rules { match rule { CssRule::StyleRule(style_rule) => { sanitized_css += &style_rule.selectors; sanitized_css += "{"; for declaration in style_rule.declarations.into_iter() { let declaration_string = &declaration.to_string(); if self .config .allowed_css_properties .contains(&CssProperty::from(declaration.property)) { sanitized_css += declaration_string; } } sanitized_css += " }"; } CssRule::AtRule(at_rule) => { if self .config .allowed_css_at_rules .contains(&CssAtRule::from(at_rule.name.clone())) { sanitized_css += &format!("@{}", &at_rule.name); sanitized_css += &at_rule.prelude; if let Some(block) = at_rule.block { sanitized_css += "{"; sanitized_css += &self.serialize_sanitized_css_rules(block); sanitized_css += " }"; } else { sanitized_css += "; "; } } } } } sanitized_css } fn sanitize_style_tag_css(&self, node: Ref<'arena>) { if let NodeData::Text { ref contents } = node.data { // TODO: seems rather expensive to lookup the parent on every Text node. Better // solution would be to pass some sort of context from the parent that marks that this // Text node is inside a