use std::sync::Arc; use tokio::sync::Mutex; use axum::{ extract::{Form, Path, Query, State}, http::{header, HeaderMap, StatusCode}, response::IntoResponse, routing::{get, post}, Json, Router, }; use rusqlite::Connection; use serde::Deserialize; use crate::db::{self, AnnouncementFlag}; type Db = Arc>; pub async fn serve(addr: &str, db: Db) { let app = Router::new() .route("/", get(management_page)) .route("/api/announcements", get(list_announcements).post(create_announcement)) .route("/api/announcements/:id/delete", post(delete_announcement)) .with_state(db); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); println!("Management server listening on {addr}"); axum::serve(listener, app).await.unwrap(); } // ── Auth ────────────────────────────────────────────────────────────────────── fn check_auth(headers: &HeaderMap) -> bool { let expected_user = std::env::var("MANAGEMENT_USERNAME").unwrap_or_else(|_| "admin".to_string()); let expected_pass = std::env::var("MANAGEMENT_PASSWORD").unwrap_or_else(|_| "admin".to_string()); let expected = format!("Basic {}", base64_encode(&format!("{expected_user}:{expected_pass}"))); headers .get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) .is_some_and(|v| v == expected) } fn unauthorized() -> Response { ( StatusCode::UNAUTHORIZED, [(header::WWW_AUTHENTICATE, "Basic realm=\"Management\"")], "Unauthorized", ) .into_response() } use axum::response::Response; // ── Query params ────────────────────────────────────────────────────────────── #[derive(Deserialize, Default)] struct ListParams { #[serde(default)] page: u32, #[serde(default = "default_page_size")] page_size: u32, #[serde(default)] filter_date: Option, } const fn default_page_size() -> u32 { 20 } // ── Handlers ────────────────────────────────────────────────────────────────── async fn management_page( headers: HeaderMap, State(db): State, Query(params): Query, ) -> impl IntoResponse { if !check_auth(&headers) { return unauthorized(); } let page = params.page.max(1); let page_size = params.page_size.clamp(1, 100); let filter_date = params.filter_date.filter(|s| !s.is_empty()); let conn = db.lock().await; let announcements = db::list_all(&conn, page, page_size, filter_date.as_deref()).unwrap_or_default(); let total = db::count_all(&conn, filter_date.as_deref()).unwrap_or(0); drop(conn); let total_pages = (total as u32).div_ceil(page_size).max(1); let rows: String = if announcements.is_empty() { let colspan = if filter_date.is_some() { "8" } else { "7" }; format!( r#"No announcements yet."#, colspan ) } else { announcements .iter() .map(|a| { let flags_display: String = a .flags .iter() .map(|f| { format!(r#"{}"#, escape_html(f.display_name())) }) .collect::>() .join(" "); format!( include_str!("../../templates/announcement_row.html"), id = a.id, author = escape_html(&a.author), text = escape_html(&a.text_content), start = a.start_date, end = a.end_date, flags_display = flags_display, ) }) .collect() }; let flags_checkboxes: String = AnnouncementFlag::all() .iter() .map(|f| { let value = f.to_string().to_lowercase(); format!( r#""#, escape_html(&value), escape_html(f.display_name()), ) }) .collect::>() .join(""); let filter_date_value = filter_date.unwrap_or_default(); let filter_param = if filter_date_value.is_empty() { String::new() } else { format!("&filter_date={}", urlencode(&filter_date_value)) }; let pagination = render_pagination(page, total_pages, total as u64, &filter_param); let html = format!( include_str!("../../templates/management_page.html"), total_count = total, filter_date = escape_html(&filter_date_value), flags_checkboxes = flags_checkboxes, rows = rows, pagination = pagination, ); ([(header::CONTENT_TYPE, "text/html; charset=utf-8")], html).into_response() } fn render_pagination(current: u32, total: u32, count: u64, filter_param: &str) -> String { if total <= 1 { return String::new(); } let mut html = String::from( r#"Page "#, ); html.push_str(¤t.to_string()); html.push_str(r#" of "#); html.push_str(&total.to_string()); html.push_str(r#" ("#); html.push_str(&count.to_string()); html.push_str(r#" total)
"#); // prev if current > 1 { html.push_str(&format!( r#"« Prev"#, current - 1, filter_param )); } else { html.push_str(r#"« Prev"#); } // page numbers - show a window around current page let window: u32 = 2; let start_page = if current > window + 1 { current - window } else { 1 }; let end_page = if current + window < total { current + window } else { total }; if start_page > 1 { html.push_str(&format!( r#"1"#, filter_param )); if start_page > 2 { html.push_str(r#"..."#); } } for p in start_page..=end_page { if p == current { html.push_str(&format!(r#"{p}"#)); } else { html.push_str(&format!( r#"{p}"#, filter_param )); } } if end_page < total { if end_page < total - 1 { html.push_str(r#"..."#); } html.push_str(&format!( r#"{total}"#, filter_param )); } // next if current < total { html.push_str(&format!( r#"Next »"#, current + 1, filter_param )); } else { html.push_str(r#"Next »"#); } html.push_str("
"); html } #[derive(Deserialize)] struct CreateInput { author: String, text_content: String, start_date: String, end_date: String, #[serde(default)] flags: Vec, } async fn create_announcement( headers: HeaderMap, State(db): State, Form(input): Form, ) -> impl IntoResponse { if !check_auth(&headers) { return unauthorized(); } let flags: Vec = input .flags .iter() .filter_map(|s| s.parse::().ok()) .collect(); let conn = db.lock().await; if let Err(e) = db::create(&conn, &input.author, &input.text_content, &input.start_date, &input.end_date, &flags) { return ( StatusCode::INTERNAL_SERVER_ERROR, [("location", "/")], format!("Failed to create: {e}"), ) .into_response(); } (StatusCode::FOUND, [("location", "/")], "").into_response() } async fn list_announcements( headers: HeaderMap, State(db): State, Query(params): Query, ) -> impl IntoResponse { if !check_auth(&headers) { return unauthorized(); } let page = params.page.max(1); let page_size = params.page_size.clamp(1, 100); let filter_date = params.filter_date.filter(|s| !s.is_empty()); let conn = db.lock().await; let list = db::list_all(&conn, page, page_size, filter_date.as_deref()); let total = db::count_all(&conn, filter_date.as_deref()).unwrap_or(0); drop(conn); match list { Ok(items) => ( StatusCode::OK, Json(serde_json::json!({ "items": items, "page": page, "page_size": page_size, "total": total, })), ) .into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({ "error": e.to_string() })), ) .into_response(), } } async fn delete_announcement( headers: HeaderMap, State(db): State, Path(id): Path, ) -> impl IntoResponse { if !check_auth(&headers) { return unauthorized(); } let conn = db.lock().await; db::delete(&conn, id).ok(); (StatusCode::FOUND, [("location", "/")], "").into_response() } // ── Helpers ─────────────────────────────────────────────────────────────────── fn escape_html(s: &str) -> String { s.replace('&', "&") .replace('<', "<") .replace('>', ">") .replace('"', """) } fn urlencode(s: &str) -> String { let mut out = String::with_capacity(s.len()); for byte in s.bytes() { match byte { b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { out.push(byte as char); } _ => { out.push_str(&format!("%{:02X}", byte)); } } } out } fn base64_encode(input: &str) -> String { const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; let bytes = input.as_bytes(); let mut out = String::new(); for chunk in bytes.chunks(3) { let b0 = chunk[0] as u32; let b1 = chunk.get(1).copied().unwrap_or(0) as u32; let b2 = chunk.get(2).copied().unwrap_or(0) as u32; let triple = (b0 << 16) | (b1 << 8) | b2; out.push(CHARS[((triple >> 18) & 0x3F) as usize] as char); out.push(CHARS[((triple >> 12) & 0x3F) as usize] as char); if chunk.len() > 1 { out.push(CHARS[((triple >> 6) & 0x3F) as usize] as char); } else { out.push('='); } if chunk.len() > 2 { out.push(CHARS[(triple & 0x3F) as usize] as char); } else { out.push('='); } } out }