diff --git a/Cargo.lock b/Cargo.lock index 5d4febe..b451388 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,6 +160,7 @@ checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", "axum-core", + "axum-macros", "bytes", "futures-util", "http 1.1.0", @@ -207,6 +208,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-extra" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73c3220b188aea709cf1b6c5f9b01c3bd936bb08bd2b5184a12b35ac8131b1f9" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "serde", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-htmx" version = "0.6.0" @@ -218,6 +241,17 @@ dependencies = [ "http 1.1.0", ] +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -1093,6 +1127,7 @@ dependencies = [ "askama", "askama_axum", "axum", + "axum-extra", "axum-htmx", "dotenvy", "httpc-test", diff --git a/Cargo.toml b/Cargo.toml index d58fda8..e5a4636 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] anyhow = "1.0.91" askama = { version = "0.12.1", features = ["with-axum"] } -axum = "0.7.7" +axum = { version = "0.7.7", features = ["macros"] } axum-htmx = "0.6.0" dotenvy = "0.15.7" oauth2 = "4.4.2" @@ -22,6 +22,7 @@ tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } serde = { version = "1.0.213", features = ["derive"] } reqwest = { version = "0.12.9", features = ["json"] } askama_axum = "0.4.0" +axum-extra = "0.9.4" [dev-dependencies] httpc-test = "0.1.10" diff --git a/example.env b/example.env index 7b4007e..d0b388a 100644 --- a/example.env +++ b/example.env @@ -9,3 +9,5 @@ OAUTH_TOKEN_URL=https://accounts.google.com/o/oauth2/token OAUTH_REVOKE_URL=https://accounts.google.com/o/oauth2/revoke OAUTH_USER_INFO_URL=https://www.googleapis.com/oauth2/v1/userinfo OAUTH_REDIRECT_URL=http://localhost:4206/auth/authorized +AUTHORIZED_USERS=user1@somewhere.com;user2@somewhere.com +ROUTES_INCLUDE_ERROR_TESTS=no diff --git a/src/auth.rs b/src/auth.rs index f5c71f0..5d7aaed 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,9 +1,11 @@ use anyhow::{anyhow, Context}; -use axum::{async_trait, http}; +use askama::Template; +use axum::{async_trait, http, Router}; use axum::extract::{FromRef, FromRequestParts, Query, State}; use axum::http::{header, StatusCode}; use axum::http::request::Parts; use axum::response::{IntoResponse, Redirect, Response}; +use axum::routing::get; use oauth2::basic::BasicClient; use oauth2::reqwest::async_http_client; use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl}; @@ -11,9 +13,17 @@ use serde::{Deserialize, Serialize}; use sqlx::SqlitePool; use tower_sessions::Session; -use crate::error::AppError; -use crate::{CSRF_TOKEN, USER_SESSION}; +use crate::error::{AppError, AppForbiddenResponse}; +use crate::{auth, CSRF_TOKEN, USER_SESSION}; +use crate::error::QueryExtractor; +use crate::app_state::AppState; +pub fn routes() -> Router { + Router::new() + .route("/auth/login", get(auth_login)) + .route("/auth/logout", get(auth_logout)) + .route("/auth/authorized", get(auth_authorized)) +} pub fn init_client() -> anyhow::Result { use std::env::var; @@ -43,10 +53,15 @@ pub fn init_client() -> anyhow::Result { Ok(client) } -pub async fn auth_google( +pub async fn auth_login( session: Session, + user: Option, State(oauth_client): State, ) -> anyhow::Result { + if user.is_some() { + return Ok(Redirect::to("/")); + } + let (auth_url, csrf_token) = oauth_client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("profile".to_string())) @@ -61,7 +76,7 @@ pub async fn auth_google( pub async fn auth_authorized( session: Session, - Query(query_auth): Query, + QueryExtractor(query_auth): QueryExtractor, State(oauth_client): State, ) -> anyhow::Result { let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL") @@ -106,24 +121,29 @@ pub async fn auth_authorized( let is_authorized = valid_users.contains(&user_data.email.as_str()); - if is_authorized { - session.insert(USER_SESSION, user_data).await?; - - //TODO Redirect somewhere sane - Ok(Redirect::to("/protected").into_response()) - } - else { - Ok((http::StatusCode::UNAUTHORIZED, "Unauthorized").into_response()) + if !is_authorized { + return Ok(AppForbiddenResponse::new(&user_data.email, "application").into_response()) } + + session.insert(USER_SESSION, user_data).await?; + + Ok(Redirect::to("/").into_response()) } +#[derive(Template)] +#[template(path = "logged-out.html")] +struct LoggedOutTemplate; + pub async fn auth_logout( session: Session, + user: Option, ) -> anyhow::Result { + + if user.is_some() { + session.remove::(USER_SESSION).await?; + } - session.remove::(USER_SESSION).await?; - - Ok(Redirect::to("/")) + Ok(LoggedOutTemplate.into_response()) } #[derive(Debug, Deserialize)] @@ -141,30 +161,20 @@ pub struct User { pub picture: String, } -pub struct UserExtractError(http::StatusCode); +pub enum UserExtractError { + InternalServerError(anyhow::Error), + Unauthorized, +} impl IntoResponse for UserExtractError { fn into_response(self) -> Response { - match self.0 { - StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() } - StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() } - _ => StatusCode::INTERNAL_SERVER_ERROR.into_response() + match self { + UserExtractError::InternalServerError(err) => AppError::from(err).into_response(), + UserExtractError::Unauthorized => { Redirect::temporary("/auth/login").into_response() } } } } -impl From<(http::StatusCode, &'static str)> for UserExtractError { - fn from(value: (StatusCode, &'static str)) -> Self { - Self(value.0) - } -} - -impl From for UserExtractError { - fn from(_value: tower_sessions::session::Error) -> Self { - Self(StatusCode::INTERNAL_SERVER_ERROR) - } -} - #[async_trait] impl FromRequestParts for User where @@ -175,11 +185,11 @@ where async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let session = Session::from_request_parts(parts, state).await - .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?; + .map_err(|_| UserExtractError::InternalServerError(anyhow!("session from parts failed")))?; - let user: User = session.get(USER_SESSION).await - .map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))? - .ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?; + let user = session.get(USER_SESSION).await + .map_err(|e| UserExtractError::InternalServerError(anyhow::Error::from(e)))? + .ok_or(UserExtractError::Unauthorized)?; Ok(user) } diff --git a/src/error.rs b/src/error.rs index aaddeec..c6c6477 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,21 +1,59 @@ +use anyhow::anyhow; +use askama::Template; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; +use axum::Router; +use axum::routing::get; +use axum::extract::FromRequestParts; +use crate::app_state::AppState; +use crate::auth::User; + +/// These are just test routes. They shouldn't really be called directly +/// as they just return an error. But they are nice for testing +pub fn routes() -> Router { + use std::env::var; + + let mut router = Router::new(); + + let should_include = var("ROUTES_INCLUDE_ERROR_TESTS") + .unwrap_or("no".to_string()); + + if should_include.as_str() == "yes" { + router = router + .route("/error/forbidden", get(forbidden)) + .route("/error/unhandled", get(fail)); + + } + + router +} + +/// Handler that always responds with 404 Not Found +pub async fn not_found() -> impl IntoResponse { + AppNotFoundResponse.into_response() +} /// Application error that is a thin layer over anyhow::Error but has the distinction of having /// the piping for converting from the error to an axum response pub struct AppError(anyhow::Error); +/// A template for a bit nicer error page to the basic error that is just a bare string +#[derive(Template)] +#[template(path = "app-error.html")] +struct AppErrorTemplate; -// Convert app error into axum response. This is the default path for generic errors +/// Convert app error into axum response. This is the default path for generic errors impl IntoResponse for AppError { fn into_response(self) -> Response { tracing::error!("Unhandled error: {:#}", self.0); - (StatusCode::INTERNAL_SERVER_ERROR, "Unhandled internal error",).into_response() + let (mut parts, body) = AppErrorTemplate.into_response().into_parts(); + parts.status = StatusCode::INTERNAL_SERVER_ERROR; + (parts, body).into_response() } } -// Quality-of-life helper that saves us from converting errors manually +/// Quality-of-life helper that saves us from converting errors manually impl From for AppError where E: Into, @@ -23,4 +61,70 @@ where fn from(err: E) -> Self { Self(err.into()) } -} \ No newline at end of file +} + +/// A custom query extractor that yields an app error instead of +/// a query rejection error. This is automatically hooked up +/// because the query rejection is able to be Into'ed into an +/// anyhow::Error +#[derive(FromRequestParts)] +#[from_request(via(axum::extract::Query), rejection(AppError))] +pub struct QueryExtractor(pub T); + + +/// A response type for when returning a 403 (Forbidden) Response +pub struct AppForbiddenResponse { + email: String, + resource: String, +} + +#[derive(Template)] +#[template(path = "forbidden.html")] +struct ForbiddenTemplate; + +impl AppForbiddenResponse { + pub fn new(email: &str, resource: &str) -> Self { + Self { email: email.to_string(), resource: resource.to_string() } + } +} + +impl IntoResponse for AppForbiddenResponse { + fn into_response(self) -> Response { + tracing::error!("forbidden {} accessing {}", self.email, self.resource); + + let (mut parts, body) = ForbiddenTemplate.into_response().into_parts(); + parts.status = StatusCode::FORBIDDEN; + (parts, body).into_response() + } +} + + +/// A response type for when returning a 404 (Not found) Response +pub struct AppNotFoundResponse; + +#[derive(Template)] +#[template(path = "not-found.html")] +struct NotFoundTemplate; + +impl IntoResponse for AppNotFoundResponse { + fn into_response(self) -> Response { + let (mut parts, body) = NotFoundTemplate.into_response().into_parts(); + parts.status = StatusCode::NOT_FOUND; + (parts, body).into_response() + } +} + +/// Handler that always fails with a 500 Error +async fn fail() -> anyhow::Result<(), AppError> { + let val = always_fails()?; + Ok(val) +} + +fn always_fails() -> anyhow::Result<()> { + Err(anyhow!("I always fail")) +} + +/// Handler that always responds with 403 Forbidden +async fn forbidden(user: User) -> impl IntoResponse { + AppForbiddenResponse::new(&user.email, "test endpoint") +} diff --git a/src/main.rs b/src/main.rs index e8b16ac..21a014e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,13 @@ use crate::app_state::AppState; use crate::auth::User; -use crate::error::AppError; +use crate::error::{AppError, AppForbiddenResponse}; use anyhow::{anyhow, Context, Result}; -use askama::Template; +use askama_axum::Template; use axum::extract::{FromRef, Query, State}; use axum::http::header::SET_COOKIE; use axum::http::HeaderMap; use axum::response::{IntoResponse, Redirect}; -use axum::{ - extract::Request, handler::HandlerWithoutStateExt, http::StatusCode, routing::get, Router, -}; +use axum::{extract::Request, handler::HandlerWithoutStateExt, http, http::StatusCode, routing::get, Router}; use tokio::signal; use tokio::task::{AbortHandle, JoinHandle}; use tower_http::{ @@ -25,6 +23,7 @@ mod db; mod error; mod session; mod auth; +mod static_routes; //NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning // process. There is a lot of implementation details that are obscured by all these pieces, and it // can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of @@ -70,32 +69,23 @@ async fn main() -> Result<()>{ // Session let (session_layer, session_task) = session::init().await?; + + let auth_routes = auth::routes(); + let static_routes = static_routes::routes(); - let auth_routes: Router = Router::new() - .route("/auth/login", get(auth::auth_google)) - .route("/auth/logout", get(auth::auth_logout)) - .route("/auth/authorized", get(auth::auth_authorized)); - - let static_routes = Router::new() - .nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css")) - .nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js")) - .nest_service("/favicon.ico", ServeFile::new("static/favicon.ico")); - - let test_routes: Router = Router::new() - .route("/fail", get(fail)) - .route("/usertest", get(index)) - .route("/protected", get(protected)); - + let error_routes: Router = error::routes(); + let app_routes: Router = Router::new() .route("/", get(index)); - + let router = Router::new() .merge(auth_routes) - .merge(test_routes) + .merge(error_routes) .merge(app_routes) .merge(static_routes) .layer(session_layer) .layer(TraceLayer::new_for_http()) + .fallback(error::not_found) .with_state(app_state); let address = "0.0.0.0:4206"; @@ -115,43 +105,17 @@ async fn main() -> Result<()>{ Ok(()) } -async fn fail() -> Result<(), AppError> { - let val = always_fails()?; - Ok(val) -} -fn always_fails() -> Result<()> { - Err(anyhow!("I always fail")) -} - -#[derive(Template)] // this will generate the code... -#[template(path = "index.html")] // using the template in this path, relative -// to the `templates` dir in the crate root -struct IndexTemplate<'a> { // the name of the struct can be anything - name: &'a str, // the field name should match the variable name - // in your template +#[derive(Template)] +#[template(path = "index.html")] +struct IndexTemplate<'a> { + name: &'a str, } async fn index(user: User) -> impl IntoResponse { IndexTemplate { name: user.name.as_str() }.into_response() - //format!("Hello {}", user.email) } -async fn protected( - session: Session, -) -> Result { - - let user: Option = session.get(USER_SESSION).await?; - - if let Some(user) = user { - info!("Protected route: Logged in user {}", user.email); - } - else { - info!("Protected route: No user"); - } - - Ok(Redirect::to("/")) -} async fn shutdown_signal(tasks: Vec>>) { diff --git a/src/static_routes.rs b/src/static_routes.rs new file mode 100644 index 0000000..3d0423d --- /dev/null +++ b/src/static_routes.rs @@ -0,0 +1,10 @@ +use axum::Router; +use tower_http::services::ServeFile; +use crate::app_state::AppState; + +pub fn routes() -> Router { + Router::new() + .nest_service("/css/pico.min.css", ServeFile::new("static/css/pico.min.css")) + .nest_service("/js/htmx.min.js", ServeFile::new("static/js/htmx.min.js")) + .nest_service("/favicon.ico", ServeFile::new("static/favicon.ico")) +} diff --git a/templates/app-error.html b/templates/app-error.html new file mode 100644 index 0000000..8199cd6 --- /dev/null +++ b/templates/app-error.html @@ -0,0 +1,10 @@ +{% extends "problem.html" %} + +{% block content %} + +

