8 Commits
main ... apalis

Author SHA1 Message Date
6bd661765e WIP frontend isn't completely broken now
Still more work left with integrating apalis and need to fully update
it.

These changes are mostly for fixing the frontend I broke by eagerly
updating everything.
2025-02-10 00:55:39 -05:00
7a8f7dc415 Add import_opml job 2024-09-22 14:43:51 -04:00
e41085425a Upgrade apalis, add fred pool to state, start publishing in jobs 2024-09-22 13:49:24 -04:00
6912ef9017 Add crawl_entry job 2024-08-27 21:54:14 -04:00
65eac1975c Move feed fetching to crawl_feed job, DomainRequestLimiter
`DomainRequestLimiter` is a distributed version of `DomainLocks` based
on redis.
2024-08-26 01:12:18 -04:00
9c75a88c69 Start of a crawl_feed job 2024-08-25 22:24:02 -04:00
a3450e202a Working apalis cron and worker with 0.6.0-rc.5
Also renamed `pool` variables throughout codebase to `db` for clarity.
2024-08-21 01:21:45 -04:00
764d3f23b8 WIP add apalis & split up main process 2024-07-27 13:55:08 -04:00
41 changed files with 2163 additions and 652 deletions

1713
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@
name = "crawlnicle"
version = "0.1.0"
edition = "2021"
default-run = "crawlnicle"
default-run = "web"
authors = ["Tyler Hallada <tyler@hallada.net>"]
[lib]
@@ -15,18 +15,23 @@ path = "src/lib.rs"
ammonia = "4"
ansi-to-html = "0.2"
anyhow = "1"
# apalis v0.6 fixes this issue: https://github.com/geofmureithi/apalis/issues/351
apalis = { version = "0.6.0-rc.8", features = ["retry"] }
apalis-cron = "0.6.0-rc.8"
apalis-redis = "0.6.0-rc.8"
async-trait = "0.1"
axum = { version = "0.7", features = ["form", "multipart", "query"] }
axum-client-ip = "0.6"
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-login = "0.15"
axum-login = "0.16"
base64 = "0.22"
bytes = "1.4"
# TODO: replace chrono with time
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.4", features = ["derive", "env"] }
dotenvy = "0.15"
feed-rs = "1.3"
feed-rs = "2.1"
fred = "9"
futures = "0.3"
headers = "0.4"
http = "1.0.0"
@@ -36,10 +41,12 @@ lettre = { version = "0.11", features = ["builder"] }
maud = { git = "https://github.com/vidhanio/maud", branch = "patch-1", features = [
"axum",
] }
# upgrading this to > 6 causes infinite reloads with tower-livereload
notify = "6"
once_cell = "1.18"
opml = "1.1"
password-auth = "1.0"
rand = { version = "0.8", features = ["small_rng"] }
readability = "0.3"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
@@ -53,21 +60,21 @@ sqlx = { version = "0.7", features = [
"uuid",
"ipnetwork",
] }
thiserror = "1"
thiserror = "2"
time = "0.3"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
tower = "0.4"
tower = { version = "0.5", features = ["retry"] }
tower-livereload = "0.9"
tower-http = { version = "0.5", features = ["trace", "fs"] }
tower-sessions = { version = "0.12", features = ["signed"] }
tower-sessions-redis-store = "0.12"
tower-http = { version = "0.6", features = ["trace", "fs"] }
tower-sessions = { version = "0.13", features = ["signed"] }
tower-sessions-redis-store = "0.14"
tracing = { version = "0.1", features = ["valuable", "attributes"] }
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.4", features = ["serde"] }
url = "2.4"
validator = { version = "0.18", features = ["derive"] }
validator = { version = "0.19", features = ["derive"] }
[profile.dev.package.sqlx-macros]
opt-level = 3

View File

