From f13c7e5e702beae209858ed6d2fa82de41472ce8 Mon Sep 17 00:00:00 2001 From: Tyler Hallada Date: Sun, 9 Jul 2023 21:18:19 -0400 Subject: [PATCH] Add an async actor FeedCrawler for fetching feed details Currently, this allows the browser to subscribe to the response of the asynchronous crawl after they add a new feed. Eventually I will also use this in the main scheduled crawls. Right now, it only upserts feed metadata based on the parsed feed. --- Cargo.toml | 2 +- src/actors/feed_crawler.rs | 160 +++++++++++++++++++++++++++++++++++++ src/actors/mod.rs | 1 + src/error.rs | 9 ++- src/handlers/feed.rs | 148 ++++++++++++++++++++++++++++------ src/handlers/feeds.rs | 2 +- src/handlers/log.rs | 3 +- src/lib.rs | 1 + src/main.rs | 18 +++-- src/models/entry.rs | 18 ++--- src/models/feed.rs | 43 +++++++++- src/partials/feed_link.rs | 10 ++- src/state.rs | 30 ++++++- src/uuid.rs | 18 ++++- 14 files changed, 405 insertions(+), 58 deletions(-) create mode 100644 src/actors/feed_crawler.rs create mode 100644 src/actors/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 5fd6b73..8fcab9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ tokio-stream = { version = "0.1", features = ["sync"] } tower = "0.4" tower-livereload = "0.8" tower-http = { version = "0.4", features = ["trace", "fs"] } -tracing = { version = "0.1", features = ["valuable"] } +tracing = { version = "0.1", features = ["valuable", "attributes"] } tracing-appender = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.3", features = ["serde"] } diff --git a/src/actors/feed_crawler.rs b/src/actors/feed_crawler.rs new file mode 100644 index 0000000..e87e949 --- /dev/null +++ b/src/actors/feed_crawler.rs @@ -0,0 +1,160 @@ +use std::fmt::{self, Display, Formatter}; + +use feed_rs::parser; +use reqwest::Client; +use sqlx::PgPool; +use tokio::sync::{broadcast, mpsc}; +use tracing::{info, instrument}; +use url::Url; + +use crate::models::entry::Entry; +use crate::models::feed::{upsert_feed, CreateFeed, Feed}; + +/// The `FeedCrawler` actor fetches a feed url, parses it, and saves it to the database. +/// +/// It receives `FeedCrawlerMessage` messages via the `receiver` channel. It communicates back to +/// the sender of those messages via the `respond_to` channel on the `FeedCrawlerMessage`. +/// +/// `FeedCrawler` should not be instantiated directly. Instead, use the `FeedCrawlerHandle`. +struct FeedCrawler { + receiver: mpsc::Receiver, + pool: PgPool, + client: Client, +} + +#[derive(Debug)] +enum FeedCrawlerMessage { + Crawl { + url: Url, + respond_to: broadcast::Sender, + }, +} + +impl Display for FeedCrawlerMessage { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + FeedCrawlerMessage::Crawl { url, .. } => write!(f, "Crawl({})", url), + } + } +} + +/// An error type that enumerates possible failures during a crawl and is cloneable and can be sent +/// across threads (does not reference the originating Errors which are usually not cloneable). +#[derive(thiserror::Error, Debug, Clone)] +pub enum FeedCrawlerError { + #[error("failed to fetch feed: {0}")] + FetchError(Url), + #[error("failed to parse feed: {0}")] + ParseError(Url), + #[error("failed to create feed: {0}")] + CreateFeedError(Url), +} +pub type FeedCrawlerResult = ::std::result::Result; + +impl FeedCrawler { + fn new(receiver: mpsc::Receiver, pool: PgPool, client: Client) -> Self { + FeedCrawler { + receiver, + pool, + client, + } + } + + #[instrument(skip_all, fields(url = %url))] + async fn crawl_feed(&self, url: Url) -> FeedCrawlerResult { + let bytes = self + .client + .get(url.clone()) + .send() + .await + .map_err(|_| FeedCrawlerError::FetchError(url.clone()))? + .bytes() + .await + .map_err(|_| FeedCrawlerError::FetchError(url.clone()))?; + info!("fetched feed"); + let parsed_feed = + parser::parse(&bytes[..]).map_err(|_| FeedCrawlerError::ParseError(url.clone()))?; + info!("parsed feed"); + let feed = upsert_feed( + &self.pool, + CreateFeed { + title: parsed_feed.title.map(|text| text.content), + url: url.to_string(), + feed_type: parsed_feed.feed_type.into(), + description: parsed_feed.description.map(|text| text.content), + }, + ) + .await + .map_err(|_| FeedCrawlerError::CreateFeedError(url.clone()))?; + info!(%feed.feed_id, "upserted feed"); + Ok(feed) + } + + #[instrument(skip_all, fields(msg = %msg))] + async fn handle_message(&mut self, msg: FeedCrawlerMessage) { + match msg { + FeedCrawlerMessage::Crawl { url, respond_to } => { + let result = self.crawl_feed(url).await; + // ignore the result since the initiator may have cancelled waiting for the + // response, and that is ok + let _ = respond_to.send(FeedCrawlerHandleMessage::Feed(result)); + } + } + } + + #[instrument(skip_all)] + async fn run(&mut self) { + info!("starting feed crawler"); + while let Some(msg) = self.receiver.recv().await { + self.handle_message(msg).await; + } + } +} + +/// The `FeedCrawlerHandle` is used to initialize and communicate with a `FeedCrawler` actor. +/// +/// The `FeedCrawler` actor fetches a feed url, parses it, and saves it to the database. It runs as +/// a separate asynchronous task from the main web server and communicates via channels. +#[derive(Clone)] +pub struct FeedCrawlerHandle { + sender: mpsc::Sender, +} + +/// The `FeedCrawlerHandleMessage` is the response to a `FeedCrawlerMessage` sent to the +/// `FeedCrawlerHandle`. +/// +/// `FeedCrawlerHandleMessage::Feed` contains the result of crawling a feed url. +/// `FeedCrawlerHandleMessage::Entry` contains the result of crawling an entry url. +#[derive(Clone)] +pub enum FeedCrawlerHandleMessage { + Feed(FeedCrawlerResult), + Entry(FeedCrawlerResult), +} + +impl FeedCrawlerHandle { + /// Creates an async actor task that will listen for messages on the `sender` channel. + pub fn new(pool: PgPool, client: Client) -> Self { + let (sender, receiver) = mpsc::channel(8); + let mut crawler = FeedCrawler::new(receiver, pool, client); + tokio::spawn(async move { crawler.run().await }); + + Self { sender } + } + + /// Sends a `FeedCrawlerMessage::Crawl` message to the running `FeedCrawler` actor. + /// + /// Listen to the result of the crawl via the returned `broadcast::Receiver`. + pub async fn crawl(&self, url: Url) -> broadcast::Receiver { + let (sender, receiver) = broadcast::channel(8); + let msg = FeedCrawlerMessage::Crawl { + url, + respond_to: sender, + }; + + self.sender + .send(msg) + .await + .expect("feed crawler task has died"); + receiver + } +} diff --git a/src/actors/mod.rs b/src/actors/mod.rs new file mode 100644 index 0000000..6a8a931 --- /dev/null +++ b/src/actors/mod.rs @@ -0,0 +1 @@ +pub mod feed_crawler; diff --git a/src/error.rs b/src/error.rs index fd1dec7..e925442 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,8 @@ use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; -use tracing::error; use serde_with::DisplayFromStr; +use tracing::error; use uuid::Uuid; use validator::ValidationErrors; @@ -31,6 +31,9 @@ pub enum Error { #[error("referenced {0} not found")] RelationNotFound(&'static str), + + #[error("an internal server error occurred")] + InternalServerError, } pub type Result = ::std::result::Result; @@ -72,7 +75,9 @@ impl Error { match self { NotFound(_, _) => StatusCode::NOT_FOUND, - Sqlx(_) | Anyhow(_) | Reqwest(_) => StatusCode::INTERNAL_SERVER_ERROR, + InternalServerError | Sqlx(_) | Anyhow(_) | Reqwest(_) => { + StatusCode::INTERNAL_SERVER_ERROR + } InvalidEntity(_) | RelationNotFound(_) => StatusCode::UNPROCESSABLE_ENTITY, } } diff --git a/src/handlers/feed.rs b/src/handlers/feed.rs index 74a6766..02ff8f2 100644 --- a/src/handlers/feed.rs +++ b/src/handlers/feed.rs @@ -1,6 +1,9 @@ +use std::time::Duration; + use axum::extract::{Path, State}; use axum::http::StatusCode; -use axum::response::{IntoResponse, Response, Redirect}; +use axum::response::sse::{Event, KeepAlive}; +use axum::response::{IntoResponse, Redirect, Response, Sse}; use axum::Form; use feed_rs::parser; use maud::html; @@ -8,13 +11,19 @@ use reqwest::Client; use serde::Deserialize; use serde_with::{serde_as, NoneAsEmptyString}; use sqlx::PgPool; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::StreamExt; +use url::Url; +use crate::actors::feed_crawler::{FeedCrawlerHandle, FeedCrawlerHandleMessage}; use crate::error::{Error, Result}; use crate::models::entry::get_entries_for_feed; -use crate::models::feed::{create_feed, get_feed, CreateFeed, delete_feed}; +use crate::models::feed::{create_feed, delete_feed, get_feed, CreateFeed, FeedType}; use crate::partials::{entry_list::entry_list, feed_link::feed_link, layout::Layout}; -use crate::uuid::Base62Uuid; +use crate::state::Crawls; use crate::turbo_stream::TurboStream; +use crate::uuid::Base62Uuid; pub async fn get( Path(id): Path, @@ -51,12 +60,16 @@ pub struct AddFeed { #[derive(thiserror::Error, Debug)] pub enum AddFeedError { + #[error("invalid feed url: {0}")] + InvalidUrl(String, #[source] url::ParseError), #[error("failed to fetch feed: {0}")] FetchError(String, #[source] reqwest::Error), #[error("failed to parse feed: {0}")] ParseError(String, #[source] parser::ParseFeedError), #[error("failed to create feed: {0}")] CreateFeedError(String, #[source] Error), + #[error("feed already exists: {0}")] + FeedAlreadyExists(String, #[source] Error), } pub type AddFeedResult = ::std::result::Result; @@ -65,7 +78,9 @@ impl AddFeedError { use AddFeedError::*; match self { - FetchError(..) | ParseError(..) => StatusCode::UNPROCESSABLE_ENTITY, + InvalidUrl(..) | FetchError(..) | ParseError(..) | FeedAlreadyExists(..) => { + StatusCode::UNPROCESSABLE_ENTITY + } CreateFeedError(..) => StatusCode::INTERNAL_SERVER_ERROR, } } @@ -92,41 +107,56 @@ impl IntoResponse for AddFeedError { pub async fn post( State(pool): State, + State(crawls): State, Form(add_feed): Form, ) -> AddFeedResult { + // TODO: store the client in axum state (as long as it can be used concurrently?) let client = Client::new(); - let bytes = client - .get(&add_feed.url) - .send() - .await - .map_err(|err| AddFeedError::FetchError(add_feed.url.clone(), err))? - .bytes() - .await - .map_err(|err| AddFeedError::FetchError(add_feed.url.clone(), err))?; - let parsed_feed = parser::parse(&bytes[..]) - .map_err(|err| AddFeedError::ParseError(add_feed.url.clone(), err))?; + let feed_crawler = FeedCrawlerHandle::new(pool.clone(), client.clone()); + let feed = create_feed( &pool, CreateFeed { - title: add_feed - .title - .map_or_else(|| parsed_feed.title.map(|text| text.content), Some), + title: add_feed.title, url: add_feed.url.clone(), - feed_type: parsed_feed.feed_type.into(), - description: add_feed - .description - .map_or_else(|| parsed_feed.description.map(|text| text.content), Some), + feed_type: FeedType::Rss, // eh, get rid of this + description: add_feed.description, }, ) .await - .map_err(|err| AddFeedError::CreateFeedError(add_feed.url.clone(), err))?; + .map_err(|err| { + if let Error::Sqlx(sqlx::error::Error::Database(db_error)) = &err { + if let Some(code) = db_error.code() { + if let Some(constraint) = db_error.constraint() { + if code == "23505" && constraint == "feed_url_idx" { + return AddFeedError::FeedAlreadyExists(add_feed.url.clone(), err); + } + } + } + } + AddFeedError::CreateFeedError(add_feed.url.clone(), err) + })?; + + let url: Url = Url::parse(&add_feed.url) + .map_err(|err| AddFeedError::InvalidUrl(add_feed.url.clone(), err))?; + let receiver = feed_crawler.crawl(url).await; + { + let mut crawls = crawls.lock().map_err(|_| { + AddFeedError::CreateFeedError(add_feed.url.clone(), Error::InternalServerError) + })?; + crawls.insert(feed.feed_id, receiver); + } + + let feed_id = format!("feed-{}", Base62Uuid::from(feed.feed_id)); + let feed_stream = format!("/feed/{}/stream", Base62Uuid::from(feed.feed_id)); Ok(( StatusCode::CREATED, TurboStream( html! { + turbo-stream-source src=(feed_stream) id="feed-stream" {} turbo-stream action="append" target="feeds" { template { - li { (feed_link(&feed)) } + li id=(feed_id) { (feed_link(&feed, true)) } } } } @@ -136,10 +166,76 @@ pub async fn post( .into_response()) } -pub async fn delete( - State(pool): State, +pub async fn stream( Path(id): Path, -) -> Result { + State(crawls): State, +) -> Result { + let receiver = { + let mut crawls = crawls.lock().expect("crawls lock poisoned"); + crawls.remove(&id.as_uuid()) + } + .ok_or_else(|| Error::NotFound("feed stream", id.as_uuid()))?; + + let stream = BroadcastStream::new(receiver); + let feed_id = format!("feed-{}", id); + let stream = stream.map(move |msg| match msg { + Ok(FeedCrawlerHandleMessage::Feed(Ok(feed))) => Ok::( + Event::default().data( + html! { + turbo-stream action="remove" target="feed-stream" {} + turbo-stream action="replace" target=(feed_id) { + template { + li id=(feed_id) { (feed_link(&feed, false)) } + } + } + } + .into_string(), + ), + ), + Ok(FeedCrawlerHandleMessage::Feed(Err(error))) => Ok(Event::default().data( + html! { + turbo-stream action="remove" target="feed-stream" {} + turbo-stream action="replace" target=(feed_id) { + template { + li id=(feed_id) { span class="error" { (error) } } + } + } + } + .into_string(), + )), + // TODO: these Entry messages are not yet sent, need to handle them better + Ok(FeedCrawlerHandleMessage::Entry(Ok(_))) => Ok(Event::default().data( + html! { + turbo-stream action="remove" target="feed-stream" {} + turbo-stream action="replace" target=(feed_id) { + template { + li id=(feed_id) { "fetched entry" } + } + } + } + .into_string(), + )), + Ok(FeedCrawlerHandleMessage::Entry(Err(error))) => Ok(Event::default().data( + html! { + turbo-stream action="remove" target="feed-stream" {} + turbo-stream action="replace" target=(feed_id) { + template { + li id=(feed_id) { span class="error" { (error) } } + } + } + } + .into_string(), + )), + Err(BroadcastStreamRecvError::Lagged(_)) => Ok(Event::default()), + }); + Ok(Sse::new(stream).keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive-text"), + )) +} + +pub async fn delete(State(pool): State, Path(id): Path) -> Result { delete_feed(&pool, id.as_uuid()).await?; Ok(Redirect::to("/feeds")) } diff --git a/src/handlers/feeds.rs b/src/handlers/feeds.rs index 23b9d53..133c314 100644 --- a/src/handlers/feeds.rs +++ b/src/handlers/feeds.rs @@ -14,7 +14,7 @@ pub async fn get(State(pool): State, layout: Layout) -> Result div class="feeds" { ul id="feeds" { @for feed in feeds { - li { (feed_link(&feed)) } + li { (feed_link(&feed, false)) } } } div class="add-feed" { diff --git a/src/handlers/log.rs b/src/handlers/log.rs index c16442a..4aed645 100644 --- a/src/handlers/log.rs +++ b/src/handlers/log.rs @@ -4,6 +4,7 @@ use std::time::Duration; use ansi_to_html::convert_escaped; use axum::extract::State; +use axum::response::sse::KeepAlive; use axum::response::{ sse::{Event, Sse}, Response, @@ -44,7 +45,7 @@ pub async fn stream( )) }); Sse::new(log_stream).keep_alive( - axum::response::sse::KeepAlive::new() + KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive-text"), ) diff --git a/src/lib.rs b/src/lib.rs index dc8405b..7de2322 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod actors; pub mod config; pub mod error; pub mod handlers; diff --git a/src/main.rs b/src/main.rs index 124cf1e..c6a2b4f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,9 @@ -use std::{path::Path, net::SocketAddr}; +use std::{ + collections::HashMap, + net::SocketAddr, + path::Path, + sync::{Arc, Mutex}, +}; use anyhow::Result; use axum::{ @@ -12,7 +17,7 @@ use notify::Watcher; use sqlx::postgres::PgPoolOptions; use tokio::sync::watch::channel; use tower::ServiceBuilder; -use tower_http::{trace::TraceLayer, services::ServeDir}; +use tower_http::{services::ServeDir, trace::TraceLayer}; use tower_livereload::LiveReloadLayer; use tracing::debug; @@ -38,6 +43,8 @@ async fn main() -> Result<()> { let (log_sender, log_receiver) = channel::(Bytes::new()); let _guards = init_tracing(&config, log_sender)?; + let crawls = Arc::new(Mutex::new(HashMap::new())); + let pool = PgPoolOptions::new() .max_connections(config.database_max_connections) .connect(&config.database_url) @@ -57,6 +64,7 @@ async fn main() -> Result<()> { .route("/feeds", get(handlers::feeds::get)) .route("/feed", post(handlers::feed::post)) .route("/feed/:id", get(handlers::feed::get)) + .route("/feed/:id/stream", get(handlers::feed::stream)) .route("/feed/:id/delete", post(handlers::feed::delete)) .route("/entry/:id", get(handlers::entry::get)) .route("/log", get(handlers::log::get)) @@ -66,6 +74,7 @@ async fn main() -> Result<()> { pool, config, log_receiver, + crawls, }) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); @@ -74,10 +83,7 @@ async fn main() -> Result<()> { let livereload = LiveReloadLayer::new(); let reloader = livereload.reloader(); let mut watcher = notify::recommended_watcher(move |_| reloader.reload())?; - watcher.watch( - Path::new("static"), - notify::RecursiveMode::Recursive, - )?; + watcher.watch(Path::new("static"), notify::RecursiveMode::Recursive)?; app = app.layer(livereload); serve(app, addr).await?; } else { diff --git a/src/models/entry.rs b/src/models/entry.rs index 90015d5..e0740b8 100644 --- a/src/models/entry.rs +++ b/src/models/entry.rs @@ -8,7 +8,7 @@ use crate::error::{Error, Result}; const DEFAULT_ENTRIES_PAGE_SIZE: i64 = 50; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Entry { pub entry_id: Uuid, pub title: Option, @@ -51,10 +51,7 @@ pub struct GetEntriesOptions { pub limit: Option, } -pub async fn get_entries( - pool: &PgPool, - options: GetEntriesOptions, -) -> sqlx::Result> { +pub async fn get_entries(pool: &PgPool, options: GetEntriesOptions) -> sqlx::Result> { if let Some(published_before) = options.published_before { sqlx::query_as!( Entry, @@ -81,7 +78,6 @@ pub async fn get_entries( ) .fetch_all(pool) .await - } } @@ -120,7 +116,6 @@ pub async fn get_entries_for_feed( ) .fetch_all(pool) .await - } } @@ -266,8 +261,11 @@ pub async fn update_entry(pool: &PgPool, payload: Entry) -> Result { } pub async fn delete_entry(pool: &PgPool, entry_id: Uuid) -> Result<()> { - sqlx::query!("update entry set deleted_at = now() where entry_id = $1", entry_id) - .execute(pool) - .await?; + sqlx::query!( + "update entry set deleted_at = now() where entry_id = $1", + entry_id + ) + .execute(pool) + .await?; Ok(()) } diff --git a/src/models/feed.rs b/src/models/feed.rs index 4772909..812979b 100644 --- a/src/models/feed.rs +++ b/src/models/feed.rs @@ -37,7 +37,7 @@ impl From for FeedType { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Feed { pub feed_id: Uuid, pub title: Option, @@ -135,9 +135,44 @@ pub async fn create_feed(pool: &PgPool, payload: CreateFeed) -> Result { .await?) } +pub async fn upsert_feed(pool: &PgPool, payload: CreateFeed) -> Result { + payload.validate()?; + Ok(sqlx::query_as!( + Feed, + r#"insert into feed ( + title, url, type, description + ) values ( + $1, $2, $3, $4 + ) on conflict (url) do update set + title = excluded.title, + url = excluded.url, + type = excluded.type, + description = excluded.description + returning + feed_id, + title, + url, + type as "feed_type: FeedType", + description, + created_at, + updated_at, + deleted_at + "#, + payload.title, + payload.url, + payload.feed_type as FeedType, + payload.description + ) + .fetch_one(pool) + .await?) +} + pub async fn delete_feed(pool: &PgPool, feed_id: Uuid) -> Result<()> { - sqlx::query!("update feed set deleted_at = now() where feed_id = $1", feed_id) - .execute(pool) - .await?; + sqlx::query!( + "update feed set deleted_at = now() where feed_id = $1", + feed_id + ) + .execute(pool) + .await?; Ok(()) } diff --git a/src/partials/feed_link.rs b/src/partials/feed_link.rs index dab1b86..b24cfe8 100644 --- a/src/partials/feed_link.rs +++ b/src/partials/feed_link.rs @@ -3,8 +3,14 @@ use maud::{html, Markup}; use crate::models::feed::Feed; use crate::uuid::Base62Uuid; -pub fn feed_link(feed: &Feed) -> Markup { - let title = feed.title.clone().unwrap_or_else(|| "Untitled Feed".to_string()); +pub fn feed_link(feed: &Feed, pending_crawl: bool) -> Markup { + let title = feed.title.clone().unwrap_or_else(|| { + if pending_crawl { + "Crawling feed...".to_string() + } else { + "Untitled Feed".to_string() + } + }); let feed_url = format!("/feed/{}", Base62Uuid::from(feed.feed_id)); html! { a href=(feed_url) { (title) } diff --git a/src/state.rs b/src/state.rs index 7cc5d94..b4db481 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,16 +1,34 @@ -use tokio::sync::watch::Receiver; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use tokio::sync::{broadcast, watch}; use axum::extract::FromRef; use bytes::Bytes; use sqlx::PgPool; +use uuid::Uuid; +use crate::actors::feed_crawler::FeedCrawlerHandleMessage; use crate::config::Config; +/// A map of feed IDs to a channel receiver for the active `FeedCrawler` running a crawl for that +/// feed. +/// +/// Currently, the only purpose of this is to keep track of active crawls so that axum handlers can +/// subscribe to the result of the crawl via the receiver channel which are then sent to end-users +/// as a stream of server-sent events. +/// +/// This map should only contain crawls that have just been created but not yet subscribed to. +/// Entries are only added when a user adds a feed in the UI and entries are removed by the same +/// user once a server-sent event connection is established. +pub type Crawls = Arc>>>; + #[derive(Clone)] pub struct AppState { pub pool: PgPool, pub config: Config, - pub log_receiver: Receiver, + pub log_receiver: watch::Receiver, + pub crawls: Crawls, } impl FromRef for PgPool { @@ -25,8 +43,14 @@ impl FromRef for Config { } } -impl FromRef for Receiver { +impl FromRef for watch::Receiver { fn from_ref(state: &AppState) -> Self { state.log_receiver.clone() } } + +impl FromRef for Crawls { + fn from_ref(state: &AppState) -> Self { + state.crawls.clone() + } +} diff --git a/src/uuid.rs b/src/uuid.rs index 91f1462..8453458 100644 --- a/src/uuid.rs +++ b/src/uuid.rs @@ -1,21 +1,35 @@ -use std::fmt::{Display, Formatter, self}; +use std::fmt::{self, Display, Formatter}; use serde::{Deserialize, Serialize}; use uuid::Uuid; const BASE62_CHARS: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; +/// A wrapper around a UUID (from `uuid::Uuid`) that serializes to a Base62 string. +/// +/// Database rows have a UUID primary key, but they are encoded in Base62 to be shorter and more +/// URL-friendly for the frontend. #[derive(Debug, Serialize, Deserialize)] pub struct Base62Uuid( #[serde(deserialize_with = "uuid_from_base62_str")] #[serde(serialize_with = "uuid_to_base62_str")] - Uuid + Uuid, ); impl Base62Uuid { pub fn as_uuid(&self) -> Uuid { self.0 } + + pub fn new() -> Self { + Self(Uuid::new_v4()) + } +} + +impl Default for Base62Uuid { + fn default() -> Self { + Self::new() + } } impl From for Base62Uuid {