1 Commits

Author SHA1 Message Date
cfc962c2cb Adding tokio-console support for debugging tokio tasks
This isn't always that useful and I'm not sure what performance impact
it has so I'm going to keep this in a branch in case I need it in the
future.
2024-05-11 16:04:18 -04:00
42 changed files with 878 additions and 2130 deletions

3
.cargo/config.toml Normal file
View File

@@ -0,0 +1,3 @@
[build]
# needed so that tokio-console has the necessary data
rustflags = ["--cfg", "tokio_unstable"]

1866
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@
name = "crawlnicle"
version = "0.1.0"
edition = "2021"
default-run = "web"
default-run = "crawlnicle"
authors = ["Tyler Hallada <tyler@hallada.net>"]
[lib]
@@ -15,23 +15,19 @@ path = "src/lib.rs"
ammonia = "4"
ansi-to-html = "0.2"
anyhow = "1"
# apalis v0.6 fixes this issue: https://github.com/geofmureithi/apalis/issues/351
apalis = { version = "0.6.0-rc.8", features = ["retry"] }
apalis-cron = "0.6.0-rc.8"
apalis-redis = "0.6.0-rc.8"
async-trait = "0.1"
axum = { version = "0.7", features = ["form", "multipart", "query"] }
axum-client-ip = "0.6"
axum-extra = { version = "0.9", features = ["typed-header"] }
axum-login = "0.16"
axum-login = "0.15"
base64 = "0.22"
bytes = "1.4"
# TODO: replace chrono with time
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4.4", features = ["derive", "env"] }
console-subscriber = "0.2"
dotenvy = "0.15"
feed-rs = "2.1"
fred = "9"
feed-rs = "1.3"
futures = "0.3"
headers = "0.4"
http = "1.0.0"
@@ -41,12 +37,10 @@ lettre = { version = "0.11", features = ["builder"] }
maud = { git = "https://github.com/vidhanio/maud", branch = "patch-1", features = [
"axum",
] }
# upgrading this to > 6 causes infinite reloads with tower-livereload
notify = "6"
once_cell = "1.18"
opml = "1.1"
password-auth = "1.0"
rand = { version = "0.8", features = ["small_rng"] }
readability = "0.3"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
@@ -60,21 +54,21 @@ sqlx = { version = "0.7", features = [
"uuid",
"ipnetwork",
] }
thiserror = "2"
thiserror = "1"
time = "0.3"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
tower = { version = "0.5", features = ["retry"] }
tower = "0.4"
tower-livereload = "0.9"
tower-http = { version = "0.6", features = ["trace", "fs"] }
tower-sessions = { version = "0.13", features = ["signed"] }
tower-sessions-redis-store = "0.14"
tower-http = { version = "0.5", features = ["trace", "fs"] }
tower-sessions = { version = "0.12", features = ["signed"] }
tower-sessions-redis-store = "0.12"
tracing = { version = "0.1", features = ["valuable", "attributes"] }
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.4", features = ["serde"] }
url = "2.4"
validator = { version = "0.19", features = ["derive"] }
validator = { version = "0.18", features = ["derive"] }
[profile.dev.package.sqlx-macros]
opt-level = 3

View File