@@ -56,7 +56,7 @@ Install these requirements to get started developing crawlnicle.
directory with the contents:
```env
RUST_LOG=crawlnicle=debug,cli=debug,lib=debug,tower_http=debug,sqlx=debug
RUST_LOG=crawlnicle=debug,cli=debug,web=debug,worker=debug,crawler=debug,lib=debug,tower_http=debug,sqlx=debug
HOST=127.0.0.1
PORT=3000
PUBLIC_URL=http://localhost:3000

Binary file not shown.

View File

@@ -1,10 +1,5 @@
import htmx from 'htmx.org';
// import assets so they get named with a content hash that busts caches
// import '../css/styles.css';
import './localTimeController';
declare global {
interface Window {
htmx: typeof htmx;
@@ -13,5 +8,16 @@ declare global {
window.htmx = htmx;
// eslint-disable-next-line import/first
import 'htmx.org/dist/ext/sse';
// Wait for htmx to be fully initialized before loading extensions
document.addEventListener(
'htmx:load',
() => {
import('htmx.org/dist/ext/sse');
},
{ once: true }
);
// import assets so they get named with a content hash that busts caches
// import '../css/styles.css';
import './localTimeController';

View File

@@ -2,21 +2,21 @@
"name": "crawlnicle-frontend",
"module": "js/index.ts",
"devDependencies": {
"@tailwindcss/forms": "^0.5.7",
"@tailwindcss/typography": "^0.5.13",
"@typescript-eslint/eslint-plugin": "^7.8.0",
"@typescript-eslint/parser": "^7.8.0",
"bun-types": "^1.1.8",
"eslint": "^9.2.0",
"@tailwindcss/forms": "^0.5.10",
"@tailwindcss/typography": "^0.5.16",
"@typescript-eslint/eslint-plugin": "^7.18.0",
"@typescript-eslint/parser": "^7.18.0",
"bun-types": "^1.2.2",
"eslint": "^9.20.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-standard-with-typescript": "latest",
"eslint-plugin-import": "^2.29.1",
"eslint-plugin-n": "^17.6.0",
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-promise": "^6.1.1",
"prettier": "^3.2.5",
"tailwindcss": "^3.4.3",
"typescript": "^5.4.5"
"eslint-plugin-import": "^2.31.0",
"eslint-plugin-n": "^17.15.1",
"eslint-plugin-prettier": "^5.2.3",
"eslint-plugin-promise": "^6.6.0",
"prettier": "^3.4.2",
"tailwindcss": "^3.4.17",
"typescript": "^5.7.3"
},
"peerDependencies": {
"typescript": "^5.0.0"

View File

@@ -1,8 +1,8 @@
const plugin = require('tailwindcss/plugin');
import plugin from 'tailwindcss/plugin';
/** @type {import('tailwindcss').Config} */
export default {
content: ['./src/**/*.rs'],
content: ['../src/**/*.rs'],
theme: {
extend: {},
},

View File

@@ -7,14 +7,22 @@ install-frontend:
clean-frontend:
rm -rf ./static/js/* ./static/css/* ./static/img/*
build-frontend: clean-frontend
bunx tailwindcss -i frontend/css/styles.css -o static/css/styles.css --minify
[working-directory: 'frontend']
build-css:
bunx tailwindcss -i css/styles.css -o ../static/css/styles.css --minify
[working-directory: 'frontend']
build-dev-css:
bunx tailwindcss -i css/styles.css -o ../static/css/styles.css
build-frontend: clean-frontend build-css
bun build frontend/js/index.ts \
--outdir ./static \
--root ./frontend \
--entry-naming [dir]/[name]-[hash].[ext] \
--chunk-naming [dir]/[name]-[hash].[ext] \
--asset-naming [dir]/[name]-[hash].[ext] \
--sourcemap=linked \
--minify
mkdir -p static/img
cp frontend/img/* static/img/
@@ -22,14 +30,14 @@ build-frontend: clean-frontend
touch ./static/css/manifest.txt # create empty manifest to be overwritten by build.rs
touch .frontend-built # trigger build.rs to run
build-dev-frontend: clean-frontend
bunx tailwindcss -i frontend/css/styles.css -o static/css/styles.css
build-dev-frontend: clean-frontend build-dev-css
bun build frontend/js/index.ts \
--outdir ./static \
--root ./frontend \
--entry-naming [dir]/[name]-[hash].[ext] \
--chunk-naming [dir]/[name]-[hash].[ext] \
--asset-naming [dir]/[name]-[hash].[ext]
--asset-naming [dir]/[name]-[hash].[ext] \
--sourcemap=linked
mkdir -p static/img
cp frontend/img/* static/img/
touch ./static/js/manifest.txt # create empty manifest needed so binary compiles

5
rust-analyzer.json Normal file
View File

@@ -0,0 +1,5 @@
{
"files": {
"excludeDirs": ["frontend"]
}
}

View File

@@ -11,7 +11,7 @@ use uuid::Uuid;
use crate::actors::crawl_scheduler::{CrawlSchedulerHandle, CrawlSchedulerHandleMessage};
use crate::error::Error;
use crate::models::feed::{Feed, CreateFeed};
use crate::models::feed::{CreateFeed, Feed};
use crate::state::Imports;
use crate::uuid::Base62Uuid;

View File

@@ -97,7 +97,7 @@ pub async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let pool = PgPoolOptions::new()
let db = PgPoolOptions::new()
.max_connections(env::var("DATABASE_MAX_CONNECTIONS")?.parse()?)
.connect(&env::var("DATABASE_URL")?)
.await?;
@@ -108,7 +108,7 @@ pub async fn main() -> Result<()> {
match cli.commands {
Commands::AddFeed(args) => {
let feed = Feed::create(
&pool,
&db,
CreateFeed {
title: args.title,
url: args.url,
@@ -119,12 +119,12 @@ pub async fn main() -> Result<()> {
info!("Created feed with id {}", Base62Uuid::from(feed.feed_id));
}
Commands::DeleteFeed(args) => {
Feed::delete(&pool, args.id).await?;
Feed::delete(&db, args.id).await?;
info!("Deleted feed with id {}", Base62Uuid::from(args.id));
}
Commands::AddEntry(args) => {
let entry = Entry::create(
&pool,
&db,
CreateEntry {
title: args.title,
url: args.url,
@@ -137,7 +137,7 @@ pub async fn main() -> Result<()> {
info!("Created entry with id {}", Base62Uuid::from(entry.entry_id));
}
Commands::DeleteEntry(args) => {
Entry::delete(&pool, args.id).await?;
Entry::delete(&db, args.id).await?;
info!("Deleted entry with id {}", Base62Uuid::from(args.id));
}
Commands::Crawl(CrawlFeed { id }) => {
@@ -147,7 +147,7 @@ pub async fn main() -> Result<()> {
// server is running, it will *not* serialize same-domain requests with it.
let domain_locks = DomainLocks::new();
let feed_crawler = FeedCrawlerHandle::new(
pool.clone(),
db.clone(),
client.clone(),
domain_locks.clone(),
env::var("CONTENT_DIR")?,

118
src/bin/crawler.rs Normal file
View File

@@ -0,0 +1,118 @@
use apalis::layers::retry::{RetryLayer, RetryPolicy};
use apalis::layers::tracing::TraceLayer;
use apalis::prelude::*;
use apalis_cron::{CronStream, Schedule};
use apalis_redis::RedisStorage;
use chrono::{DateTime, Utc};
use clap::Parser;
use dotenvy::dotenv;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
use tracing::{info, instrument};
use lib::config::Config;
use lib::jobs::{AsyncJob, CrawlFeedJob};
use lib::log::init_worker_tracing;
use lib::models::feed::{Feed, GetFeedsOptions};
#[derive(Default, Debug, Clone)]
struct Crawl(DateTime<Utc>);
impl From<DateTime<Utc>> for Crawl {
fn from(t: DateTime<Utc>) -> Self {
Crawl(t)
}
}
#[derive(Debug, Error)]
enum CrawlError {
#[error("error fetching feeds")]
FetchFeedsError(#[from] sqlx::Error),
#[error("error queueing crawl feed job")]
QueueJobError(String),
}
#[derive(Clone)]
struct State {
pool: PgPool,
apalis: RedisStorage<AsyncJob>,
}
#[instrument(skip_all)]
pub async fn crawl_fn(job: Crawl, state: Data<Arc<State>>) -> Result<(), CrawlError> {
tracing::info!(job = ?job, "crawl");
let mut apalis = (state.apalis).clone();
let mut options = GetFeedsOptions::default();
loop {
info!("fetching feeds before: {:?}", options.before);
// TODO: filter to feeds where:
// now >= feed.last_crawled_at + feed.crawl_interval_minutes
// may need more indices...
let feeds = match Feed::get_all(&state.pool, &options).await {
Err(err) => return Err(CrawlError::FetchFeedsError(err)),
Ok(feeds) if feeds.is_empty() => {
info!("no more feeds found");
break;
}
Ok(feeds) => feeds,
};
info!("found {} feeds", feeds.len());
options.before = feeds.last().map(|f| f.created_at);
for feed in feeds.into_iter() {
// self.spawn_crawler_loop(feed, respond_to.clone());
// TODO: implement uniqueness on jobs per feed for ~1 minute
apalis
.push(AsyncJob::CrawlFeed(CrawlFeedJob {
feed_id: feed.feed_id,
}))
.await
.map_err(|err| CrawlError::QueueJobError(err.to_string()))?;
}
}
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv().ok();
let config = Config::parse();
let _guard = init_worker_tracing()?;
let pool = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
.await?;
// TODO: create connection from redis_pool for each job instead using a single connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let apalis_storage = RedisStorage::new_with_config(redis_conn, apalis_config);
let state = Arc::new(State {
pool,
apalis: apalis_storage.clone(),
});
let schedule = Schedule::from_str("0 * * * * *").unwrap();
let worker = WorkerBuilder::new("crawler")
.layer(RetryLayer::new(RetryPolicy::default()))
.layer(TraceLayer::new())
.data(state)
.backend(CronStream::new(schedule))
.build_fn(crawl_fn);
Monitor::<TokioExecutor>::new()
.register(worker)
.run()
.await
.unwrap();
Ok(())
}

View File

@@ -1,6 +1,8 @@
use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
use anyhow::Result;
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use axum::{
routing::{get, post},
Router,
@@ -14,6 +16,7 @@ use base64::prelude::*;
use bytes::Bytes;
use clap::Parser;
use dotenvy::dotenv;
use fred::prelude::*;
use lettre::transport::smtp::authentication::Credentials;
use lettre::SmtpTransport;
use notify::Watcher;
@@ -26,12 +29,20 @@ use tower::ServiceBuilder;
use tower_http::{services::ServeDir, trace::TraceLayer};
use tower_livereload::LiveReloadLayer;
use tower_sessions::cookie::Key;
use tower_sessions_redis_store::{fred::prelude::*, RedisStore};
use tower_sessions_redis_store::{
fred::{
interfaces::ClientLike as TowerSessionsRedisClientLike,
prelude::{RedisConfig as TowerSessionsRedisConfig, RedisPool as TowerSessionsRedisPool},
},
RedisStore,
};
use tracing::debug;
use lib::config::Config;
use lib::domain_locks::DomainLocks;
use lib::domain_request_limiter::DomainRequestLimiter;
use lib::handlers;
use lib::jobs::AsyncJob;
use lib::log::init_tracing;
use lib::state::AppState;
use lib::USER_AGENT;
@@ -63,7 +74,7 @@ async fn main() -> Result<()> {
let domain_locks = DomainLocks::new();
let client = Client::builder().user_agent(USER_AGENT).build()?;
let pool = PgPoolOptions::new()
let db = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
@@ -72,8 +83,20 @@ async fn main() -> Result<()> {
let redis_config = RedisConfig::from_url(&config.redis_url)?;
let redis_pool = RedisPool::new(redis_config, None, None, None, config.redis_pool_size)?;
redis_pool.init().await?;
let domain_request_limiter = DomainRequestLimiter::new(redis_pool.clone(), 10, 5, 100, 0.5);
let session_store = RedisStore::new(redis_pool);
// TODO: is it possible to use the same fred RedisPool that the web app uses?
let sessions_redis_config = TowerSessionsRedisConfig::from_url(&config.redis_url)?;
let sessions_redis_pool = TowerSessionsRedisPool::new(
sessions_redis_config,
None,
None,
None,
config.redis_pool_size,
)?;
sessions_redis_pool.init().await?;
let session_store = RedisStore::new(sessions_redis_pool);
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(!cfg!(debug_assertions))
.with_expiry(Expiry::OnInactivity(Duration::days(
@@ -81,7 +104,7 @@ async fn main() -> Result<()> {
)))
.with_signed(Key::from(&BASE64_STANDARD.decode(&config.session_secret)?));
let backend = Backend::new(pool.clone());
let backend = Backend::new(db.clone());
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
let smtp_creds = Credentials::new(config.smtp_user.clone(), config.smtp_password.clone());
@@ -91,17 +114,28 @@ async fn main() -> Result<()> {
.credentials(smtp_creds)
.build();
sqlx::migrate!().run(&pool).await?;
sqlx::migrate!().run(&db).await?;
// TODO: use redis_pool from above instead of making a new connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let mut apalis: RedisStorage<AsyncJob> =
RedisStorage::new_with_config(redis_conn, apalis_config);
apalis
.push(AsyncJob::HelloWorld("hello".to_string()))
.await?;
let crawl_scheduler = CrawlSchedulerHandle::new(
pool.clone(),
db.clone(),
client.clone(),
domain_locks.clone(),
config.content_dir.clone(),
crawls.clone(),
);
let _ = crawl_scheduler.bootstrap().await;
let importer = ImporterHandle::new(pool.clone(), crawl_scheduler.clone(), imports.clone());
// let _ = crawl_scheduler.bootstrap().await;
let importer = ImporterHandle::new(db.clone(), crawl_scheduler.clone(), imports.clone());
let ip_source_extension = config.ip_source.0.clone().into_extension();
@@ -140,16 +174,19 @@ async fn main() -> Result<()> {
.route("/reset-password", post(handlers::reset_password::post))
.nest_service("/static", ServeDir::new("static"))
.with_state(AppState {
pool,
db,
config,
log_receiver,
crawls,
domain_locks,
domain_request_limiter,
client,
crawl_scheduler,
importer,
imports,
mailer,
apalis,
redis: redis_pool,
})
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.layer(auth_layer)

61
src/bin/worker.rs Normal file
View File

@@ -0,0 +1,61 @@
use anyhow::Result;
use apalis::layers::retry::RetryPolicy;
use apalis::layers::tracing::TraceLayer;
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use clap::Parser;
use dotenvy::dotenv;
use fred::prelude::*;
use reqwest::Client;
use sqlx::postgres::PgPoolOptions;
use tower::retry::RetryLayer;
use lib::config::Config;
use lib::domain_request_limiter::DomainRequestLimiter;
use lib::jobs::{handle_async_job, AsyncJob};
use lib::log::init_worker_tracing;
use lib::USER_AGENT;
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
let config = Config::parse();
let _guard = init_worker_tracing()?;
// TODO: create connection from redis_pool for each job instead using a single connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let apalis_storage: RedisStorage<AsyncJob> =
RedisStorage::new_with_config(redis_conn, apalis_config);
let redis_config = RedisConfig::from_url(&config.redis_url)?;
let redis_pool = RedisPool::new(redis_config, None, None, None, 5)?;
redis_pool.init().await?;
let domain_request_limiter = DomainRequestLimiter::new(redis_pool.clone(), 10, 5, 100, 0.5);
let http_client = Client::builder().user_agent(USER_AGENT).build()?;
let db = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
.await?;
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
WorkerBuilder::new("worker")
.layer(RetryLayer::new(RetryPolicy::default()))
.layer(TraceLayer::new())
.data(http_client)
.data(db)
.data(domain_request_limiter)
.data(config)
.data(apalis_storage.clone())
.data(redis_pool)
.backend(apalis_storage)
.build_fn(handle_async_job)
})
.run()
.await
.unwrap();
Ok(())
}

View File

@@ -33,10 +33,10 @@ impl DomainLocks {
}
/// Run the passed function `f` while holding a lock that gives exclusive access to the passed
/// domain. If another task running `run_request` currently has the lock to the
/// domain. If another task running `run_request` currently has the lock to the
/// `DomainLocksMap` or the lock to the domain passed, then this function will wait until that
/// other task is done. Once it has access to the lock, if it has been less than one second
/// since the last request to the domain, then this function will sleep until one second has
/// other task is done. Once it has access to the lock, if it has been less than one second
/// since the last request to the domain, then this function will sleep until one second has
/// passed before calling `f`.
pub async fn run_request<F, T>(&self, domain: &str, f: F) -> T
where

View File

@@ -0,0 +1,123 @@
use anyhow::{anyhow, Result};
use fred::{clients::RedisPool, interfaces::KeysInterface, prelude::*};
use rand::{rngs::SmallRng, Rng, SeedableRng};
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};
/// A Redis-based rate limiter for domain-specific requests with jittered retry delay.
///
/// This limiter uses a fixed window algorithm with a 1-second window and applies
/// jitter to the retry delay to help prevent synchronized retries in distributed systems.
/// It uses fred's RedisPool for efficient connection management.
///
/// Limitations:
/// 1. Fixed window: The limit resets every second, potentially allowing short traffic bursts
/// at window boundaries.
/// 2. No token bucket: Doesn't accumulate unused capacity from quiet periods.
/// 3. Potential overcounting: In distributed systems, there's a small chance of overcounting
/// near window ends due to race conditions.
/// 4. Redis dependency: Rate limiting fails open if Redis is unavailable.
/// 5. Blocking: The acquire method will block until a request is allowed or max_retries is reached.
///
/// Usage example:
/// ```
/// use fred::prelude::*;
///
/// #[tokio::main]
/// async fn main() -> Result<()> {
/// let config = RedisConfig::default();
/// let pool = RedisPool::new(config, None, None, 5)?;
/// pool.connect();
/// pool.wait_for_connect().await?;
///
/// let limiter = DomainRequestLimiter::new(pool, 10, 5, 100, 0.5);
/// let domain = "example.com";
///
/// for _ in 0..15 {
/// match limiter.acquire(domain).await {
/// Ok(()) => println!("Request allowed"),
/// Err(_) => println!("Max retries reached, request denied"),
/// }
/// }
///
/// Ok(())
/// }
/// ```
#[derive(Debug, Clone)]
pub struct DomainRequestLimiter {
redis_pool: RedisPool,
requests_per_second: u32,
max_retries: u32,
base_retry_delay_ms: u64,
jitter_factor: f64,
// TODO: I think I can get rid of this if I instantiate a DomainRequestLimiter per-worker, but
// I'm not sure how to do that in apalis (then I could just use thread_rng)
rng: Arc<Mutex<SmallRng>>,
}
impl DomainRequestLimiter {
/// Create a new DomainRequestLimiter.
///
/// # Arguments
/// * `redis_pool` - A fred RedisPool.
/// * `requests_per_second` - Maximum allowed requests per second per domain.
/// * `max_retries` - Maximum number of retries before giving up.
/// * `base_retry_delay_ms` - Base delay between retries in milliseconds.
/// * `jitter_factor` - Factor to determine the maximum jitter (0.0 to 1.0).
pub fn new(
redis_pool: RedisPool,
requests_per_second: u32,
max_retries: u32,
base_retry_delay_ms: u64,
jitter_factor: f64,
) -> Self {
Self {
redis_pool,
requests_per_second,
max_retries,
base_retry_delay_ms,
jitter_factor: jitter_factor.clamp(0.0, 1.0),
rng: Arc::new(Mutex::new(SmallRng::from_entropy())),
}
}
/// Attempt to acquire permission for a request, retrying if necessary.
///
/// This method will attempt to acquire permission up to max_retries times,
/// sleeping for a jittered delay between each attempt.
///
/// # Arguments
/// * `domain` - The domain for which to check the rate limit.
///
/// # Returns
/// Ok(()) if permission is granted, or an error if max retries are exceeded.
pub async fn acquire(&self, domain: &str) -> Result<()> {
for attempt in 0..=self.max_retries {
if self.try_acquire(domain).await? {
return Ok(());
}
if attempt < self.max_retries {
let mut rng = self.rng.lock().await;
let jitter =
rng.gen::<f64>() * self.jitter_factor * self.base_retry_delay_ms as f64;
let delay = self.base_retry_delay_ms + jitter as u64;
sleep(Duration::from_millis(delay)).await;
}
}
Err(anyhow!(
"Max retries exceeded for domain: {:?}, request denied",
domain
))
}
async fn try_acquire(&self, domain: &str) -> Result<bool, RedisError> {
let key = format!("rate_limit:{}", domain);
let count: u32 = self.redis_pool.incr(&key).await?;
if count == 1 {
self.redis_pool.expire(&key, 1).await?;
}
Ok(count <= self.requests_per_second)
}
}

