User extractor and refactor

demo-mode
Wes Holland 1 year ago
parent 618e9bde4b
commit c9ece20bd8

@ -1,6 +1,7 @@
# Copy this to .env and change OAUTH Values
RUST_LOG=debug,tower_http=info
DATABASE_URI=inventory-app.db
SESSION_DATABASE_URI=session.db
OAUTH_CLIENT_ID=changeme
OAUTH_CLIENT_SECRET=changme
OAUTH_AUTH_URL=https://accounts.google.com/o/oauth2/auth

@ -0,0 +1,169 @@
use anyhow::Context;
use axum::{async_trait, http};
use axum::extract::{FromRef, FromRequestParts, Query, State};
use axum::http::{header, StatusCode};
use axum::http::request::Parts;
use axum::response::{IntoResponse, Redirect, Response};
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use serde::{Deserialize, Serialize};
use sqlx::SqlitePool;
use tower_sessions::Session;
use crate::error::AppError;
use crate::{CSRF_TOKEN, USER_SESSION};
pub fn init_client() -> anyhow::Result<BasicClient> {
use std::env::var;
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
let client_id = ClientId::new(client_id);
let client_secret = ClientSecret::new(client_secret);
let auth_url = AuthUrl::new(auth_url)?;
let token_url = TokenUrl::new(token_url)?;
let revoke_url = RevocationUrl::new(revoke_url)?;
let redirect_url = RedirectUrl::new(redirect_url)?;
let client = BasicClient::new(
client_id,
Some(client_secret),
auth_url,
Some(token_url))
.set_redirect_uri(redirect_url)
.set_revocation_uri(revoke_url);
Ok(client)
}
pub async fn auth_google(
session: Session,
State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
session.insert(CSRF_TOKEN, csrf_token).await?;
Ok(Redirect::to(auth_url.as_ref()))
}
pub async fn auth_authorized(
session: Session,
Query(query_auth): Query<AuthRequest>,
State(oauth_client): State<BasicClient>,
) -> anyhow::Result<impl IntoResponse, AppError> {
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?;
let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data
let client = reqwest::Client::new();
let user_data = client
.get(user_info_endpoint)
.bearer_auth(token.access_token().secret())
.send()
.await
.context("failed in sending request to target Url")?;
let user_data = user_data
.json::<User>()
.await
.context("failed to deserialize response as JSON")?;
session.insert(USER_SESSION, user_data).await?;
//TODO Redirect somewhere sane
Ok(Redirect::to("/protected"))
}
pub async fn auth_logout(
session: Session,
) -> anyhow::Result<impl IntoResponse, AppError> {
session.remove::<User>(USER_SESSION).await?;
Ok(Redirect::to("/"))
}
#[derive(Debug, Deserialize)]
pub struct AuthRequest {
pub code: String,
pub state: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub email: String,
pub name: String,
pub verified_email: bool,
pub picture: String,
}
pub struct UserExtractError(http::StatusCode);
impl IntoResponse for UserExtractError {
fn into_response(self) -> Response {
match self.0 {
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
impl From<(http::StatusCode, &'static str)> for UserExtractError {
fn from(value: (StatusCode, &'static str)) -> Self {
Self(value.0)
}
}
impl From<tower_sessions::session::Error> for UserExtractError {
fn from(_value: tower_sessions::session::Error) -> Self {
Self(StatusCode::INTERNAL_SERVER_ERROR)
}
}
#[async_trait]
impl<S> FromRequestParts<S> for User
where
SqlitePool: FromRef<S>,
S: Send + Sync,
{
type Rejection = UserExtractError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(parts, state).await
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
let user: User = session.get(USER_SESSION).await
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
let db = SqlitePool::from_ref(state);
//TODO actual verification of users
if user.email != "whatswithwes@gmail.com" {
Err(UserExtractError(StatusCode::FORBIDDEN))?;
}
Ok(user)
}
}

@ -1,34 +1,29 @@
use crate::app_state::AppState;
use crate::auth::User;
use crate::error::AppError;
use anyhow::{anyhow, Context, Result};
use axum::extract::{FromRef, Query, State};
use axum::http::header::SET_COOKIE;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Redirect};
use axum::{
extract::Request, handler::HandlerWithoutStateExt, http::StatusCode, routing::get, Router,
};
use std::net::SocketAddr;
use tower::ServiceExt;
use tokio::signal;
use tokio::task::{AbortHandle, JoinHandle};
use tower_http::{
services::{ServeDir, ServeFile},
trace::TraceLayer,
};
use tower_sessions::{session_store::ExpiredDeletion, Expiry, Session, SessionManagerLayer};
use tower_sessions_sqlx_store::{SqliteStore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use anyhow::{anyhow, Context, Result};
use axum::extract::{FromRef, Query, State};
use axum::http::header::SET_COOKIE;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Redirect};
use oauth2::basic::BasicClient;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use oauth2::reqwest::async_http_client;
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
use time::Duration;
use tower_sessions::cookie::SameSite;
use serde::{Deserialize, Serialize};
use tower_sessions::Session;
use tracing::info;
use crate::app_state::AppState;
use crate::error::AppError;
use tracing_subscriber::EnvFilter;
mod app_state;
mod db;
mod error;
mod session;
mod auth;
//NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning
// process. There is a lot of implementation details that are obscured by all these pieces, and it
// can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of
@ -58,35 +53,42 @@ async fn main() -> Result<()>{
.compact()
.init();
tracing::info!("{}", env_status);
info!("{}", env_status);
// Application database. What you would expect. Access
// through the application state
let db_file = std::env::var("DATABASE_URI")
.context("DATABASE_URI not set")?;
let db = db::init(&db_file).await?;
let session_store = init_session_store(db.clone()).await?;
let session_layer = init_session_layer(session_store.clone()).await?;
/*TODO
let deletion_task = tokio::task::spawn(
session_store.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
);
*/
let oauth_client = init_oath_client()?;
// OAUTH2 Client
let oauth_client = auth::init_client()?;
// Application state
let app_state = AppState { db, oauth_client };
let router = Router::new()
// Session
let (session_layer, session_task) = session::init().await?;
let auth_routes: Router<AppState> = Router::new()
.route("/auth/login", get(auth::auth_google))
.route("/auth/logout", get(auth::auth_logout))
.route("/auth/authorized", get(auth::auth_authorized));
let static_routes = Router::new()
.nest_service("/", ServeDir::new("static/"));
let test_routes: Router<AppState> = Router::new()
.route("/fail", get(fail))
.route_service("/", ServeFile::new("assets/index.html"))
.nest_service("/js", ServeDir::new("assets/js"))
.nest_service("/css", ServeDir::new("assets/css"))
.route("/auth/oauth", get(auth_google))
.route("/auth/authorized", get(auth_authorized))
.route("/protected", get(protected))
.route("/usertest", get(index))
.route("/protected", get(protected));
let router = Router::new()
.merge(auth_routes)
.merge(test_routes)
.fallback_service(static_routes)
.layer(session_layer)
.layer(TraceLayer::new_for_http())
.with_state(app_state);
let address = "0.0.0.0:4206";
@ -96,8 +98,14 @@ async fn main() -> Result<()>{
info!("listening on {}", address);
axum::serve(listener, router.layer(TraceLayer::new_for_http()))
.await.context("unable to serve")
let mut tasks = vec![];
tasks.push(session_task);
axum::serve(listener, router.into_make_service())
.with_graceful_shutdown(shutdown_signal(tasks))
.await.context("unable to serve")?;
Ok(())
}
async fn fail() -> Result<(), AppError> {
@ -109,72 +117,8 @@ fn always_fails() -> Result<()> {
Err(anyhow!("I always fail"))
}
async fn auth_google(
session: Session,
State(oauth_client): State<BasicClient>,
) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
session.insert(CSRF_TOKEN, csrf_token).await?;
Ok(Redirect::to(auth_url.as_ref()))
}
#[derive(Debug, Deserialize)]
struct AuthRequest {
code: String,
state: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct User {
id: String,
email: String,
name: String,
verified_email: bool,
picture: String,
}
async fn auth_authorized(
session: Session,
Query(query_auth): Query<AuthRequest>,
State(oauth_client): State<BasicClient>,
) -> Result<impl IntoResponse, AppError> {
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?;
let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data
let client = reqwest::Client::new();
let user_data = client
.get(user_info_endpoint)
.bearer_auth(token.access_token().secret())
.send()
.await
.context("failed in sending request to target Url")?;
/*
let response_text = user_data.text().await?;
*/
let user_data = user_data
.json::<User>()
.await
.context("failed to deserialize response as JSON")?;
session.insert(USER_SESSION, user_data).await?;
Ok(Redirect::to("/protected"))
async fn index(user: User) -> impl IntoResponse {
format!("Hello {}", user.email)
}
async fn protected(
@ -193,51 +137,42 @@ async fn protected(
Ok(Redirect::to("/"))
}
async fn init_session_store(db: SqlitePool) -> Result<SqliteStore> {
let session_store = SqliteStore::new(db.clone());
session_store.migrate().await?;
Ok(session_store)
}
async fn init_session_layer(store: SqliteStore) -> Result<SessionManagerLayer<SqliteStore>> {
// This guy forms the session cookies
// The session manager layer is the glue between the session store
// and the handlers. The options basically define the options of
// the cookies given to the client
// Example cookie:
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
Ok(SessionManagerLayer::new(store)
.with_name("SESSION")
.with_same_site(SameSite::Lax)
.with_secure(true)
.with_http_only(true)
.with_path("/")
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600))))
}
async fn shutdown_signal(tasks: Vec<JoinHandle<Result<()>>>) {
fn init_oath_client() -> Result<BasicClient> {
use std::env::var;
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
let client_id = ClientId::new(client_id);
let client_secret = ClientSecret::new(client_secret);
let auth_url = AuthUrl::new(auth_url)?;
let token_url = TokenUrl::new(token_url)?;
let revoke_url = RevocationUrl::new(revoke_url)?;
let redirect_url = RedirectUrl::new(redirect_url)?;
let client = BasicClient::new(
client_id,
Some(client_secret),
auth_url,
Some(token_url))
.set_redirect_uri(redirect_url)
.set_revocation_uri(revoke_url);
Ok(client)
let abort_handles: Vec<AbortHandle> = tasks.iter().map(|h| h.abort_handle()).collect();
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
};
info!("Shutdown signalled");
for handle in abort_handles {
handle.abort();
}
for handle in tasks {
let _ = handle.await;
}
info!("All processes finished");
}

@ -0,0 +1,44 @@
use tower_sessions_sqlx_store::SqliteStore;
use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer};
use tower_sessions::cookie::SameSite;
use time::Duration;
use anyhow::{Context, Result};
use tokio::task::JoinHandle;
use crate::db;
pub async fn init() -> Result<(SessionManagerLayer<SqliteStore>, JoinHandle<Result<()>>)> {
// Session store is a session aware database backing for the session data
let session_db_location = std::env::var("SESSION_DATABASE_URI")
.context("SESSION_DATABASE_URI not set")?;
let session_db = db::init(&session_db_location).await?;
let session_store = SqliteStore::new(session_db);
session_store.migrate().await?;
// This guy forms the session cookies
// The session manager layer is the glue between the session store
// and the handlers. The options basically define the options of
// the cookies given to the client
// Example cookie:
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
let session_layer = SessionManagerLayer::new(session_store.clone())
.with_name("SESSION")
.with_same_site(SameSite::Lax)
.with_secure(true)
.with_http_only(true)
.with_path("/")
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600)));
// We need to spawn a long-running task to clean up expired sessions
let task = tokio::task::spawn(deletion_task(session_store));
Ok((session_layer, task))
}
async fn deletion_task(session_store: SqliteStore) -> Result<()> {
session_store.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60))
.await
.context("delete expired task failed")
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

@ -15,6 +15,7 @@
<article>
<h2>Card</h2>
</article>
<a href="/auth/logout">Logout</a>
</main>
</body>
</html>
Loading…
Cancel
Save

Powered by TurnKey Linux.