@@ -56,7 +56,7 @@ Install these requirements to get started developing crawlnicle.
directory with the contents:
```env
RUST_LOG=crawlnicle=debug,cli=debug,web=debug,worker=debug,crawler=debug,lib=debug,tower_http=debug,sqlx=debug
RUST_LOG=crawlnicle=debug,cli=debug,lib=debug,tower_http=debug,sqlx=debug
HOST=127.0.0.1
PORT=3000
PUBLIC_URL=http://localhost:3000

Binary file not shown.

View File

@@ -1,5 +1,10 @@
import htmx from 'htmx.org';
// import assets so they get named with a content hash that busts caches
// import '../css/styles.css';
import './localTimeController';
declare global {
interface Window {
htmx: typeof htmx;
@@ -8,16 +13,5 @@ declare global {
window.htmx = htmx;
// Wait for htmx to be fully initialized before loading extensions
document.addEventListener(
'htmx:load',
() => {
import('htmx.org/dist/ext/sse');
},
{ once: true }
);
// import assets so they get named with a content hash that busts caches
// import '../css/styles.css';
import './localTimeController';
// eslint-disable-next-line import/first
import 'htmx.org/dist/ext/sse';

View File

@@ -2,21 +2,21 @@
"name": "crawlnicle-frontend",
"module": "js/index.ts",
"devDependencies": {
"@tailwindcss/forms": "^0.5.10",
"@tailwindcss/typography": "^0.5.16",
"@typescript-eslint/eslint-plugin": "^7.18.0",
"@typescript-eslint/parser": "^7.18.0",
"bun-types": "^1.2.2",
"eslint": "^9.20.0",
"@tailwindcss/forms": "^0.5.7",
"@tailwindcss/typography": "^0.5.13",
"@typescript-eslint/eslint-plugin": "^7.8.0",
"@typescript-eslint/parser": "^7.8.0",
"bun-types": "^1.1.8",
"eslint": "^9.2.0",
"eslint-config-prettier": "^9.1.0",
"eslint-config-standard-with-typescript": "latest",
"eslint-plugin-import": "^2.31.0",
"eslint-plugin-n": "^17.15.1",
"eslint-plugin-prettier": "^5.2.3",
"eslint-plugin-promise": "^6.6.0",
"prettier": "^3.4.2",
"tailwindcss": "^3.4.17",
"typescript": "^5.7.3"
"eslint-plugin-import": "^2.29.1",
"eslint-plugin-n": "^17.6.0",
"eslint-plugin-prettier": "^5.1.3",
"eslint-plugin-promise": "^6.1.1",
"prettier": "^3.2.5",
"tailwindcss": "^3.4.3",
"typescript": "^5.4.5"
},
"peerDependencies": {
"typescript": "^5.0.0"

View File

@@ -1,8 +1,8 @@
import plugin from 'tailwindcss/plugin';
const plugin = require('tailwindcss/plugin');
/** @type {import('tailwindcss').Config} */
export default {
content: ['../src/**/*.rs'],
content: ['./src/**/*.rs'],
theme: {
extend: {},
},

View File

@@ -7,22 +7,14 @@ install-frontend:
clean-frontend:
rm -rf ./static/js/* ./static/css/* ./static/img/*
[working-directory: 'frontend']
build-css:
bunx tailwindcss -i css/styles.css -o ../static/css/styles.css --minify
[working-directory: 'frontend']
build-dev-css:
bunx tailwindcss -i css/styles.css -o ../static/css/styles.css
build-frontend: clean-frontend build-css
build-frontend: clean-frontend
bunx tailwindcss -i frontend/css/styles.css -o static/css/styles.css --minify
bun build frontend/js/index.ts \
--outdir ./static \
--root ./frontend \
--entry-naming [dir]/[name]-[hash].[ext] \
--chunk-naming [dir]/[name]-[hash].[ext] \
--asset-naming [dir]/[name]-[hash].[ext] \
--sourcemap=linked \
--minify
mkdir -p static/img
cp frontend/img/* static/img/
@@ -30,14 +22,14 @@ build-frontend: clean-frontend build-css
touch ./static/css/manifest.txt # create empty manifest to be overwritten by build.rs
touch .frontend-built # trigger build.rs to run
build-dev-frontend: clean-frontend build-dev-css
build-dev-frontend: clean-frontend
bunx tailwindcss -i frontend/css/styles.css -o static/css/styles.css
bun build frontend/js/index.ts \
--outdir ./static \
--root ./frontend \
--entry-naming [dir]/[name]-[hash].[ext] \
--chunk-naming [dir]/[name]-[hash].[ext] \
--asset-naming [dir]/[name]-[hash].[ext] \
--sourcemap=linked
--asset-naming [dir]/[name]-[hash].[ext]
mkdir -p static/img
cp frontend/img/* static/img/
touch ./static/js/manifest.txt # create empty manifest needed so binary compiles

View File

@@ -1,5 +0,0 @@
{
"files": {
"excludeDirs": ["frontend"]
}
}

View File

@@ -11,7 +11,7 @@ use uuid::Uuid;
use crate::actors::crawl_scheduler::{CrawlSchedulerHandle, CrawlSchedulerHandleMessage};
use crate::error::Error;
use crate::models::feed::{CreateFeed, Feed};
use crate::models::feed::{Feed, CreateFeed};
use crate::state::Imports;
use crate::uuid::Base62Uuid;

View File

@@ -97,7 +97,7 @@ pub async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let db = PgPoolOptions::new()
let pool = PgPoolOptions::new()
.max_connections(env::var("DATABASE_MAX_CONNECTIONS")?.parse()?)
.connect(&env::var("DATABASE_URL")?)
.await?;
@@ -108,7 +108,7 @@ pub async fn main() -> Result<()> {
match cli.commands {
Commands::AddFeed(args) => {
let feed = Feed::create(
&db,
&pool,
CreateFeed {
title: args.title,
url: args.url,
@@ -119,12 +119,12 @@ pub async fn main() -> Result<()> {
info!("Created feed with id {}", Base62Uuid::from(feed.feed_id));
}
Commands::DeleteFeed(args) => {
Feed::delete(&db, args.id).await?;
Feed::delete(&pool, args.id).await?;
info!("Deleted feed with id {}", Base62Uuid::from(args.id));
}
Commands::AddEntry(args) => {
let entry = Entry::create(
&db,
&pool,
CreateEntry {
title: args.title,
url: args.url,
@@ -137,7 +137,7 @@ pub async fn main() -> Result<()> {
info!("Created entry with id {}", Base62Uuid::from(entry.entry_id));
}
Commands::DeleteEntry(args) => {
Entry::delete(&db, args.id).await?;
Entry::delete(&pool, args.id).await?;
info!("Deleted entry with id {}", Base62Uuid::from(args.id));
}
Commands::Crawl(CrawlFeed { id }) => {
@@ -147,7 +147,7 @@ pub async fn main() -> Result<()> {
// server is running, it will *not* serialize same-domain requests with it.
let domain_locks = DomainLocks::new();
let feed_crawler = FeedCrawlerHandle::new(
db.clone(),
pool.clone(),
client.clone(),
domain_locks.clone(),
env::var("CONTENT_DIR")?,

View File

@@ -1,118 +0,0 @@
use apalis::layers::retry::{RetryLayer, RetryPolicy};
use apalis::layers::tracing::TraceLayer;
use apalis::prelude::*;
use apalis_cron::{CronStream, Schedule};
use apalis_redis::RedisStorage;
use chrono::{DateTime, Utc};
use clap::Parser;
use dotenvy::dotenv;
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use std::str::FromStr;
use std::sync::Arc;
use thiserror::Error;
use tracing::{info, instrument};
use lib::config::Config;
use lib::jobs::{AsyncJob, CrawlFeedJob};
use lib::log::init_worker_tracing;
use lib::models::feed::{Feed, GetFeedsOptions};
#[derive(Default, Debug, Clone)]
struct Crawl(DateTime<Utc>);
impl From<DateTime<Utc>> for Crawl {
fn from(t: DateTime<Utc>) -> Self {
Crawl(t)
}
}
#[derive(Debug, Error)]
enum CrawlError {
#[error("error fetching feeds")]
FetchFeedsError(#[from] sqlx::Error),
#[error("error queueing crawl feed job")]
QueueJobError(String),
}
#[derive(Clone)]
struct State {
pool: PgPool,
apalis: RedisStorage<AsyncJob>,
}
#[instrument(skip_all)]
pub async fn crawl_fn(job: Crawl, state: Data<Arc<State>>) -> Result<(), CrawlError> {
tracing::info!(job = ?job, "crawl");
let mut apalis = (state.apalis).clone();
let mut options = GetFeedsOptions::default();
loop {
info!("fetching feeds before: {:?}", options.before);
// TODO: filter to feeds where:
// now >= feed.last_crawled_at + feed.crawl_interval_minutes
// may need more indices...
let feeds = match Feed::get_all(&state.pool, &options).await {
Err(err) => return Err(CrawlError::FetchFeedsError(err)),
Ok(feeds) if feeds.is_empty() => {
info!("no more feeds found");
break;
}
Ok(feeds) => feeds,
};
info!("found {} feeds", feeds.len());
options.before = feeds.last().map(|f| f.created_at);
for feed in feeds.into_iter() {
// self.spawn_crawler_loop(feed, respond_to.clone());
// TODO: implement uniqueness on jobs per feed for ~1 minute
apalis
.push(AsyncJob::CrawlFeed(CrawlFeedJob {
feed_id: feed.feed_id,
}))
.await
.map_err(|err| CrawlError::QueueJobError(err.to_string()))?;
}
}
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenv().ok();
let config = Config::parse();
let _guard = init_worker_tracing()?;
let pool = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
.await?;
// TODO: create connection from redis_pool for each job instead using a single connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let apalis_storage = RedisStorage::new_with_config(redis_conn, apalis_config);
let state = Arc::new(State {
pool,
apalis: apalis_storage.clone(),
});
let schedule = Schedule::from_str("0 * * * * *").unwrap();
let worker = WorkerBuilder::new("crawler")
.layer(RetryLayer::new(RetryPolicy::default()))
.layer(TraceLayer::new())
.data(state)
.backend(CronStream::new(schedule))
.build_fn(crawl_fn);
Monitor::<TokioExecutor>::new()
.register(worker)
.run()
.await
.unwrap();
Ok(())
}

View File

@@ -1,61 +0,0 @@
use anyhow::Result;
use apalis::layers::retry::RetryPolicy;
use apalis::layers::tracing::TraceLayer;
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use clap::Parser;
use dotenvy::dotenv;
use fred::prelude::*;
use reqwest::Client;
use sqlx::postgres::PgPoolOptions;
use tower::retry::RetryLayer;
use lib::config::Config;
use lib::domain_request_limiter::DomainRequestLimiter;
use lib::jobs::{handle_async_job, AsyncJob};
use lib::log::init_worker_tracing;
use lib::USER_AGENT;
#[tokio::main]
async fn main() -> Result<()> {
dotenv().ok();
let config = Config::parse();
let _guard = init_worker_tracing()?;
// TODO: create connection from redis_pool for each job instead using a single connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let apalis_storage: RedisStorage<AsyncJob> =
RedisStorage::new_with_config(redis_conn, apalis_config);
let redis_config = RedisConfig::from_url(&config.redis_url)?;
let redis_pool = RedisPool::new(redis_config, None, None, None, 5)?;
redis_pool.init().await?;
let domain_request_limiter = DomainRequestLimiter::new(redis_pool.clone(), 10, 5, 100, 0.5);
let http_client = Client::builder().user_agent(USER_AGENT).build()?;
let db = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
.await?;
Monitor::<TokioExecutor>::new()
.register_with_count(2, {
WorkerBuilder::new("worker")
.layer(RetryLayer::new(RetryPolicy::default()))
.layer(TraceLayer::new())
.data(http_client)
.data(db)
.data(domain_request_limiter)
.data(config)
.data(apalis_storage.clone())
.data(redis_pool)
.backend(apalis_storage)
.build_fn(handle_async_job)
})
.run()
.await
.unwrap();
Ok(())
}

View File

@@ -33,10 +33,10 @@ impl DomainLocks {
}
/// Run the passed function `f` while holding a lock that gives exclusive access to the passed
/// domain. If another task running `run_request` currently has the lock to the
/// domain. If another task running `run_request` currently has the lock to the
/// `DomainLocksMap` or the lock to the domain passed, then this function will wait until that
/// other task is done. Once it has access to the lock, if it has been less than one second
/// since the last request to the domain, then this function will sleep until one second has
/// other task is done. Once it has access to the lock, if it has been less than one second
/// since the last request to the domain, then this function will sleep until one second has
/// passed before calling `f`.
pub async fn run_request<F, T>(&self, domain: &str, f: F) -> T
where

View File

@@ -1,123 +0,0 @@
use anyhow::{anyhow, Result};
use fred::{clients::RedisPool, interfaces::KeysInterface, prelude::*};
use rand::{rngs::SmallRng, Rng, SeedableRng};
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};
/// A Redis-based rate limiter for domain-specific requests with jittered retry delay.
///
/// This limiter uses a fixed window algorithm with a 1-second window and applies
/// jitter to the retry delay to help prevent synchronized retries in distributed systems.
/// It uses fred's RedisPool for efficient connection management.
///
/// Limitations:
/// 1. Fixed window: The limit resets every second, potentially allowing short traffic bursts
/// at window boundaries.
/// 2. No token bucket: Doesn't accumulate unused capacity from quiet periods.
/// 3. Potential overcounting: In distributed systems, there's a small chance of overcounting
/// near window ends due to race conditions.
/// 4. Redis dependency: Rate limiting fails open if Redis is unavailable.
/// 5. Blocking: The acquire method will block until a request is allowed or max_retries is reached.
///
/// Usage example:
/// ```
/// use fred::prelude::*;
///
/// #[tokio::main]
/// async fn main() -> Result<()> {
/// let config = RedisConfig::default();
/// let pool = RedisPool::new(config, None, None, 5)?;
/// pool.connect();
/// pool.wait_for_connect().await?;
///
/// let limiter = DomainRequestLimiter::new(pool, 10, 5, 100, 0.5);
/// let domain = "example.com";
///
/// for _ in 0..15 {
/// match limiter.acquire(domain).await {
/// Ok(()) => println!("Request allowed"),
/// Err(_) => println!("Max retries reached, request denied"),
/// }
/// }
///
/// Ok(())
/// }
/// ```
#[derive(Debug, Clone)]
pub struct DomainRequestLimiter {
redis_pool: RedisPool,
requests_per_second: u32,
max_retries: u32,
base_retry_delay_ms: u64,
jitter_factor: f64,
// TODO: I think I can get rid of this if I instantiate a DomainRequestLimiter per-worker, but
// I'm not sure how to do that in apalis (then I could just use thread_rng)
rng: Arc<Mutex<SmallRng>>,
}
impl DomainRequestLimiter {
/// Create a new DomainRequestLimiter.
///
/// # Arguments
/// * `redis_pool` - A fred RedisPool.
/// * `requests_per_second` - Maximum allowed requests per second per domain.
/// * `max_retries` - Maximum number of retries before giving up.
/// * `base_retry_delay_ms` - Base delay between retries in milliseconds.
/// * `jitter_factor` - Factor to determine the maximum jitter (0.0 to 1.0).
pub fn new(
redis_pool: RedisPool,
requests_per_second: u32,
max_retries: u32,
base_retry_delay_ms: u64,
jitter_factor: f64,
) -> Self {
Self {
redis_pool,
requests_per_second,
max_retries,
base_retry_delay_ms,
jitter_factor: jitter_factor.clamp(0.0, 1.0),
rng: Arc::new(Mutex::new(SmallRng::from_entropy())),
}
}
/// Attempt to acquire permission for a request, retrying if necessary.
///
/// This method will attempt to acquire permission up to max_retries times,
/// sleeping for a jittered delay between each attempt.
///
/// # Arguments
/// * `domain` - The domain for which to check the rate limit.
///
/// # Returns
/// Ok(()) if permission is granted, or an error if max retries are exceeded.
pub async fn acquire(&self, domain: &str) -> Result<()> {
for attempt in 0..=self.max_retries {
if self.try_acquire(domain).await? {
return Ok(());
}
if attempt < self.max_retries {
let mut rng = self.rng.lock().await;
let jitter =
rng.gen::<f64>() * self.jitter_factor * self.base_retry_delay_ms as f64;
let delay = self.base_retry_delay_ms + jitter as u64;
sleep(Duration::from_millis(delay)).await;
}
}
Err(anyhow!(
"Max retries exceeded for domain: {:?}, request denied",
domain
))
}
async fn try_acquire(&self, domain: &str) -> Result<bool, RedisError> {
let key = format!("rate_limit:{}", domain);
let count: u32 = self.redis_pool.incr(&key).await?;
if count == 1 {
self.redis_pool.expire(&key, 1).await?;
}
Ok(count <= self.requests_per_second)
}
}