View File

@@ -27,7 +27,7 @@ pub enum Error {
#[error("validation error in request body")]
InvalidEntity(#[from] ValidationErrors),
#[error("error with file upload: (0)")]
#[error("error with file upload")]
Upload(#[from] MultipartError),
#[error("no file uploaded")]
@@ -49,7 +49,7 @@ pub enum Error {
Unauthorized,
#[error("bad request: {0}")]
BadRequest(&'static str)
BadRequest(&'static str),
}
pub type Result<T, E = Error> = ::std::result::Result<T, E>;

View File

@@ -13,9 +13,9 @@ use crate::partials::entry_list::entry_list;
pub async fn get(
Query(options): Query<GetEntriesOptions>,
accept: Option<TypedHeader<Accept>>,
State(pool): State<PgPool>,
State(db): State<PgPool>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let entries = Entry::get_all(&pool, &options).await.map_err(Error::from)?;
let entries = Entry::get_all(&db, &options).await.map_err(Error::from)?;
if let Some(TypedHeader(accept)) = accept {
if accept == Accept::ApplicationJson {
return Ok::<ApiResponse<Vec<Entry>>, Error>(ApiResponse::Json(entries));

View File

@@ -9,15 +9,15 @@ use crate::models::entry::{CreateEntry, Entry};
use crate::uuid::Base62Uuid;
pub async fn get(
State(pool): State<PgPool>,
State(db): State<PgPool>,
Path(id): Path<Base62Uuid>,
) -> Result<Json<Entry>, Error> {
Ok(Json(Entry::get(&pool, id.as_uuid()).await?))
Ok(Json(Entry::get(&db, id.as_uuid()).await?))
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
Json(payload): Json<CreateEntry>,
) -> Result<Json<Entry>, Error> {
Ok(Json(Entry::create(&pool, payload).await?))
Ok(Json(Entry::create(&db, payload).await?))
}

View File

@@ -8,17 +8,17 @@ use crate::error::{Error, Result};
use crate::models::feed::{CreateFeed, Feed};
use crate::uuid::Base62Uuid;
pub async fn get(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Json<Feed>> {
Ok(Json(Feed::get(&pool, id.as_uuid()).await?))
pub async fn get(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Json<Feed>> {
Ok(Json(Feed::get(&db, id.as_uuid()).await?))
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
Json(payload): Json<CreateFeed>,
) -> Result<Json<Feed>, Error> {
Ok(Json(Feed::create(&pool, payload).await?))
Ok(Json(Feed::create(&db, payload).await?))
}
pub async fn delete(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<()> {
Feed::delete(&pool, id.as_uuid()).await
pub async fn delete(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<()> {
Feed::delete(&db, id.as_uuid()).await
}

View File

@@ -13,9 +13,9 @@ use crate::partials::feed_list::feed_list;
pub async fn get(
Query(options): Query<GetFeedsOptions>,
accept: Option<TypedHeader<Accept>>,
State(pool): State<PgPool>,
State(db): State<PgPool>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let feeds = Feed::get_all(&pool, &options).await.map_err(Error::from)?;
let feeds = Feed::get_all(&db, &options).await.map_err(Error::from)?;
if let Some(TypedHeader(accept)) = accept {
if accept == Accept::ApplicationJson {
return Ok::<ApiResponse<Vec<Feed>>, Error>(ApiResponse::Json(feeds));

View File

@@ -70,7 +70,7 @@ pub fn confirm_email_page(
}
pub async fn get(
State(pool): State<PgPool>,
State(db): State<PgPool>,
auth: AuthSession,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
@@ -78,7 +78,7 @@ pub async fn get(
) -> Result<Response> {
if let Some(token_id) = query.token_id {
info!(token_id = %token_id.as_uuid(), "get with token_id");
let token = match UserEmailVerificationToken::get(&pool, token_id.as_uuid()).await {
let token = match UserEmailVerificationToken::get(&db, token_id.as_uuid()).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -112,8 +112,8 @@ pub async fn get(
}))
} else {
info!(token_id = %token.token_id, "token valid, verifying email");
User::verify_email(&pool, token.user_id).await?;
UserEmailVerificationToken::delete(&pool, token.token_id).await?;
User::verify_email(&db, token.user_id).await?;
UserEmailVerificationToken::delete(&db, token.token_id).await?;
Ok(layout
.with_subtitle("confirm email")
.targeted(hx_target)
@@ -152,7 +152,7 @@ pub struct ConfirmEmail {
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
hx_target: Option<TypedHeader<HXTarget>>,
@@ -161,11 +161,11 @@ pub async fn post(
) -> Result<Response> {
if let Some(token_id) = confirm_email.token {
info!(%token_id, "posted with token_id");
let token = UserEmailVerificationToken::get(&pool, token_id).await?;
let user = User::get(&pool, token.user_id).await?;
let token = UserEmailVerificationToken::get(&db, token_id).await?;
let user = User::get(&db, token.user_id).await?;
if !user.email_verified {
info!(user_id = %user.user_id, "user exists, resending confirmation email");
send_confirmation_email(pool, mailer, config, user);
send_confirmation_email(db, mailer, config, user);
} else {
warn!(user_id = %user.user_id, "confirm email submitted for already verified user, skip resend");
}
@@ -184,10 +184,10 @@ pub async fn post(
}));
}
if let Some(email) = confirm_email.email {
if let Ok(user) = User::get_by_email(&pool, email).await {
if let Ok(user) = User::get_by_email(&db, email).await {
if !user.email_verified {
info!(user_id = %user.user_id, "user exists, resending confirmation email");
send_confirmation_email(pool, mailer, config, user);
send_confirmation_email(db, mailer, config, user);
} else {
warn!(user_id = %user.user_id, "confirm email submitted for already verified user, skip resend");
}

View File

@@ -8,8 +8,8 @@ use crate::partials::entry_list::entry_list;
pub async fn get(
Query(options): Query<GetEntriesOptions>,
State(pool): State<PgPool>,
State(db): State<PgPool>,
) -> Result<Markup> {
let entries = Entry::get_all(&pool, &options).await?;
let entries = Entry::get_all(&db, &options).await?;
Ok(entry_list(entries, &options, false))
}

View File

@@ -16,12 +16,12 @@ use crate::uuid::Base62Uuid;
pub async fn get(
Path(id): Path<Base62Uuid>,
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(config): State<Config>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let entry = Entry::get(&pool, id.as_uuid()).await?;
let entry = Entry::get(&db, id.as_uuid()).await?;
let content_dir = std::path::Path::new(&config.content_dir);
let content_path = content_dir.join(format!("{}.html", entry.entry_id));
let title = entry.title.unwrap_or_else(|| "Untitled Entry".to_string());

View File

@@ -28,17 +28,17 @@ use crate::uuid::Base62Uuid;
pub async fn get(
Path(id): Path<Base62Uuid>,
State(pool): State<PgPool>,
State(db): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let feed = Feed::get(&pool, id.as_uuid()).await?;
let feed = Feed::get(&db, id.as_uuid()).await?;
let options = GetEntriesOptions {
feed_id: Some(feed.feed_id),
..Default::default()
};
let title = feed.title.unwrap_or_else(|| "Untitled Feed".to_string());
let entries = Entry::get_all(&pool, &options).await?;
let entries = Entry::get_all(&db, &options).await?;
let delete_url = format!("/feed/{}/delete", id);
Ok(layout.with_subtitle(&title).targeted(hx_target).render(html! {
header class="mb-4 flex flex-row items-center gap-4" {
@@ -115,13 +115,13 @@ impl IntoResponse for AddFeedError {
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(crawls): State<Crawls>,
State(crawl_scheduler): State<CrawlSchedulerHandle>,
Form(add_feed): Form<AddFeed>,
) -> AddFeedResult<Response> {
let feed = Feed::create(
&pool,
&db,
CreateFeed {
title: add_feed.title,
url: add_feed.url.clone(),
@@ -233,7 +233,7 @@ pub async fn stream(
))
}
pub async fn delete(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Redirect> {
Feed::delete(&pool, id.as_uuid()).await?;
pub async fn delete(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Redirect> {
Feed::delete(&db, id.as_uuid()).await?;
Ok(Redirect::to("/feeds"))
}

View File

@@ -13,12 +13,12 @@ use crate::partials::layout::Layout;
use crate::partials::opml_import_form::opml_import_form;
pub async fn get(
State(pool): State<PgPool>,
State(db): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let options = GetFeedsOptions::default();
let feeds = Feed::get_all(&pool, &options).await?;
let feeds = Feed::get_all(&db, &options).await?;
Ok(layout
.with_subtitle("feeds")
.targeted(hx_target)

View File

@@ -82,7 +82,7 @@ pub async fn get(
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
SecureClientIp(ip): SecureClientIp,
@@ -91,7 +91,7 @@ pub async fn post(
layout: Layout,
Form(forgot_password): Form<ForgotPassword>,
) -> Result<Response> {
let user: User = match User::get_by_email(&pool, forgot_password.email.clone()).await {
let user: User = match User::get_by_email(&db, forgot_password.email.clone()).await {
Ok(user) => user,
Err(err) => {
if let Error::NotFoundString(_, _) = err {
@@ -105,7 +105,7 @@ pub async fn post(
if user.email_verified {
info!(user_id = %user.user_id, "user exists with verified email, sending password reset email");
send_forgot_password_email(
pool,
db,
mailer,
config,
user,

View File

@@ -10,12 +10,12 @@ use crate::models::entry::Entry;
use crate::partials::{entry_list::entry_list, layout::Layout};
pub async fn get(
State(pool): State<PgPool>,
State(db): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let options = Default::default();
let entries = Entry::get_all(&pool, &options).await?;
let entries = Entry::get_all(&db, &options).await?;
Ok(layout.targeted(hx_target).render(html! {
ul class="list-none flex flex-col gap-4" {
(entry_list(entries, &options, true))

View File

@@ -59,7 +59,7 @@ pub async fn get(hx_target: Option<TypedHeader<HXTarget>>, layout: Layout) -> Re
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
mut auth: AuthSession,
@@ -80,7 +80,7 @@ pub async fn post(
));
}
let user = match User::create(
&pool,
&db,
CreateUser {
email: register.email.clone(),
password: register.password.clone(),
@@ -144,7 +144,7 @@ pub async fn post(
}
};
send_confirmation_email(pool, mailer, config, user.clone());
send_confirmation_email(db, mailer, config, user.clone());
auth.login(&user)
.await

View File

@@ -126,14 +126,14 @@ pub fn reset_password_page(
}
pub async fn get(
State(pool): State<PgPool>,
State(db): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
query: Query<ResetPasswordQuery>,
) -> Result<Response> {
if let Some(token_id) = query.token_id {
info!(token_id = %token_id.as_uuid(), "get with token_id");
let token = match UserPasswordResetToken::get(&pool, token_id.as_uuid()).await {
let token = match UserPasswordResetToken::get(&db, token_id.as_uuid()).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -158,7 +158,7 @@ pub async fn get(
}))
} else {
info!(token_id = %token.token_id, "token valid, showing reset password form");
let user = User::get(&pool, token.user_id).await?;
let user = User::get(&db, token.user_id).await?;
Ok(reset_password_page(ResetPasswordPageProps {
hx_target,
layout,
@@ -181,7 +181,7 @@ pub async fn get(
}
pub async fn post(
State(pool): State<PgPool>,
State(db): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
SecureClientIp(ip): SecureClientIp,
@@ -203,7 +203,7 @@ pub async fn post(
..Default::default()
}));
}
let token = match UserPasswordResetToken::get(&pool, reset_password.token).await {
let token = match UserPasswordResetToken::get(&db, reset_password.token).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -241,7 +241,7 @@ pub async fn post(
..Default::default()
}));
}
let user = match User::get(&pool, token.user_id).await {
let user = match User::get(&db, token.user_id).await {
Ok(user) => user,
Err(err) => {
if let Error::NotFoundString(_, _) = err {
@@ -266,7 +266,7 @@ pub async fn post(
}
};
info!(user_id = %user.user_id, "user exists with verified email, resetting password");
let mut tx = pool.begin().await?;
let mut tx = db.begin().await?;
UserPasswordResetToken::delete(tx.as_mut(), reset_password.token).await?;
let user = match user
.update_password(

68
src/jobs/crawl_entry.rs Normal file
View File

@@ -0,0 +1,68 @@
use std::fs;
use std::path::Path;
use ammonia::clean;
use anyhow::{anyhow, Result};
use apalis::prelude::*;
use bytes::Buf;
use fred::prelude::*;
use readability::extractor;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{info, instrument};
use url::Url;
use uuid::Uuid;
use crate::config::Config;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::models::entry::Entry;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CrawlEntryJob {
pub entry_id: Uuid,
}
#[instrument(skip_all, fields(entry_id = %entry_id))]
pub async fn crawl_entry(
CrawlEntryJob { entry_id }: CrawlEntryJob,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
config: Data<Config>,
redis: Data<RedisPool>,
) -> Result<()> {
let entry = Entry::get(&*db, entry_id).await?;
info!("got entry from db");
let content_dir = Path::new(&*config.content_dir);
let url = Url::parse(&entry.url)?;
let domain = url
.domain()
.ok_or(anyhow!("invalid url: {:?}", entry.url.clone()))?;
info!(url=%url, "starting fetch");
domain_request_limiter.acquire(domain).await?;
let bytes = http_client.get(url.clone()).send().await?.bytes().await?;
info!(url=%url, "fetched entry");
let article = extractor::extract(&mut bytes.reader(), &url)?;
info!("extracted content");
let id = entry.entry_id;
// TODO: update entry with scraped data
// if let Some(date) = article.date {
// // prefer scraped date over rss feed date
// let mut updated_entry = entry.clone();
// updated_entry.published_at = date;
// entry = update_entry(&self.pool, updated_entry)
// .await
// .map_err(|_| EntryCrawlerError::CreateEntryError(entry.url.clone()))?;
// };
let content = clean(&article.content);
info!("sanitized content");
fs::write(content_dir.join(format!("{}.html", id)), content)?;
fs::write(content_dir.join(format!("{}.txt", id)), article.text)?;
info!("saved content to filesystem");
redis
.next()
.publish("entries", entry_id.to_string())
.await?;
Ok(())
}

189
src/jobs/crawl_feed.rs Normal file
View File

@@ -0,0 +1,189 @@
use std::cmp::Ordering;
use anyhow::{anyhow, Result};
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use chrono::{Duration, Utc};
use feed_rs::parser;
use fred::prelude::*;
use http::{header, HeaderMap, StatusCode};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{info, info_span, instrument, warn};
use url::Url;
use uuid::Uuid;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::jobs::{AsyncJob, CrawlEntryJob};
use crate::models::entry::{CreateEntry, Entry};
use crate::models::feed::{Feed, MAX_CRAWL_INTERVAL_MINUTES, MIN_CRAWL_INTERVAL_MINUTES};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CrawlFeedJob {
pub feed_id: Uuid,
}
#[instrument(skip_all, fields(feed_id = %feed_id))]
pub async fn crawl_feed(
CrawlFeedJob { feed_id }: CrawlFeedJob,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<()> {
let mut feed = Feed::get(&*db, feed_id).await?;
info!("got feed from db");
let url = Url::parse(&feed.url)?;
let domain = url
.domain()
.ok_or(anyhow!("invalid url: {:?}", feed.url.clone()))?;
let mut headers = HeaderMap::new();
if let Some(etag) = &feed.etag_header {
if let Ok(etag) = etag.parse() {
headers.insert(header::IF_NONE_MATCH, etag);
} else {
warn!(%etag, "failed to parse saved etag header");
}
}
if let Some(last_modified) = &feed.last_modified_header {
if let Ok(last_modified) = last_modified.parse() {
headers.insert(header::IF_MODIFIED_SINCE, last_modified);
} else {
warn!(
%last_modified,
"failed to parse saved last_modified header",
);
}
}
info!(url=%url, "starting fetch");
domain_request_limiter.acquire(domain).await?;
let resp = http_client.get(url.clone()).headers(headers).send().await?;
let headers = resp.headers();
if let Some(etag) = headers.get(header::ETAG) {
if let Ok(etag) = etag.to_str() {
feed.etag_header = Some(etag.to_string());
} else {
warn!(?etag, "failed to convert response etag header to string");
}
}
if let Some(last_modified) = headers.get(header::LAST_MODIFIED) {
if let Ok(last_modified) = last_modified.to_str() {
feed.last_modified_header = Some(last_modified.to_string());
} else {
warn!(
?last_modified,
"failed to convert response last_modified header to string",
);
}
}
info!(url=%url, "fetched feed");
if resp.status() == StatusCode::NOT_MODIFIED {
info!("feed returned not modified status");
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = None;
feed.save(&*db).await?;
info!("updated feed in db");
return Ok(());
} else if !resp.status().is_success() {
warn!("feed returned non-successful status");
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = resp.status().canonical_reason().map(|s| s.to_string());
feed.save(&*db).await?;
info!("updated feed in db");
return Ok(());
}
let bytes = resp.bytes().await?;
let parsed_feed = parser::parse(&bytes[..])?;
info!("parsed feed");
feed.url = url.to_string();
feed.feed_type = parsed_feed.feed_type.into();
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = None;
if let Some(title) = parsed_feed.title {
feed.title = Some(title.content);
}
if let Some(description) = parsed_feed.description {
feed.description = Some(description.content);
}
let last_entry_published_at = parsed_feed.entries.iter().filter_map(|e| e.published).max();
if let Some(prev_last_entry_published_at) = feed.last_entry_published_at {
if let Some(published_at) = last_entry_published_at {
let time_since_last_entry = if published_at == prev_last_entry_published_at {
// No new entry since last crawl, compare current time to last publish instead
Utc::now() - prev_last_entry_published_at
} else {
// Compare new entry publish time to previous publish time
published_at - prev_last_entry_published_at
};
match time_since_last_entry.cmp(&Duration::minutes(feed.crawl_interval_minutes.into()))
{
Ordering::Greater => {
feed.crawl_interval_minutes = i32::max(
(feed.crawl_interval_minutes as f32 * 1.2).ceil() as i32,
MAX_CRAWL_INTERVAL_MINUTES,
);
info!(
interval = feed.crawl_interval_minutes,
"increased crawl interval"
);
}
Ordering::Less => {
feed.crawl_interval_minutes = i32::max(
(feed.crawl_interval_minutes as f32 / 1.2).ceil() as i32,
MIN_CRAWL_INTERVAL_MINUTES,
);
info!(
interval = feed.crawl_interval_minutes,
"decreased crawl interval"
);
}
Ordering::Equal => {}
}
}
}
feed.last_entry_published_at = last_entry_published_at;
let feed = feed.save(&*db).await?;
info!("updated feed in db");
let mut payload = Vec::with_capacity(parsed_feed.entries.len());
for entry in parsed_feed.entries {
let entry_span = info_span!("entry", id = entry.id);
let _entry_span_guard = entry_span.enter();
if let Some(link) = entry.links.first() {
// if no scraped or feed date is available, fallback to the current time
let published_at = entry.published.unwrap_or_else(Utc::now);
let entry = CreateEntry {
title: entry.title.map(|t| t.content),
url: link.href.clone(),
description: entry.summary.map(|s| s.content),
feed_id: feed.feed_id,
published_at,
};
payload.push(entry);
} else {
warn!("skipping feed entry with no links");
}
}
let entries = Entry::bulk_upsert(&*db, payload).await?;
let (new, updated) = entries
.into_iter()
.partition::<Vec<_>, _>(|entry| entry.updated_at.is_none());
info!(new = new.len(), updated = updated.len(), "saved entries");
for entry in new {
(*apalis)
.clone() // TODO: clone bad?
.push(AsyncJob::CrawlEntry(CrawlEntryJob {
entry_id: entry.entry_id,
}))
.await?;
}
redis.next().publish("feeds", feed_id.to_string()).await?;
Ok(())
}

105
src/jobs/import_opml.rs Normal file
View File

@@ -0,0 +1,105 @@
use std::io::Cursor;
use anyhow::{anyhow, Context, Result};
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use fred::prelude::*;
use opml::OPML;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{error, instrument, warn};
use uuid::Uuid;
use crate::error::Error;
use crate::jobs::crawl_feed::CrawlFeedJob;
use crate::jobs::AsyncJob;
use crate::models::feed::{CreateFeed, Feed};
use crate::uuid::Base62Uuid;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ImportOpmlJob {
pub import_id: Uuid,
pub file_name: Option<String>,
pub bytes: Vec<u8>,
}
// TODO: send messages over redis channel
/// `ImporterOpmlMessage::Import` contains the result of importing the OPML file.
// #[allow(clippy::large_enum_variant)]
// #[derive(Debug, Clone)]
// pub enum ImporterOpmlMessage {
// Import(ImporterResult<()>),
// CreateFeedError(String),
// AlreadyImported(String),
// CrawlScheduler(CrawlSchedulerHandleMessage),
// }
#[instrument(skip_all, fields(import_id = %import_id))]
pub async fn import_opml(
ImportOpmlJob {
import_id,
file_name,
bytes,
}: ImportOpmlJob,
db: Data<PgPool>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<()> {
let document = OPML::from_reader(&mut Cursor::new(bytes)).with_context(|| {
format!(
"Failed to read OPML file for import {} from file {}",
Base62Uuid::from(import_id),
file_name
.map(|n| n.to_string())
.unwrap_or_else(|| "unknown".to_string())
)
})?;
for url in gather_feed_urls(document.body.outlines) {
let feed = Feed::create(
&*db,
CreateFeed {
url: url.clone(),
..Default::default()
},
)
.await;
match feed {
Ok(feed) => {
(*apalis)
.clone()
.push(AsyncJob::CrawlFeed(CrawlFeedJob {
feed_id: feed.feed_id,
}))
.await?;
}
Err(Error::Sqlx(sqlx::error::Error::Database(err))) => {
if err.is_unique_violation() {
// let _ = respond_to.send(ImporterHandleMessage::AlreadyImported(url));
warn!("Feed {} already imported", url);
}
}
Err(err) => {
// let _ = respond_to.send(ImporterHandleMessage::CreateFeedError(url));
error!("Failed to create feed for {}", url);
return Err(anyhow!(err));
}
}
}
redis
.next()
.publish("imports", import_id.to_string())
.await?;
Ok(())
}
fn gather_feed_urls(outlines: Vec<opml::Outline>) -> Vec<String> {
let mut urls = Vec::new();
for outline in outlines.into_iter() {
if let Some(url) = outline.xml_url {
urls.push(url);
}
urls.append(&mut gather_feed_urls(outline.outlines));
}
urls
}

70
src/jobs/mod.rs Normal file
View File

@@ -0,0 +1,70 @@
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use fred::prelude::*;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use thiserror::Error;
use tracing::{error, info, instrument};
mod crawl_entry;
mod crawl_feed;
mod import_opml;
pub use crawl_entry::CrawlEntryJob;
pub use crawl_feed::CrawlFeedJob;
pub use import_opml::ImportOpmlJob;
use crate::{config::Config, domain_request_limiter::DomainRequestLimiter};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum AsyncJob {
HelloWorld(String),
CrawlFeed(CrawlFeedJob),
CrawlEntry(CrawlEntryJob),
ImportOpml(ImportOpmlJob),
}
#[derive(Debug, Error)]
pub enum AsyncJobError {
#[error("error executing job")]
JobError(#[from] anyhow::Error),
}
#[instrument(skip_all, fields(worker_id = ?worker_id, task_id = ?task_id))]
pub async fn handle_async_job(
job: AsyncJob,
worker_id: Data<WorkerId>,
task_id: Data<TaskId>,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
config: Data<Config>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<(), AsyncJobError> {
let result = match job {
AsyncJob::HelloWorld(name) => {
info!("Hello, {}!", name);
Ok(())
}
AsyncJob::CrawlFeed(job) => {
crawl_feed::crawl_feed(job, http_client, db, domain_request_limiter, apalis, redis)
.await
}
AsyncJob::CrawlEntry(job) => {
crawl_entry::crawl_entry(job, http_client, db, domain_request_limiter, config, redis)
.await
}
AsyncJob::ImportOpml(job) => import_opml::import_opml(job, db, apalis, redis).await,
};
match result {
Ok(_) => info!("Job completed successfully"),
Err(err) => {
error!("Job failed: {err:?}");
return Err(AsyncJobError::JobError(err));
}
};
Ok(())
}

View File

@@ -3,10 +3,12 @@ pub mod api_response;
pub mod auth;
pub mod config;
pub mod domain_locks;
pub mod domain_request_limiter;
pub mod error;
pub mod handlers;
pub mod headers;
pub mod htmx;
pub mod jobs;
pub mod log;
pub mod mailers;
pub mod models;

View File

@@ -91,3 +91,17 @@ pub fn init_tracing(
.init();
Ok((file_writer_guard, mem_writer_guard))
}
pub fn init_worker_tracing() -> Result<WorkerGuard> {
let stdout_layer = tracing_subscriber::fmt::layer().pretty();
let filter_layer = EnvFilter::from_default_env();
let file_appender = tracing_appender::rolling::hourly("./logs", "log");
let (file_writer, file_writer_guard) = tracing_appender::non_blocking(file_appender);
let file_writer_layer = tracing_subscriber::fmt::layer().with_writer(file_writer);
tracing_subscriber::registry()
.with(filter_layer)
.with(stdout_layer)
.with(file_writer_layer)
.init();
Ok(file_writer_guard)
}

View File

@@ -17,7 +17,7 @@ use crate::uuid::Base62Uuid;
// TODO: put in config
const USER_EMAIL_VERIFICATION_TOKEN_EXPIRATION: Duration = Duration::from_secs(24 * 60 * 60);
pub fn send_confirmation_email(pool: PgPool, mailer: SmtpTransport, config: Config, user: User) {
pub fn send_confirmation_email(db: PgPool, mailer: SmtpTransport, config: Config, user: User) {
tokio::spawn(async move {
let user_email_address = match user.email.parse() {
Ok(address) => address,
@@ -28,7 +28,7 @@ pub fn send_confirmation_email(pool: PgPool, mailer: SmtpTransport, config: Conf
};
let mailbox = Mailbox::new(user.name.clone(), user_email_address);
let token = match UserEmailVerificationToken::create(
&pool,
&db,
CreateUserEmailVerificationToken {
user_id: user.user_id,
expires_at: Utc::now() + USER_EMAIL_VERIFICATION_TOKEN_EXPIRATION,
@@ -42,11 +42,10 @@ pub fn send_confirmation_email(pool: PgPool, mailer: SmtpTransport, config: Conf
return;
}
};
let mut confirm_link = config
.public_url
.clone();
let mut confirm_link = config.public_url.clone();
confirm_link.set_path("confirm-email");
confirm_link.query_pairs_mut()
confirm_link
.query_pairs_mut()
.append_pair("token_id", &Base62Uuid::from(token.token_id).to_string());
let confirm_link = confirm_link.as_str();

View File

@@ -18,7 +18,7 @@ use crate::uuid::Base62Uuid;
const PASSWORD_RESET_TOKEN_EXPIRATION: Duration = Duration::from_secs(24 * 60 * 60);
pub fn send_forgot_password_email(
pool: PgPool,
db: PgPool,
mailer: SmtpTransport,
config: Config,
user: User,
@@ -35,7 +35,7 @@ pub fn send_forgot_password_email(
};
let mailbox = Mailbox::new(user.name.clone(), user_email_address);
let token = match UserPasswordResetToken::create(
&pool,
&db,
CreatePasswordResetToken {
token_id: Uuid::new_v4(), // cyptographically-secure random uuid
user_id: user.user_id,

View File

@@ -32,7 +32,7 @@ impl UserPasswordResetToken {
}
pub async fn get(
pool: impl Executor<'_, Database = Postgres>,
db: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserPasswordResetToken> {
sqlx::query_as!(
@@ -43,7 +43,7 @@ impl UserPasswordResetToken {
where token_id = $1"#,
token_id
)
.fetch_one(pool)
.fetch_one(db)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@@ -54,7 +54,7 @@ impl UserPasswordResetToken {
}
pub async fn create(
pool: impl Executor<'_, Database = Postgres>,
db: impl Executor<'_, Database = Postgres>,
payload: CreatePasswordResetToken,
) -> Result<UserPasswordResetToken> {
Ok(sqlx::query_as!(
@@ -70,20 +70,17 @@ impl UserPasswordResetToken {
payload.request_ip,
payload.expires_at
)
.fetch_one(pool)
.fetch_one(db)
.await?)
}
pub async fn delete(
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<()> {
pub async fn delete(db: impl Executor<'_, Database = Postgres>, token_id: Uuid) -> Result<()> {
sqlx::query!(
r#"delete from user_password_reset_token
where token_id = $1"#,
token_id
)
.execute(pool)
.execute(db)
.await?;
Ok(())
}

View File

@@ -1,18 +1,22 @@
use std::collections::HashMap;
use std::sync::Arc;
use apalis_redis::RedisStorage;
use axum::extract::FromRef;
use bytes::Bytes;
use fred::clients::RedisPool;
use lettre::SmtpTransport;
use reqwest::Client;
use sqlx::PgPool;
use tokio::sync::{broadcast, watch, Mutex};
use uuid::Uuid;
use crate::actors::importer::{ImporterHandle, ImporterHandleMessage};
use crate::actors::crawl_scheduler::{CrawlSchedulerHandle, CrawlSchedulerHandleMessage};
use crate::actors::importer::{ImporterHandle, ImporterHandleMessage};
use crate::config::Config;
use crate::domain_locks::DomainLocks;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::jobs::AsyncJob;
/// A map of feed IDs to a channel receiver for the active `CrawlScheduler` running a feed crawl
/// for that feed.
@@ -28,32 +32,35 @@ pub type Crawls = Arc<Mutex<HashMap<Uuid, broadcast::Receiver<CrawlSchedulerHand
/// A map of unique import IDs to a channel receiver for the active `Importer` running that import.
///
/// Same as the `Crawls` map, the only purpose of this is to keep track of active imports so that
/// axum handlers can subscribe to the result of the import via the receiver channel which are then
/// Same as the `Crawls` map, the only purpose of this is to keep track of active imports so that
/// axum handlers can subscribe to the result of the import via the receiver channel which are then
/// sent to end-users as a stream of server-sent events.
///
/// This map should only contain imports that have just been created but not yet subscribed to.
/// Entries are only added when a user adds uploads an OPML to import and entries are removed by
/// Entries are only added when a user adds uploads an OPML to import and entries are removed by
/// the same user once a server-sent event connection is established.
pub type Imports = Arc<Mutex<HashMap<Uuid, broadcast::Receiver<ImporterHandleMessage>>>>;
#[derive(Clone)]
pub struct AppState {
pub pool: PgPool,
pub db: PgPool,
pub config: Config,
pub log_receiver: watch::Receiver<Bytes>,
pub crawls: Crawls,
pub domain_locks: DomainLocks,
pub domain_request_limiter: DomainRequestLimiter,
pub client: Client,
pub crawl_scheduler: CrawlSchedulerHandle,
pub importer: ImporterHandle,
pub imports: Imports,
pub mailer: SmtpTransport,
pub apalis: RedisStorage<AsyncJob>,
pub redis: RedisPool,
}
impl FromRef<AppState> for PgPool {
fn from_ref(state: &AppState) -> Self {
state.pool.clone()
state.db.clone()
}
}
@@ -81,6 +88,12 @@ impl FromRef<AppState> for DomainLocks {
}
}
impl FromRef<AppState> for DomainRequestLimiter {
fn from_ref(state: &AppState) -> Self {
state.domain_request_limiter.clone()
}
}
impl FromRef<AppState> for Client {
fn from_ref(state: &AppState) -> Self {
state.client.clone()
@@ -110,3 +123,15 @@ impl FromRef<AppState> for SmtpTransport {
state.mailer.clone()
}
}
impl FromRef<AppState> for RedisStorage<AsyncJob> {
fn from_ref(state: &AppState) -> Self {
state.apalis.clone()
}
}
impl FromRef<AppState> for RedisPool {
fn from_ref(state: &AppState) -> Self {
state.redis.clone()
}
}

View File

@@ -1,2 +0,0 @@
const config = require('./frontend/tailwind.config.js');
export default config;