Error

+

+ Oops, something went wrong. Press the back button to try again. +

+ +{% endblock %} \ No newline at end of file diff --git a/templates/forbidden.html b/templates/forbidden.html new file mode 100644 index 0000000..83aff41 --- /dev/null +++ b/templates/forbidden.html @@ -0,0 +1,11 @@ +{% extends "problem.html" %} + +{% block content %} + +

Forbidden

+

+ You are forbidden from accessing this resource. Please contact your supervisor or logout + and log back in with a different user. +

+ +{% endblock %} \ No newline at end of file diff --git a/templates/logged-out.html b/templates/logged-out.html new file mode 100644 index 0000000..2e16e19 --- /dev/null +++ b/templates/logged-out.html @@ -0,0 +1,9 @@ +{% extends "main.html" %} + +{% block content %} + +

Logged out

+

You have been logged out

+

Log In

+ +{% endblock %} \ No newline at end of file diff --git a/templates/main.html b/templates/main.html index db0ac51..1aa3f05 100644 --- a/templates/main.html +++ b/templates/main.html @@ -4,13 +4,13 @@ - - + + Test Page
- {% block content %}

Placeholder content

{% endblock %} + {% block content %}

Content Missing

{% endblock %}
\ No newline at end of file diff --git a/templates/not-found.html b/templates/not-found.html new file mode 100644 index 0000000..834b5a9 --- /dev/null +++ b/templates/not-found.html @@ -0,0 +1,11 @@ +{% extends "problem.html" %} + +{% block content %} + +

Not Found

+

+ Sorry, we can't seem to find the page you're looking for. Please press back button or return + home. +

+ +{% endblock %} \ No newline at end of file diff --git a/templates/problem.html b/templates/problem.html new file mode 100644 index 0000000..9d956a3 --- /dev/null +++ b/templates/problem.html @@ -0,0 +1,16 @@ + + + + + + + + + Problem + + +
+ {% block content %}

Something went wrong

{% endblock %} +
+ + \ No newline at end of file