View File

@@ -27,7 +27,7 @@ pub enum Error {
#[error("validation error in request body")]
InvalidEntity(#[from] ValidationErrors),
#[error("error with file upload")]
#[error("error with file upload: (0)")]
Upload(#[from] MultipartError),
#[error("no file uploaded")]
@@ -49,7 +49,7 @@ pub enum Error {
Unauthorized,
#[error("bad request: {0}")]
BadRequest(&'static str),
BadRequest(&'static str)
}
pub type Result<T, E = Error> = ::std::result::Result<T, E>;

View File

@@ -13,9 +13,9 @@ use crate::partials::entry_list::entry_list;
pub async fn get(
Query(options): Query<GetEntriesOptions>,
accept: Option<TypedHeader<Accept>>,
State(db): State<PgPool>,
State(pool): State<PgPool>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let entries = Entry::get_all(&db, &options).await.map_err(Error::from)?;
let entries = Entry::get_all(&pool, &options).await.map_err(Error::from)?;
if let Some(TypedHeader(accept)) = accept {
if accept == Accept::ApplicationJson {
return Ok::<ApiResponse<Vec<Entry>>, Error>(ApiResponse::Json(entries));

View File

@@ -9,15 +9,15 @@ use crate::models::entry::{CreateEntry, Entry};
use crate::uuid::Base62Uuid;
pub async fn get(
State(db): State<PgPool>,
State(pool): State<PgPool>,
Path(id): Path<Base62Uuid>,
) -> Result<Json<Entry>, Error> {
Ok(Json(Entry::get(&db, id.as_uuid()).await?))
Ok(Json(Entry::get(&pool, id.as_uuid()).await?))
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
Json(payload): Json<CreateEntry>,
) -> Result<Json<Entry>, Error> {
Ok(Json(Entry::create(&db, payload).await?))
Ok(Json(Entry::create(&pool, payload).await?))
}

View File

@@ -8,17 +8,17 @@ use crate::error::{Error, Result};
use crate::models::feed::{CreateFeed, Feed};
use crate::uuid::Base62Uuid;
pub async fn get(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Json<Feed>> {
Ok(Json(Feed::get(&db, id.as_uuid()).await?))
pub async fn get(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Json<Feed>> {
Ok(Json(Feed::get(&pool, id.as_uuid()).await?))
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
Json(payload): Json<CreateFeed>,
) -> Result<Json<Feed>, Error> {
Ok(Json(Feed::create(&db, payload).await?))
Ok(Json(Feed::create(&pool, payload).await?))
}
pub async fn delete(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<()> {
Feed::delete(&db, id.as_uuid()).await
pub async fn delete(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<()> {
Feed::delete(&pool, id.as_uuid()).await
}

View File

@@ -13,9 +13,9 @@ use crate::partials::feed_list::feed_list;
pub async fn get(
Query(options): Query<GetFeedsOptions>,
accept: Option<TypedHeader<Accept>>,
State(db): State<PgPool>,
State(pool): State<PgPool>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let feeds = Feed::get_all(&db, &options).await.map_err(Error::from)?;
let feeds = Feed::get_all(&pool, &options).await.map_err(Error::from)?;
if let Some(TypedHeader(accept)) = accept {
if accept == Accept::ApplicationJson {
return Ok::<ApiResponse<Vec<Feed>>, Error>(ApiResponse::Json(feeds));

View File

@@ -70,7 +70,7 @@ pub fn confirm_email_page(
}
pub async fn get(
State(db): State<PgPool>,
State(pool): State<PgPool>,
auth: AuthSession,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
@@ -78,7 +78,7 @@ pub async fn get(
) -> Result<Response> {
if let Some(token_id) = query.token_id {
info!(token_id = %token_id.as_uuid(), "get with token_id");
let token = match UserEmailVerificationToken::get(&db, token_id.as_uuid()).await {
let token = match UserEmailVerificationToken::get(&pool, token_id.as_uuid()).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -112,8 +112,8 @@ pub async fn get(
}))
} else {
info!(token_id = %token.token_id, "token valid, verifying email");
User::verify_email(&db, token.user_id).await?;
UserEmailVerificationToken::delete(&db, token.token_id).await?;
User::verify_email(&pool, token.user_id).await?;
UserEmailVerificationToken::delete(&pool, token.token_id).await?;
Ok(layout
.with_subtitle("confirm email")
.targeted(hx_target)
@@ -152,7 +152,7 @@ pub struct ConfirmEmail {
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
hx_target: Option<TypedHeader<HXTarget>>,
@@ -161,11 +161,11 @@ pub async fn post(
) -> Result<Response> {
if let Some(token_id) = confirm_email.token {
info!(%token_id, "posted with token_id");
let token = UserEmailVerificationToken::get(&db, token_id).await?;
let user = User::get(&db, token.user_id).await?;
let token = UserEmailVerificationToken::get(&pool, token_id).await?;
let user = User::get(&pool, token.user_id).await?;
if !user.email_verified {
info!(user_id = %user.user_id, "user exists, resending confirmation email");
send_confirmation_email(db, mailer, config, user);
send_confirmation_email(pool, mailer, config, user);
} else {
warn!(user_id = %user.user_id, "confirm email submitted for already verified user, skip resend");
}
@@ -184,10 +184,10 @@ pub async fn post(
}));
}
if let Some(email) = confirm_email.email {
if let Ok(user) = User::get_by_email(&db, email).await {
if let Ok(user) = User::get_by_email(&pool, email).await {
if !user.email_verified {
info!(user_id = %user.user_id, "user exists, resending confirmation email");
send_confirmation_email(db, mailer, config, user);
send_confirmation_email(pool, mailer, config, user);
} else {
warn!(user_id = %user.user_id, "confirm email submitted for already verified user, skip resend");
}

View File

@@ -8,8 +8,8 @@ use crate::partials::entry_list::entry_list;
pub async fn get(
Query(options): Query<GetEntriesOptions>,
State(db): State<PgPool>,
State(pool): State<PgPool>,
) -> Result<Markup> {
let entries = Entry::get_all(&db, &options).await?;
let entries = Entry::get_all(&pool, &options).await?;
Ok(entry_list(entries, &options, false))
}

View File

@@ -16,12 +16,12 @@ use crate::uuid::Base62Uuid;
pub async fn get(
Path(id): Path<Base62Uuid>,
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(config): State<Config>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let entry = Entry::get(&db, id.as_uuid()).await?;
let entry = Entry::get(&pool, id.as_uuid()).await?;
let content_dir = std::path::Path::new(&config.content_dir);
let content_path = content_dir.join(format!("{}.html", entry.entry_id));
let title = entry.title.unwrap_or_else(|| "Untitled Entry".to_string());

View File

@@ -28,17 +28,17 @@ use crate::uuid::Base62Uuid;
pub async fn get(
Path(id): Path<Base62Uuid>,
State(db): State<PgPool>,
State(pool): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let feed = Feed::get(&db, id.as_uuid()).await?;
let feed = Feed::get(&pool, id.as_uuid()).await?;
let options = GetEntriesOptions {
feed_id: Some(feed.feed_id),
..Default::default()
};
let title = feed.title.unwrap_or_else(|| "Untitled Feed".to_string());
let entries = Entry::get_all(&db, &options).await?;
let entries = Entry::get_all(&pool, &options).await?;
let delete_url = format!("/feed/{}/delete", id);
Ok(layout.with_subtitle(&title).targeted(hx_target).render(html! {
header class="mb-4 flex flex-row items-center gap-4" {
@@ -115,13 +115,13 @@ impl IntoResponse for AddFeedError {
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(crawls): State<Crawls>,
State(crawl_scheduler): State<CrawlSchedulerHandle>,
Form(add_feed): Form<AddFeed>,
) -> AddFeedResult<Response> {
let feed = Feed::create(
&db,
&pool,
CreateFeed {
title: add_feed.title,
url: add_feed.url.clone(),
@@ -233,7 +233,7 @@ pub async fn stream(
))
}
pub async fn delete(State(db): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Redirect> {
Feed::delete(&db, id.as_uuid()).await?;
pub async fn delete(State(pool): State<PgPool>, Path(id): Path<Base62Uuid>) -> Result<Redirect> {
Feed::delete(&pool, id.as_uuid()).await?;
Ok(Redirect::to("/feeds"))
}

View File

@@ -13,12 +13,12 @@ use crate::partials::layout::Layout;
use crate::partials::opml_import_form::opml_import_form;
pub async fn get(
State(db): State<PgPool>,
State(pool): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let options = GetFeedsOptions::default();
let feeds = Feed::get_all(&db, &options).await?;
let feeds = Feed::get_all(&pool, &options).await?;
Ok(layout
.with_subtitle("feeds")
.targeted(hx_target)

View File

@@ -82,7 +82,7 @@ pub async fn get(
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
SecureClientIp(ip): SecureClientIp,
@@ -91,7 +91,7 @@ pub async fn post(
layout: Layout,
Form(forgot_password): Form<ForgotPassword>,
) -> Result<Response> {
let user: User = match User::get_by_email(&db, forgot_password.email.clone()).await {
let user: User = match User::get_by_email(&pool, forgot_password.email.clone()).await {
Ok(user) => user,
Err(err) => {
if let Error::NotFoundString(_, _) = err {
@@ -105,7 +105,7 @@ pub async fn post(
if user.email_verified {
info!(user_id = %user.user_id, "user exists with verified email, sending password reset email");
send_forgot_password_email(
db,
pool,
mailer,
config,
user,

View File

@@ -10,12 +10,12 @@ use crate::models::entry::Entry;
use crate::partials::{entry_list::entry_list, layout::Layout};
pub async fn get(
State(db): State<PgPool>,
State(pool): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
) -> Result<Response> {
let options = Default::default();
let entries = Entry::get_all(&db, &options).await?;
let entries = Entry::get_all(&pool, &options).await?;
Ok(layout.targeted(hx_target).render(html! {
ul class="list-none flex flex-col gap-4" {
(entry_list(entries, &options, true))

View File

@@ -59,7 +59,7 @@ pub async fn get(hx_target: Option<TypedHeader<HXTarget>>, layout: Layout) -> Re
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
mut auth: AuthSession,
@@ -80,7 +80,7 @@ pub async fn post(
));
}
let user = match User::create(
&db,
&pool,
CreateUser {
email: register.email.clone(),
password: register.password.clone(),
@@ -144,7 +144,7 @@ pub async fn post(
}
};
send_confirmation_email(db, mailer, config, user.clone());
send_confirmation_email(pool, mailer, config, user.clone());
auth.login(&user)
.await

View File

@@ -126,14 +126,14 @@ pub fn reset_password_page(
}
pub async fn get(
State(db): State<PgPool>,
State(pool): State<PgPool>,
hx_target: Option<TypedHeader<HXTarget>>,
layout: Layout,
query: Query<ResetPasswordQuery>,
) -> Result<Response> {
if let Some(token_id) = query.token_id {
info!(token_id = %token_id.as_uuid(), "get with token_id");
let token = match UserPasswordResetToken::get(&db, token_id.as_uuid()).await {
let token = match UserPasswordResetToken::get(&pool, token_id.as_uuid()).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -158,7 +158,7 @@ pub async fn get(
}))
} else {
info!(token_id = %token.token_id, "token valid, showing reset password form");
let user = User::get(&db, token.user_id).await?;
let user = User::get(&pool, token.user_id).await?;
Ok(reset_password_page(ResetPasswordPageProps {
hx_target,
layout,
@@ -181,7 +181,7 @@ pub async fn get(
}
pub async fn post(
State(db): State<PgPool>,
State(pool): State<PgPool>,
State(mailer): State<SmtpTransport>,
State(config): State<Config>,
SecureClientIp(ip): SecureClientIp,
@@ -203,7 +203,7 @@ pub async fn post(
..Default::default()
}));
}
let token = match UserPasswordResetToken::get(&db, reset_password.token).await {
let token = match UserPasswordResetToken::get(&pool, reset_password.token).await {
Ok(token) => token,
Err(err) => {
if let Error::NotFoundUuid(_, _) = err {
@@ -241,7 +241,7 @@ pub async fn post(
..Default::default()
}));
}
let user = match User::get(&db, token.user_id).await {
let user = match User::get(&pool, token.user_id).await {
Ok(user) => user,
Err(err) => {
if let Error::NotFoundString(_, _) = err {
@@ -266,7 +266,7 @@ pub async fn post(
}
};
info!(user_id = %user.user_id, "user exists with verified email, resetting password");
let mut tx = db.begin().await?;
let mut tx = pool.begin().await?;
UserPasswordResetToken::delete(tx.as_mut(), reset_password.token).await?;
let user = match user
.update_password(

View File

@@ -1,68 +0,0 @@
use std::fs;
use std::path::Path;
use ammonia::clean;
use anyhow::{anyhow, Result};
use apalis::prelude::*;
use bytes::Buf;
use fred::prelude::*;
use readability::extractor;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{info, instrument};
use url::Url;
use uuid::Uuid;
use crate::config::Config;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::models::entry::Entry;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CrawlEntryJob {
pub entry_id: Uuid,
}
#[instrument(skip_all, fields(entry_id = %entry_id))]
pub async fn crawl_entry(
CrawlEntryJob { entry_id }: CrawlEntryJob,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
config: Data<Config>,
redis: Data<RedisPool>,
) -> Result<()> {
let entry = Entry::get(&*db, entry_id).await?;
info!("got entry from db");
let content_dir = Path::new(&*config.content_dir);
let url = Url::parse(&entry.url)?;
let domain = url
.domain()
.ok_or(anyhow!("invalid url: {:?}", entry.url.clone()))?;
info!(url=%url, "starting fetch");
domain_request_limiter.acquire(domain).await?;
let bytes = http_client.get(url.clone()).send().await?.bytes().await?;
info!(url=%url, "fetched entry");
let article = extractor::extract(&mut bytes.reader(), &url)?;
info!("extracted content");
let id = entry.entry_id;
// TODO: update entry with scraped data
// if let Some(date) = article.date {
// // prefer scraped date over rss feed date
// let mut updated_entry = entry.clone();
// updated_entry.published_at = date;
// entry = update_entry(&self.pool, updated_entry)
// .await
// .map_err(|_| EntryCrawlerError::CreateEntryError(entry.url.clone()))?;
// };
let content = clean(&article.content);
info!("sanitized content");
fs::write(content_dir.join(format!("{}.html", id)), content)?;
fs::write(content_dir.join(format!("{}.txt", id)), article.text)?;
info!("saved content to filesystem");
redis
.next()
.publish("entries", entry_id.to_string())
.await?;
Ok(())
}

View File

@@ -1,189 +0,0 @@
use std::cmp::Ordering;
use anyhow::{anyhow, Result};
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use chrono::{Duration, Utc};
use feed_rs::parser;
use fred::prelude::*;
use http::{header, HeaderMap, StatusCode};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{info, info_span, instrument, warn};
use url::Url;
use uuid::Uuid;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::jobs::{AsyncJob, CrawlEntryJob};
use crate::models::entry::{CreateEntry, Entry};
use crate::models::feed::{Feed, MAX_CRAWL_INTERVAL_MINUTES, MIN_CRAWL_INTERVAL_MINUTES};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct CrawlFeedJob {
pub feed_id: Uuid,
}
#[instrument(skip_all, fields(feed_id = %feed_id))]
pub async fn crawl_feed(
CrawlFeedJob { feed_id }: CrawlFeedJob,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<()> {
let mut feed = Feed::get(&*db, feed_id).await?;
info!("got feed from db");
let url = Url::parse(&feed.url)?;
let domain = url
.domain()
.ok_or(anyhow!("invalid url: {:?}", feed.url.clone()))?;
let mut headers = HeaderMap::new();
if let Some(etag) = &feed.etag_header {
if let Ok(etag) = etag.parse() {
headers.insert(header::IF_NONE_MATCH, etag);
} else {
warn!(%etag, "failed to parse saved etag header");
}
}
if let Some(last_modified) = &feed.last_modified_header {
if let Ok(last_modified) = last_modified.parse() {
headers.insert(header::IF_MODIFIED_SINCE, last_modified);
} else {
warn!(
%last_modified,
"failed to parse saved last_modified header",
);
}
}
info!(url=%url, "starting fetch");
domain_request_limiter.acquire(domain).await?;
let resp = http_client.get(url.clone()).headers(headers).send().await?;
let headers = resp.headers();
if let Some(etag) = headers.get(header::ETAG) {
if let Ok(etag) = etag.to_str() {
feed.etag_header = Some(etag.to_string());
} else {
warn!(?etag, "failed to convert response etag header to string");
}
}
if let Some(last_modified) = headers.get(header::LAST_MODIFIED) {
if let Ok(last_modified) = last_modified.to_str() {
feed.last_modified_header = Some(last_modified.to_string());
} else {
warn!(
?last_modified,
"failed to convert response last_modified header to string",
);
}
}
info!(url=%url, "fetched feed");
if resp.status() == StatusCode::NOT_MODIFIED {
info!("feed returned not modified status");
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = None;
feed.save(&*db).await?;
info!("updated feed in db");
return Ok(());
} else if !resp.status().is_success() {
warn!("feed returned non-successful status");
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = resp.status().canonical_reason().map(|s| s.to_string());
feed.save(&*db).await?;
info!("updated feed in db");
return Ok(());
}
let bytes = resp.bytes().await?;
let parsed_feed = parser::parse(&bytes[..])?;
info!("parsed feed");
feed.url = url.to_string();
feed.feed_type = parsed_feed.feed_type.into();
feed.last_crawled_at = Some(Utc::now());
feed.last_crawl_error = None;
if let Some(title) = parsed_feed.title {
feed.title = Some(title.content);
}
if let Some(description) = parsed_feed.description {
feed.description = Some(description.content);
}
let last_entry_published_at = parsed_feed.entries.iter().filter_map(|e| e.published).max();
if let Some(prev_last_entry_published_at) = feed.last_entry_published_at {
if let Some(published_at) = last_entry_published_at {
let time_since_last_entry = if published_at == prev_last_entry_published_at {
// No new entry since last crawl, compare current time to last publish instead
Utc::now() - prev_last_entry_published_at
} else {
// Compare new entry publish time to previous publish time
published_at - prev_last_entry_published_at
};
match time_since_last_entry.cmp(&Duration::minutes(feed.crawl_interval_minutes.into()))
{
Ordering::Greater => {
feed.crawl_interval_minutes = i32::max(
(feed.crawl_interval_minutes as f32 * 1.2).ceil() as i32,
MAX_CRAWL_INTERVAL_MINUTES,
);
info!(
interval = feed.crawl_interval_minutes,
"increased crawl interval"
);
}
Ordering::Less => {
feed.crawl_interval_minutes = i32::max(
(feed.crawl_interval_minutes as f32 / 1.2).ceil() as i32,
MIN_CRAWL_INTERVAL_MINUTES,
);
info!(
interval = feed.crawl_interval_minutes,
"decreased crawl interval"
);
}
Ordering::Equal => {}
}
}
}
feed.last_entry_published_at = last_entry_published_at;
let feed = feed.save(&*db).await?;
info!("updated feed in db");
let mut payload = Vec::with_capacity(parsed_feed.entries.len());
for entry in parsed_feed.entries {
let entry_span = info_span!("entry", id = entry.id);
let _entry_span_guard = entry_span.enter();
if let Some(link) = entry.links.first() {
// if no scraped or feed date is available, fallback to the current time
let published_at = entry.published.unwrap_or_else(Utc::now);
let entry = CreateEntry {
title: entry.title.map(|t| t.content),
url: link.href.clone(),
description: entry.summary.map(|s| s.content),
feed_id: feed.feed_id,
published_at,
};
payload.push(entry);
} else {
warn!("skipping feed entry with no links");
}
}
let entries = Entry::bulk_upsert(&*db, payload).await?;
let (new, updated) = entries
.into_iter()
.partition::<Vec<_>, _>(|entry| entry.updated_at.is_none());
info!(new = new.len(), updated = updated.len(), "saved entries");
for entry in new {
(*apalis)
.clone() // TODO: clone bad?
.push(AsyncJob::CrawlEntry(CrawlEntryJob {
entry_id: entry.entry_id,
}))
.await?;
}
redis.next().publish("feeds", feed_id.to_string()).await?;
Ok(())
}

View File

@@ -1,105 +0,0 @@
use std::io::Cursor;
use anyhow::{anyhow, Context, Result};
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use fred::prelude::*;
use opml::OPML;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use tracing::{error, instrument, warn};
use uuid::Uuid;
use crate::error::Error;
use crate::jobs::crawl_feed::CrawlFeedJob;
use crate::jobs::AsyncJob;
use crate::models::feed::{CreateFeed, Feed};
use crate::uuid::Base62Uuid;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ImportOpmlJob {
pub import_id: Uuid,
pub file_name: Option<String>,
pub bytes: Vec<u8>,
}
// TODO: send messages over redis channel
/// `ImporterOpmlMessage::Import` contains the result of importing the OPML file.
// #[allow(clippy::large_enum_variant)]
// #[derive(Debug, Clone)]
// pub enum ImporterOpmlMessage {
// Import(ImporterResult<()>),
// CreateFeedError(String),
// AlreadyImported(String),
// CrawlScheduler(CrawlSchedulerHandleMessage),
// }
#[instrument(skip_all, fields(import_id = %import_id))]
pub async fn import_opml(
ImportOpmlJob {
import_id,
file_name,
bytes,
}: ImportOpmlJob,
db: Data<PgPool>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<()> {
let document = OPML::from_reader(&mut Cursor::new(bytes)).with_context(|| {
format!(
"Failed to read OPML file for import {} from file {}",
Base62Uuid::from(import_id),
file_name
.map(|n| n.to_string())
.unwrap_or_else(|| "unknown".to_string())
)
})?;
for url in gather_feed_urls(document.body.outlines) {
let feed = Feed::create(
&*db,
CreateFeed {
url: url.clone(),
..Default::default()
},
)
.await;
match feed {
Ok(feed) => {
(*apalis)
.clone()
.push(AsyncJob::CrawlFeed(CrawlFeedJob {
feed_id: feed.feed_id,
}))
.await?;
}
Err(Error::Sqlx(sqlx::error::Error::Database(err))) => {
if err.is_unique_violation() {
// let _ = respond_to.send(ImporterHandleMessage::AlreadyImported(url));
warn!("Feed {} already imported", url);
}
}
Err(err) => {
// let _ = respond_to.send(ImporterHandleMessage::CreateFeedError(url));
error!("Failed to create feed for {}", url);
return Err(anyhow!(err));
}
}
}
redis
.next()
.publish("imports", import_id.to_string())
.await?;
Ok(())
}
fn gather_feed_urls(outlines: Vec<opml::Outline>) -> Vec<String> {
let mut urls = Vec::new();
for outline in outlines.into_iter() {
if let Some(url) = outline.xml_url {
urls.push(url);
}
urls.append(&mut gather_feed_urls(outline.outlines));
}
urls
}

View File

@@ -1,70 +0,0 @@
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use fred::prelude::*;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use thiserror::Error;
use tracing::{error, info, instrument};
mod crawl_entry;
mod crawl_feed;
mod import_opml;
pub use crawl_entry::CrawlEntryJob;
pub use crawl_feed::CrawlFeedJob;
pub use import_opml::ImportOpmlJob;
use crate::{config::Config, domain_request_limiter::DomainRequestLimiter};
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum AsyncJob {
HelloWorld(String),
CrawlFeed(CrawlFeedJob),
CrawlEntry(CrawlEntryJob),
ImportOpml(ImportOpmlJob),
}
#[derive(Debug, Error)]
pub enum AsyncJobError {
#[error("error executing job")]
JobError(#[from] anyhow::Error),
}
#[instrument(skip_all, fields(worker_id = ?worker_id, task_id = ?task_id))]
pub async fn handle_async_job(
job: AsyncJob,
worker_id: Data<WorkerId>,
task_id: Data<TaskId>,
http_client: Data<Client>,
db: Data<PgPool>,
domain_request_limiter: Data<DomainRequestLimiter>,
config: Data<Config>,
apalis: Data<RedisStorage<AsyncJob>>,
redis: Data<RedisPool>,
) -> Result<(), AsyncJobError> {
let result = match job {
AsyncJob::HelloWorld(name) => {
info!("Hello, {}!", name);
Ok(())
}
AsyncJob::CrawlFeed(job) => {
crawl_feed::crawl_feed(job, http_client, db, domain_request_limiter, apalis, redis)
.await
}
AsyncJob::CrawlEntry(job) => {
crawl_entry::crawl_entry(job, http_client, db, domain_request_limiter, config, redis)
.await
}
AsyncJob::ImportOpml(job) => import_opml::import_opml(job, db, apalis, redis).await,
};
match result {
Ok(_) => info!("Job completed successfully"),
Err(err) => {
error!("Job failed: {err:?}");
return Err(AsyncJobError::JobError(err));
}
};
Ok(())
}

View File

@@ -3,12 +3,10 @@ pub mod api_response;
pub mod auth;
pub mod config;
pub mod domain_locks;
pub mod domain_request_limiter;
pub mod error;
pub mod handlers;
pub mod headers;
pub mod htmx;
pub mod jobs;
pub mod log;
pub mod mailers;
pub mod models;

View File

@@ -5,7 +5,10 @@ use anyhow::Result;
use bytes::Bytes;
use once_cell::sync::Lazy;
use tokio::sync::watch::Sender;
use tracing::level_filters::LevelFilter;
use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::filter::Targets;
use tracing_subscriber::prelude::*;
use tracing_subscriber::EnvFilter;
@@ -75,8 +78,8 @@ pub fn init_tracing(
config: &Config,
log_sender: Sender<Bytes>,
) -> Result<(WorkerGuard, WorkerGuard)> {
let console_layer = console_subscriber::spawn();
let stdout_layer = tracing_subscriber::fmt::layer().pretty();
let filter_layer = EnvFilter::from_default_env();
let file_appender = tracing_appender::rolling::hourly("./logs", "log");
let (file_writer, file_writer_guard) = tracing_appender::non_blocking(file_appender);
let mem_writer = LimitedInMemoryBuffer::new(&MEM_LOG, log_sender, config.max_mem_log_size);
@@ -84,24 +87,33 @@ pub fn init_tracing(
let file_writer_layer = tracing_subscriber::fmt::layer().with_writer(file_writer);
let mem_writer_layer = tracing_subscriber::fmt::layer().with_writer(mem_writer);
tracing_subscriber::registry()
.with(filter_layer)
.with(stdout_layer)
.with(file_writer_layer)
.with(mem_writer_layer)
.with(
EnvFilter::from_default_env()
.add_directive("tokio=trace".parse()?)
.add_directive("runtime=trace".parse()?),
)
.with(console_layer)
.with(
stdout_layer.with_filter(
EnvFilter::from_default_env()
.add_directive("tokio=off".parse()?)
.add_directive("runtime=off".parse()?),
),
)
.with(
file_writer_layer.with_filter(
EnvFilter::from_default_env()
.add_directive("tokio=off".parse()?)
.add_directive("runtime=off".parse()?),
),
)
.with(
mem_writer_layer.with_filter(
EnvFilter::from_default_env()
.add_directive("tokio=off".parse()?)
.add_directive("runtime=off".parse()?),
),
)
.init();
Ok((file_writer_guard, mem_writer_guard))
}
pub fn init_worker_tracing() -> Result<WorkerGuard> {
let stdout_layer = tracing_subscriber::fmt::layer().pretty();
let filter_layer = EnvFilter::from_default_env();
let file_appender = tracing_appender::rolling::hourly("./logs", "log");
let (file_writer, file_writer_guard) = tracing_appender::non_blocking(file_appender);
let file_writer_layer = tracing_subscriber::fmt::layer().with_writer(file_writer);
tracing_subscriber::registry()
.with(filter_layer)
.with(stdout_layer)
.with(file_writer_layer)
.init();
Ok(file_writer_guard)
}

View File

@@ -17,7 +17,7 @@ use crate::uuid::Base62Uuid;
// TODO: put in config
const USER_EMAIL_VERIFICATION_TOKEN_EXPIRATION: Duration = Duration::from_secs(24 * 60 * 60);
pub fn send_confirmation_email(db: PgPool, mailer: SmtpTransport, config: Config, user: User) {
pub fn send_confirmation_email(pool: PgPool, mailer: SmtpTransport, config: Config, user: User) {
tokio::spawn(async move {
let user_email_address = match user.email.parse() {
Ok(address) => address,
@@ -28,7 +28,7 @@ pub fn send_confirmation_email(db: PgPool, mailer: SmtpTransport, config: Config
};
let mailbox = Mailbox::new(user.name.clone(), user_email_address);
let token = match UserEmailVerificationToken::create(
&db,
&pool,
CreateUserEmailVerificationToken {
user_id: user.user_id,
expires_at: Utc::now() + USER_EMAIL_VERIFICATION_TOKEN_EXPIRATION,
@@ -42,10 +42,11 @@ pub fn send_confirmation_email(db: PgPool, mailer: SmtpTransport, config: Config
return;
}
};
let mut confirm_link = config.public_url.clone();
let mut confirm_link = config
.public_url
.clone();
confirm_link.set_path("confirm-email");
confirm_link
.query_pairs_mut()
confirm_link.query_pairs_mut()
.append_pair("token_id", &Base62Uuid::from(token.token_id).to_string());
let confirm_link = confirm_link.as_str();

View File

@@ -18,7 +18,7 @@ use crate::uuid::Base62Uuid;
const PASSWORD_RESET_TOKEN_EXPIRATION: Duration = Duration::from_secs(24 * 60 * 60);
pub fn send_forgot_password_email(
db: PgPool,
pool: PgPool,
mailer: SmtpTransport,
config: Config,
user: User,
@@ -35,7 +35,7 @@ pub fn send_forgot_password_email(
};
let mailbox = Mailbox::new(user.name.clone(), user_email_address);
let token = match UserPasswordResetToken::create(
&db,
&pool,
CreatePasswordResetToken {
token_id: Uuid::new_v4(), // cyptographically-secure random uuid
user_id: user.user_id,

View File

@@ -1,8 +1,6 @@
use std::{collections::HashMap, net::SocketAddr, path::Path, sync::Arc};
use anyhow::Result;
use apalis::prelude::*;
use apalis_redis::RedisStorage;
use axum::{
routing::{get, post},
Router,
@@ -16,7 +14,6 @@ use base64::prelude::*;
use bytes::Bytes;
use clap::Parser;
use dotenvy::dotenv;
use fred::prelude::*;
use lettre::transport::smtp::authentication::Credentials;
use lettre::SmtpTransport;
use notify::Watcher;
@@ -29,20 +26,12 @@ use tower::ServiceBuilder;
use tower_http::{services::ServeDir, trace::TraceLayer};
use tower_livereload::LiveReloadLayer;
use tower_sessions::cookie::Key;
use tower_sessions_redis_store::{
fred::{
interfaces::ClientLike as TowerSessionsRedisClientLike,
prelude::{RedisConfig as TowerSessionsRedisConfig, RedisPool as TowerSessionsRedisPool},
},
RedisStore,
};
use tower_sessions_redis_store::{fred::prelude::*, RedisStore};
use tracing::debug;
use lib::config::Config;
use lib::domain_locks::DomainLocks;
use lib::domain_request_limiter::DomainRequestLimiter;
use lib::handlers;
use lib::jobs::AsyncJob;
use lib::log::init_tracing;
use lib::state::AppState;
use lib::USER_AGENT;
@@ -74,7 +63,7 @@ async fn main() -> Result<()> {
let domain_locks = DomainLocks::new();
let client = Client::builder().user_agent(USER_AGENT).build()?;
let db = PgPoolOptions::new()
let pool = PgPoolOptions::new()
.max_connections(config.database_max_connections)
.acquire_timeout(std::time::Duration::from_secs(3))
.connect(&config.database_url)
@@ -83,20 +72,8 @@ async fn main() -> Result<()> {
let redis_config = RedisConfig::from_url(&config.redis_url)?;
let redis_pool = RedisPool::new(redis_config, None, None, None, config.redis_pool_size)?;
redis_pool.init().await?;
let domain_request_limiter = DomainRequestLimiter::new(redis_pool.clone(), 10, 5, 100, 0.5);
// TODO: is it possible to use the same fred RedisPool that the web app uses?
let sessions_redis_config = TowerSessionsRedisConfig::from_url(&config.redis_url)?;
let sessions_redis_pool = TowerSessionsRedisPool::new(
sessions_redis_config,
None,
None,
None,
config.redis_pool_size,
)?;
sessions_redis_pool.init().await?;
let session_store = RedisStore::new(sessions_redis_pool);
let session_store = RedisStore::new(redis_pool);
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(!cfg!(debug_assertions))
.with_expiry(Expiry::OnInactivity(Duration::days(
@@ -104,7 +81,7 @@ async fn main() -> Result<()> {
)))
.with_signed(Key::from(&BASE64_STANDARD.decode(&config.session_secret)?));
let backend = Backend::new(db.clone());
let backend = Backend::new(pool.clone());
let auth_layer = AuthManagerLayerBuilder::new(backend, session_layer).build();
let smtp_creds = Credentials::new(config.smtp_user.clone(), config.smtp_password.clone());
@@ -114,28 +91,17 @@ async fn main() -> Result<()> {
.credentials(smtp_creds)
.build();
sqlx::migrate!().run(&db).await?;
// TODO: use redis_pool from above instead of making a new connection
// See: https://github.com/geofmureithi/apalis/issues/290
let redis_conn = apalis_redis::connect(config.redis_url.clone()).await?;
let apalis_config = apalis_redis::Config::default();
let mut apalis: RedisStorage<AsyncJob> =
RedisStorage::new_with_config(redis_conn, apalis_config);
apalis
.push(AsyncJob::HelloWorld("hello".to_string()))
.await?;
sqlx::migrate!().run(&pool).await?;
let crawl_scheduler = CrawlSchedulerHandle::new(
db.clone(),
pool.clone(),
client.clone(),
domain_locks.clone(),
config.content_dir.clone(),
crawls.clone(),
);
// let _ = crawl_scheduler.bootstrap().await;
let importer = ImporterHandle::new(db.clone(), crawl_scheduler.clone(), imports.clone());
let _ = crawl_scheduler.bootstrap().await;
let importer = ImporterHandle::new(pool.clone(), crawl_scheduler.clone(), imports.clone());
let ip_source_extension = config.ip_source.0.clone().into_extension();
@@ -174,19 +140,16 @@ async fn main() -> Result<()> {
.route("/reset-password", post(handlers::reset_password::post))
.nest_service("/static", ServeDir::new("static"))
.with_state(AppState {
db,
pool,
config,
log_receiver,
crawls,
domain_locks,
domain_request_limiter,
client,
crawl_scheduler,
importer,
imports,
mailer,
apalis,
redis: redis_pool,
})
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
.layer(auth_layer)

View File

@@ -32,7 +32,7 @@ impl UserPasswordResetToken {
}
pub async fn get(
db: impl Executor<'_, Database = Postgres>,
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<UserPasswordResetToken> {
sqlx::query_as!(
@@ -43,7 +43,7 @@ impl UserPasswordResetToken {
where token_id = $1"#,
token_id
)
.fetch_one(db)
.fetch_one(pool)
.await
.map_err(|error| {
if let sqlx::error::Error::RowNotFound = error {
@@ -54,7 +54,7 @@ impl UserPasswordResetToken {
}
pub async fn create(
db: impl Executor<'_, Database = Postgres>,
pool: impl Executor<'_, Database = Postgres>,
payload: CreatePasswordResetToken,
) -> Result<UserPasswordResetToken> {
Ok(sqlx::query_as!(
@@ -70,17 +70,20 @@ impl UserPasswordResetToken {
payload.request_ip,
payload.expires_at
)
.fetch_one(db)
.fetch_one(pool)
.await?)
}
pub async fn delete(db: impl Executor<'_, Database = Postgres>, token_id: Uuid) -> Result<()> {
pub async fn delete(
pool: impl Executor<'_, Database = Postgres>,
token_id: Uuid,
) -> Result<()> {
sqlx::query!(
r#"delete from user_password_reset_token
where token_id = $1"#,
token_id
)
.execute(db)
.execute(pool)
.await?;
Ok(())
}

View File

@@ -1,22 +1,18 @@
use std::collections::HashMap;
use std::sync::Arc;
use apalis_redis::RedisStorage;
use axum::extract::FromRef;
use bytes::Bytes;
use fred::clients::RedisPool;
use lettre::SmtpTransport;
use reqwest::Client;
use sqlx::PgPool;
use tokio::sync::{broadcast, watch, Mutex};
use uuid::Uuid;
use crate::actors::crawl_scheduler::{CrawlSchedulerHandle, CrawlSchedulerHandleMessage};
use crate::actors::importer::{ImporterHandle, ImporterHandleMessage};
use crate::actors::crawl_scheduler::{CrawlSchedulerHandle, CrawlSchedulerHandleMessage};
use crate::config::Config;
use crate::domain_locks::DomainLocks;
use crate::domain_request_limiter::DomainRequestLimiter;
use crate::jobs::AsyncJob;
/// A map of feed IDs to a channel receiver for the active `CrawlScheduler` running a feed crawl
/// for that feed.
@@ -32,35 +28,32 @@ pub type Crawls = Arc<Mutex<HashMap<Uuid, broadcast::Receiver<CrawlSchedulerHand
/// A map of unique import IDs to a channel receiver for the active `Importer` running that import.
///
/// Same as the `Crawls` map, the only purpose of this is to keep track of active imports so that
/// axum handlers can subscribe to the result of the import via the receiver channel which are then
/// Same as the `Crawls` map, the only purpose of this is to keep track of active imports so that
/// axum handlers can subscribe to the result of the import via the receiver channel which are then
/// sent to end-users as a stream of server-sent events.
///
/// This map should only contain imports that have just been created but not yet subscribed to.
/// Entries are only added when a user adds uploads an OPML to import and entries are removed by
/// Entries are only added when a user adds uploads an OPML to import and entries are removed by
/// the same user once a server-sent event connection is established.
pub type Imports = Arc<Mutex<HashMap<Uuid, broadcast::Receiver<ImporterHandleMessage>>>>;
#[derive(Clone)]
pub struct AppState {
pub db: PgPool,
pub pool: PgPool,
pub config: Config,
pub log_receiver: watch::Receiver<Bytes>,
pub crawls: Crawls,
pub domain_locks: DomainLocks,
pub domain_request_limiter: DomainRequestLimiter,
pub client: Client,
pub crawl_scheduler: CrawlSchedulerHandle,
pub importer: ImporterHandle,
pub imports: Imports,
pub mailer: SmtpTransport,
pub apalis: RedisStorage<AsyncJob>,
pub redis: RedisPool,
}
impl FromRef<AppState> for PgPool {
fn from_ref(state: &AppState) -> Self {
state.db.clone()
state.pool.clone()
}
}
@@ -88,12 +81,6 @@ impl FromRef<AppState> for DomainLocks {
}
}
impl FromRef<AppState> for DomainRequestLimiter {
fn from_ref(state: &AppState) -> Self {
state.domain_request_limiter.clone()
}
}
impl FromRef<AppState> for Client {
fn from_ref(state: &AppState) -> Self {
state.client.clone()
@@ -123,15 +110,3 @@ impl FromRef<AppState> for SmtpTransport {
state.mailer.clone()
}
}
impl FromRef<AppState> for RedisStorage<AsyncJob> {
fn from_ref(state: &AppState) -> Self {
state.apalis.clone()
}
}
impl FromRef<AppState> for RedisPool {
fn from_ref(state: &AppState) -> Self {
state.redis.clone()
}
}

2
tailwind.config.js Normal file
View File

@@ -0,0 +1,2 @@
const config = require('./frontend/tailwind.config.js');
export default config;