diff --git a/example.env b/example.env index ab7eb20..7b4007e 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,7 @@ # Copy this to .env and change OAUTH Values RUST_LOG=debug,tower_http=info DATABASE_URI=inventory-app.db +SESSION_DATABASE_URI=session.db OAUTH_CLIENT_ID=changeme OAUTH_CLIENT_SECRET=changme OAUTH_AUTH_URL=https://accounts.google.com/o/oauth2/auth diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..31f3cc0 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,169 @@ +use anyhow::Context; +use axum::{async_trait, http}; +use axum::extract::{FromRef, FromRequestParts, Query, State}; +use axum::http::{header, StatusCode}; +use axum::http::request::Parts; +use axum::response::{IntoResponse, Redirect, Response}; +use oauth2::basic::BasicClient; +use oauth2::reqwest::async_http_client; +use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; +use serde::{Deserialize, Serialize}; +use sqlx::SqlitePool; +use tower_sessions::Session; + +use crate::error::AppError; +use crate::{CSRF_TOKEN, USER_SESSION}; + + +pub fn init_client() -> anyhow::Result { + use std::env::var; + + let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?; + let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?; + let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?; + let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?; + let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?; + let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?; + + let client_id = ClientId::new(client_id); + let client_secret = ClientSecret::new(client_secret); + let auth_url = AuthUrl::new(auth_url)?; + let token_url = TokenUrl::new(token_url)?; + let revoke_url = RevocationUrl::new(revoke_url)?; + let redirect_url = RedirectUrl::new(redirect_url)?; + + let client = BasicClient::new( + client_id, + Some(client_secret), + auth_url, + Some(token_url)) + .set_redirect_uri(redirect_url) + .set_revocation_uri(revoke_url); + + Ok(client) +} + +pub async fn auth_google( + session: Session, + State(oauth_client): State, +) -> anyhow::Result { + let (auth_url, csrf_token) = oauth_client + .authorize_url(CsrfToken::new_random) + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .url(); + + session.insert(CSRF_TOKEN, csrf_token).await?; + + Ok(Redirect::to(auth_url.as_ref())) +} + + +pub async fn auth_authorized( + session: Session, + Query(query_auth): Query, + State(oauth_client): State, +) -> anyhow::Result { + let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") + .context("OAUTH_USER_INFO_URL not set")?; + + let token = oauth_client + .exchange_code(AuthorizationCode::new(query_auth.code.clone())) + .request_async(async_http_client) + .await + .context("failed in sending request request to authorization server")?; + + // Fetch user data + let client = reqwest::Client::new(); + let user_data = client + .get(user_info_endpoint) + .bearer_auth(token.access_token().secret()) + .send() + .await + .context("failed in sending request to target Url")?; + + let user_data = user_data + .json::() + .await + .context("failed to deserialize response as JSON")?; + + session.insert(USER_SESSION, user_data).await?; + + //TODO Redirect somewhere sane + Ok(Redirect::to("/protected")) +} + +pub async fn auth_logout( + session: Session, +) -> anyhow::Result { + + session.remove::(USER_SESSION).await?; + + Ok(Redirect::to("/")) +} + +#[derive(Debug, Deserialize)] +pub struct AuthRequest { + pub code: String, + pub state: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct User { + pub id: String, + pub email: String, + pub name: String, + pub verified_email: bool, + pub picture: String, +} + +pub struct UserExtractError(http::StatusCode); + +impl IntoResponse for UserExtractError { + fn into_response(self) -> Response { + match self.0 { + StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() } + StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() } + _ => StatusCode::INTERNAL_SERVER_ERROR.into_response() + } + } +} + +impl From<(http::StatusCode, &'static str)> for UserExtractError { + fn from(value: (StatusCode, &'static str)) -> Self { + Self(value.0) + } +} + +impl From for UserExtractError { + fn from(_value: tower_sessions::session::Error) -> Self { + Self(StatusCode::INTERNAL_SERVER_ERROR) + } +} + +#[async_trait] +impl FromRequestParts for User +where + SqlitePool: FromRef, + S: Send + Sync, +{ + type Rejection = UserExtractError; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let session = Session::from_request_parts(parts, state).await + .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?; + + let user: User = session.get(USER_SESSION).await + .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))? + .ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?; + + let db = SqlitePool::from_ref(state); + + //TODO actual verification of users + if user.email != "whatswithwes@gmail.com" { + Err(UserExtractError(StatusCode::FORBIDDEN))?; + } + + Ok(user) + } +} diff --git a/src/main.rs b/src/main.rs index 57da483..68a219d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,34 +1,29 @@ +use crate::app_state::AppState; +use crate::auth::User; +use crate::error::AppError; +use anyhow::{anyhow, Context, Result}; +use axum::extract::{FromRef, Query, State}; +use axum::http::header::SET_COOKIE; +use axum::http::HeaderMap; +use axum::response::{IntoResponse, Redirect}; use axum::{ extract::Request, handler::HandlerWithoutStateExt, http::StatusCode, routing::get, Router, }; -use std::net::SocketAddr; -use tower::ServiceExt; +use tokio::signal; +use tokio::task::{AbortHandle, JoinHandle}; use tower_http::{ services::{ServeDir, ServeFile}, trace::TraceLayer, }; -use tower_sessions::{session_store::ExpiredDeletion, Expiry, Session, SessionManagerLayer}; -use tower_sessions_sqlx_store::{SqliteStore}; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; -use anyhow::{anyhow, Context, Result}; -use axum::extract::{FromRef, Query, State}; -use axum::http::header::SET_COOKIE; -use axum::http::HeaderMap; -use axum::response::{IntoResponse, Redirect}; -use oauth2::basic::BasicClient; -use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; -use oauth2::reqwest::async_http_client; -use sqlx::sqlite::{SqlitePool, SqlitePoolOptions}; -use time::Duration; -use tower_sessions::cookie::SameSite; -use serde::{Deserialize, Serialize}; +use tower_sessions::Session; use tracing::info; -use crate::app_state::AppState; -use crate::error::AppError; +use tracing_subscriber::EnvFilter; mod app_state; mod db; mod error; +mod session; +mod auth; //NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning // process. There is a lot of implementation details that are obscured by all these pieces, and it // can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of @@ -58,35 +53,42 @@ async fn main() -> Result<()>{ .compact() .init(); - tracing::info!("{}", env_status); + info!("{}", env_status); + // Application database. What you would expect. Access + // through the application state let db_file = std::env::var("DATABASE_URI") .context("DATABASE_URI not set")?; - let db = db::init(&db_file).await?; - let session_store = init_session_store(db.clone()).await?; - let session_layer = init_session_layer(session_store.clone()).await?; - /*TODO - let deletion_task = tokio::task::spawn( - session_store.clone() - .continuously_delete_expired(tokio::time::Duration::from_secs(60)), - ); - */ - - let oauth_client = init_oath_client()?; + // OAUTH2 Client + let oauth_client = auth::init_client()?; + // Application state let app_state = AppState { db, oauth_client }; - let router = Router::new() + // Session + let (session_layer, session_task) = session::init().await?; + + let auth_routes: Router = Router::new() + .route("/auth/login", get(auth::auth_google)) + .route("/auth/logout", get(auth::auth_logout)) + .route("/auth/authorized", get(auth::auth_authorized)); + + let static_routes = Router::new() + .nest_service("/", ServeDir::new("static/")); + + let test_routes: Router = Router::new() .route("/fail", get(fail)) - .route_service("/", ServeFile::new("assets/index.html")) - .nest_service("/js", ServeDir::new("assets/js")) - .nest_service("/css", ServeDir::new("assets/css")) - .route("/auth/oauth", get(auth_google)) - .route("/auth/authorized", get(auth_authorized)) - .route("/protected", get(protected)) + .route("/usertest", get(index)) + .route("/protected", get(protected)); + + let router = Router::new() + .merge(auth_routes) + .merge(test_routes) + .fallback_service(static_routes) .layer(session_layer) + .layer(TraceLayer::new_for_http()) .with_state(app_state); let address = "0.0.0.0:4206"; @@ -96,8 +98,14 @@ async fn main() -> Result<()>{ info!("listening on {}", address); - axum::serve(listener, router.layer(TraceLayer::new_for_http())) - .await.context("unable to serve") + let mut tasks = vec![]; + tasks.push(session_task); + + axum::serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal(tasks)) + .await.context("unable to serve")?; + + Ok(()) } async fn fail() -> Result<(), AppError> { @@ -109,72 +117,8 @@ fn always_fails() -> Result<()> { Err(anyhow!("I always fail")) } - -async fn auth_google( - session: Session, - State(oauth_client): State, -) -> Result { - let (auth_url, csrf_token) = oauth_client - .authorize_url(CsrfToken::new_random) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .url(); - - session.insert(CSRF_TOKEN, csrf_token).await?; - - Ok(Redirect::to(auth_url.as_ref())) -} - -#[derive(Debug, Deserialize)] -struct AuthRequest { - code: String, - state: String, -} - -#[derive(Debug, Serialize, Deserialize)] -struct User { - id: String, - email: String, - name: String, - verified_email: bool, - picture: String, -} - -async fn auth_authorized( - session: Session, - Query(query_auth): Query, - State(oauth_client): State, -) -> Result { - let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") - .context("OAUTH_USER_INFO_URL not set")?; - - let token = oauth_client - .exchange_code(AuthorizationCode::new(query_auth.code.clone())) - .request_async(async_http_client) - .await - .context("failed in sending request request to authorization server")?; - - // Fetch user data - let client = reqwest::Client::new(); - let user_data = client - .get(user_info_endpoint) - .bearer_auth(token.access_token().secret()) - .send() - .await - .context("failed in sending request to target Url")?; - - /* - let response_text = user_data.text().await?; - */ - - let user_data = user_data - .json::() - .await - .context("failed to deserialize response as JSON")?; - - session.insert(USER_SESSION, user_data).await?; - - Ok(Redirect::to("/protected")) +async fn index(user: User) -> impl IntoResponse { + format!("Hello {}", user.email) } async fn protected( @@ -193,51 +137,42 @@ async fn protected( Ok(Redirect::to("/")) } -async fn init_session_store(db: SqlitePool) -> Result { - let session_store = SqliteStore::new(db.clone()); - session_store.migrate().await?; - Ok(session_store) -} -async fn init_session_layer(store: SqliteStore) -> Result> { - // This guy forms the session cookies - // The session manager layer is the glue between the session store - // and the handlers. The options basically define the options of - // the cookies given to the client - // Example cookie: - // SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600 - Ok(SessionManagerLayer::new(store) - .with_name("SESSION") - .with_same_site(SameSite::Lax) - .with_secure(true) - .with_http_only(true) - .with_path("/") - .with_expiry(Expiry::OnInactivity(Duration::seconds(3600)))) -} +async fn shutdown_signal(tasks: Vec>>) { -fn init_oath_client() -> Result { - use std::env::var; - - let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?; - let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?; - let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?; - let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?; - let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?; - let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?; - - let client_id = ClientId::new(client_id); - let client_secret = ClientSecret::new(client_secret); - let auth_url = AuthUrl::new(auth_url)?; - let token_url = TokenUrl::new(token_url)?; - let revoke_url = RevocationUrl::new(revoke_url)?; - let redirect_url = RedirectUrl::new(redirect_url)?; - - let client = BasicClient::new( - client_id, - Some(client_secret), - auth_url, - Some(token_url)) - .set_redirect_uri(redirect_url) - .set_revocation_uri(revoke_url); - - Ok(client) + let abort_handles: Vec = tasks.iter().map(|h| h.abort_handle()).collect(); + + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + }; + + info!("Shutdown signalled"); + + for handle in abort_handles { + handle.abort(); + } + + for handle in tasks { + let _ = handle.await; + } + + info!("All processes finished"); } + diff --git a/src/session.rs b/src/session.rs new file mode 100644 index 0000000..425673e --- /dev/null +++ b/src/session.rs @@ -0,0 +1,44 @@ +use tower_sessions_sqlx_store::SqliteStore; +use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer}; +use tower_sessions::cookie::SameSite; +use time::Duration; +use anyhow::{Context, Result}; +use tokio::task::JoinHandle; +use crate::db; + +pub async fn init() -> Result<(SessionManagerLayer, JoinHandle>)> { + + // Session store is a session aware database backing for the session data + let session_db_location = std::env::var("SESSION_DATABASE_URI") + .context("SESSION_DATABASE_URI not set")?; + let session_db = db::init(&session_db_location).await?; + let session_store = SqliteStore::new(session_db); + session_store.migrate().await?; + + // This guy forms the session cookies + // The session manager layer is the glue between the session store + // and the handlers. The options basically define the options of + // the cookies given to the client + // Example cookie: + // SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600 + let session_layer = SessionManagerLayer::new(session_store.clone()) + .with_name("SESSION") + .with_same_site(SameSite::Lax) + .with_secure(true) + .with_http_only(true) + .with_path("/") + .with_expiry(Expiry::OnInactivity(Duration::seconds(3600))); + + + // We need to spawn a long-running task to clean up expired sessions + let task = tokio::task::spawn(deletion_task(session_store)); + + Ok((session_layer, task)) +} + +async fn deletion_task(session_store: SqliteStore) -> Result<()> { + session_store.clone() + .continuously_delete_expired(tokio::time::Duration::from_secs(60)) + .await + .context("delete expired task failed") +} diff --git a/assets/css/pico.min.css b/static/css/pico.min.css similarity index 100% rename from assets/css/pico.min.css rename to static/css/pico.min.css diff --git a/static/favicon.ico b/static/favicon.ico new file mode 100644 index 0000000..e5142bf Binary files /dev/null and b/static/favicon.ico differ diff --git a/assets/index.html b/static/index.html similarity index 91% rename from assets/index.html rename to static/index.html index b3cc0bb..0efbd1d 100644 --- a/assets/index.html +++ b/static/index.html @@ -15,6 +15,7 @@

Card

+ Logout \ No newline at end of file diff --git a/assets/js/htmx.min.js b/static/js/htmx.min.js similarity index 100% rename from assets/js/htmx.min.js rename to static/js/htmx.min.js