diff --git a/src/actors/entry_crawler.rs b/src/actors/entry_crawler.rs index f810f77..ff0ead5 100644 --- a/src/actors/entry_crawler.rs +++ b/src/actors/entry_crawler.rs @@ -1,20 +1,17 @@ use std::fmt::{self, Display, Formatter}; use std::fs; use std::path::Path; -use std::sync::Arc; use bytes::Buf; -use feed_rs::parser; use readability::extractor; use reqwest::Client; use sqlx::PgPool; -use tokio::sync::{broadcast, mpsc, Mutex}; +use tokio::sync::{broadcast, mpsc}; use tracing::{info, instrument}; use url::Url; -use crate::config::Config; -use crate::models::entry::{update_entry, CreateEntry, Entry}; -use crate::models::feed::{upsert_feed, CreateFeed, Feed}; +use crate::domain_locks::DomainLocks; +use crate::models::entry::Entry; /// The `EntryCrawler` actor fetches an entry url, extracts the content, and saves the content to /// the file system and any associated metadata to the database. @@ -27,6 +24,7 @@ struct EntryCrawler { receiver: mpsc::Receiver, pool: PgPool, client: Client, + domain_locks: DomainLocks, content_dir: String, } @@ -68,12 +66,14 @@ impl EntryCrawler { receiver: mpsc::Receiver, pool: PgPool, client: Client, + domain_locks: DomainLocks, content_dir: String, ) -> Self { EntryCrawler { receiver, pool, client, + domain_locks, content_dir, } } @@ -84,17 +84,26 @@ impl EntryCrawler { let content_dir = Path::new(&self.content_dir); let url = Url::parse(&entry.url).map_err(|_| EntryCrawlerError::InvalidUrl(entry.url.clone()))?; + let domain = url + .domain() + .ok_or(EntryCrawlerError::InvalidUrl(entry.url.clone()))?; let bytes = self - .client - .get(url.clone()) - .send() - .await - .map_err(|_| EntryCrawlerError::FetchError(entry.url.clone()))? - .bytes() - .await - .map_err(|_| EntryCrawlerError::FetchError(entry.url.clone()))?; + .domain_locks + .run_request(domain, async { + self.client + .get(url.clone()) + .send() + .await + .map_err(|_| EntryCrawlerError::FetchError(entry.url.clone()))? + .bytes() + .await + .map_err(|_| EntryCrawlerError::FetchError(entry.url.clone())) + }) + .await?; + info!("fetched entry"); let article = extractor::extract(&mut bytes.reader(), &url) .map_err(|_| EntryCrawlerError::ExtractError(entry.url.clone()))?; + info!("extracted content"); let id = entry.entry_id; // TODO: update entry with scraped data // if let Some(date) = article.date { @@ -109,6 +118,7 @@ impl EntryCrawler { .map_err(|_| EntryCrawlerError::SaveContentError(entry.url.clone()))?; fs::write(content_dir.join(format!("{}.txt", id)), article.text) .map_err(|_| EntryCrawlerError::SaveContentError(entry.url.clone()))?; + info!("saved content to filesystem"); Ok(entry) } @@ -153,9 +163,14 @@ pub enum EntryCrawlerHandleMessage { impl EntryCrawlerHandle { /// Creates an async actor task that will listen for messages on the `sender` channel. - pub fn new(pool: PgPool, client: Client, content_dir: String) -> Self { + pub fn new( + pool: PgPool, + client: Client, + domain_locks: DomainLocks, + content_dir: String, + ) -> Self { let (sender, receiver) = mpsc::channel(8); - let mut crawler = EntryCrawler::new(receiver, pool, client, content_dir); + let mut crawler = EntryCrawler::new(receiver, pool, client, domain_locks, content_dir); tokio::spawn(async move { crawler.run().await }); Self { sender } diff --git a/src/actors/feed_crawler.rs b/src/actors/feed_crawler.rs index 610304e..a387969 100644 --- a/src/actors/feed_crawler.rs +++ b/src/actors/feed_crawler.rs @@ -10,6 +10,7 @@ use tracing::{info, info_span, instrument}; use url::Url; use crate::actors::entry_crawler::EntryCrawlerHandle; +use crate::domain_locks::DomainLocks; use crate::models::entry::{upsert_entries, CreateEntry, Entry}; use crate::models::feed::{upsert_feed, CreateFeed, Feed}; @@ -23,6 +24,7 @@ struct FeedCrawler { receiver: mpsc::Receiver, pool: PgPool, client: Client, + domain_locks: DomainLocks, content_dir: String, } @@ -46,6 +48,8 @@ impl Display for FeedCrawlerMessage { /// across threads (does not reference the originating Errors which are usually not cloneable). #[derive(thiserror::Error, Debug, Clone)] pub enum FeedCrawlerError { + #[error("invalid feed url: {0}")] + InvalidUrl(Url), #[error("failed to fetch feed: {0}")] FetchError(Url), #[error("failed to parse feed: {0}")] @@ -62,27 +66,36 @@ impl FeedCrawler { receiver: mpsc::Receiver, pool: PgPool, client: Client, + domain_locks: DomainLocks, content_dir: String, ) -> Self { FeedCrawler { receiver, pool, client, + domain_locks, content_dir, } } #[instrument(skip_all, fields(url = %url))] async fn crawl_feed(&self, url: Url) -> FeedCrawlerResult { + let domain = url + .domain() + .ok_or(FeedCrawlerError::InvalidUrl(url.clone()))?; let bytes = self - .client - .get(url.clone()) - .send() - .await - .map_err(|_| FeedCrawlerError::FetchError(url.clone()))? - .bytes() - .await - .map_err(|_| FeedCrawlerError::FetchError(url.clone()))?; + .domain_locks + .run_request(domain, async { + self.client + .get(url.clone()) + .send() + .await + .map_err(|_| FeedCrawlerError::FetchError(url.clone()))? + .bytes() + .await + .map_err(|_| FeedCrawlerError::FetchError(url.clone())) + }) + .await?; info!("fetched feed"); let parsed_feed = parser::parse(&bytes[..]).map_err(|_| FeedCrawlerError::ParseError(url.clone()))?; @@ -128,6 +141,7 @@ impl FeedCrawler { let entry_crawler = EntryCrawlerHandle::new( self.pool.clone(), self.client.clone(), + self.domain_locks.clone(), self.content_dir.clone(), ); // TODO: ignoring this receiver for the time being, pipe through events eventually @@ -179,9 +193,14 @@ pub enum FeedCrawlerHandleMessage { impl FeedCrawlerHandle { /// Creates an async actor task that will listen for messages on the `sender` channel. - pub fn new(pool: PgPool, client: Client, content_dir: String) -> Self { + pub fn new( + pool: PgPool, + client: Client, + domain_locks: DomainLocks, + content_dir: String, + ) -> Self { let (sender, receiver) = mpsc::channel(8); - let mut crawler = FeedCrawler::new(receiver, pool, client, content_dir); + let mut crawler = FeedCrawler::new(receiver, pool, client, domain_locks, content_dir); tokio::spawn(async move { crawler.run().await }); Self { sender } diff --git a/src/domain_locks.rs b/src/domain_locks.rs new file mode 100644 index 0000000..7ca3d31 --- /dev/null +++ b/src/domain_locks.rs @@ -0,0 +1,77 @@ +use std::collections::HashMap; +use std::future::Future; +use std::sync::Arc; + +use tokio::sync::Mutex; +use tokio::time::{sleep, Duration, Instant}; +use tracing::debug; + +pub type DomainLocksMap = Arc>>>>; + +// TODO: make this configurable per domain and then load into a cache at startup +// bonus points if I also make it changeable at runtime, for example, if a domain returns a 429, +// then I can increase it and make sure it is saved back to the configuration for the next startup. +pub const DOMAIN_LOCK_DURATION: Duration = Duration::from_secs(1); + +#[derive(Debug, Clone)] +pub struct DomainLocks { + map: DomainLocksMap, +} + +/// A mechanism to serialize multiple async tasks requesting a single domain. To prevent +/// overloading servers with too many requests run in parallel at once, crawlnicle will only +/// request a domain once a second. All async tasks that wish to scrape a feed or entry must use +/// the `run_request` method on this struct to wait their turn. +/// +/// Contains a map of domain names to a lock containing the timestamp of the last request to that +/// domain. +impl DomainLocks { + pub fn new() -> Self { + Self { + map: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Run the passed function `f` while holding a lock that gives exclusive access to the passed + /// domain. If another task running `run_request` currently has the lock to the + /// `DomainLocksMap` or the lock to the domain passed, then this function will wait until that + /// other task is done. Once it has access to the lock, if it has been less than one second + /// since the last request to the domain, then this function will sleep until one second has + /// passed before calling `f`. + pub async fn run_request(&self, domain: &str, f: F) -> T + where + F: Future, + { + let domain_last_request = { + let mut map = self.map.lock().await; + map.entry(domain.to_owned()) + .or_insert_with(|| Arc::new(Mutex::new(Instant::now() - DOMAIN_LOCK_DURATION))) + .clone() + }; + + let mut domain_last_request = domain_last_request.lock().await; + + let elapsed = domain_last_request.elapsed(); + if elapsed < DOMAIN_LOCK_DURATION { + let sleep_duration = DOMAIN_LOCK_DURATION - elapsed; + debug!( + domain, + duration = format!("{} ms", sleep_duration.as_millis()), + "sleeping before requesting domain", + ); + sleep(DOMAIN_LOCK_DURATION - elapsed).await; + } + + let result = f.await; + + *domain_last_request = Instant::now(); // Update the time of the last request. + + result + } +} + +impl Default for DomainLocks { + fn default() -> Self { + Self::new() + } +} diff --git a/src/handlers/feed.rs b/src/handlers/feed.rs index 85d9a74..753009c 100644 --- a/src/handlers/feed.rs +++ b/src/handlers/feed.rs @@ -18,6 +18,7 @@ use url::Url; use crate::actors::feed_crawler::{FeedCrawlerHandle, FeedCrawlerHandleMessage}; use crate::config::Config; +use crate::domain_locks::DomainLocks; use crate::error::{Error, Result}; use crate::models::entry::get_entries_for_feed; use crate::models::feed::{create_feed, delete_feed, get_feed, CreateFeed, FeedType}; @@ -109,13 +110,18 @@ impl IntoResponse for AddFeedError { pub async fn post( State(pool): State, State(crawls): State, + State(domain_locks): State, State(config): State, Form(add_feed): Form, ) -> AddFeedResult { // TODO: store the client in axum state (as long as it can be used concurrently?) let client = Client::new(); - let feed_crawler = - FeedCrawlerHandle::new(pool.clone(), client.clone(), config.content_dir.clone()); + let feed_crawler = FeedCrawlerHandle::new( + pool.clone(), + client.clone(), + domain_locks.clone(), + config.content_dir.clone(), + ); let feed = create_feed( &pool, diff --git a/src/lib.rs b/src/lib.rs index 7de2322..99192bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod actors; pub mod config; +pub mod domain_locks; pub mod error; pub mod handlers; pub mod jobs; diff --git a/src/main.rs b/src/main.rs index c6a2b4f..6ed3eff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,7 @@ use tower_livereload::LiveReloadLayer; use tracing::debug; use lib::config::Config; +use lib::domain_locks::DomainLocks; use lib::handlers; use lib::log::init_tracing; use lib::state::AppState; @@ -44,6 +45,7 @@ async fn main() -> Result<()> { let _guards = init_tracing(&config, log_sender)?; let crawls = Arc::new(Mutex::new(HashMap::new())); + let domain_locks = DomainLocks::new(); let pool = PgPoolOptions::new() .max_connections(config.database_max_connections) @@ -75,6 +77,7 @@ async fn main() -> Result<()> { config, log_receiver, crawls, + domain_locks, }) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); diff --git a/src/state.rs b/src/state.rs index b4db481..23e2aff 100644 --- a/src/state.rs +++ b/src/state.rs @@ -10,6 +10,7 @@ use uuid::Uuid; use crate::actors::feed_crawler::FeedCrawlerHandleMessage; use crate::config::Config; +use crate::domain_locks::DomainLocks; /// A map of feed IDs to a channel receiver for the active `FeedCrawler` running a crawl for that /// feed. @@ -29,6 +30,7 @@ pub struct AppState { pub config: Config, pub log_receiver: watch::Receiver, pub crawls: Crawls, + pub domain_locks: DomainLocks, } impl FromRef for PgPool { @@ -54,3 +56,9 @@ impl FromRef for Crawls { state.crawls.clone() } } + +impl FromRef for DomainLocks { + fn from_ref(state: &AppState) -> Self { + state.domain_locks.clone() + } +}