Framework for OAUTH

demo-mode
Wes Holland 1 year ago
parent 9c6587c367
commit 618e9bde4b

2
.gitignore vendored

@ -1 +1,3 @@
/target /target
/.env
/*.db

@ -3,6 +3,7 @@
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$"> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" /> <excludeFolder url="file://$MODULE_DIR$/target" />
</content> </content>
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />

969
Cargo.lock generated

File diff suppressed because it is too large Load Diff

@ -9,11 +9,18 @@ askama = { version = "0.12.1", features = ["with-axum"] }
axum = "0.7.7" axum = "0.7.7"
axum-htmx = "0.6.0" axum-htmx = "0.6.0"
dotenvy = "0.15.7" dotenvy = "0.15.7"
sqlx = "0.8.2" oauth2 = "4.4.2"
sqlx = { version = "0.8.2", features = ["runtime-tokio", "sqlite"] }
time = "0.3.36"
tokio = { version = "1.41.0", features = ["full", "tracing"] } tokio = { version = "1.41.0", features = ["full", "tracing"] }
tower = { version = "0.5.1", features = ["util"] } tower = { version = "0.5.1", features = ["util"] }
tower-http = { version = "0.6.1", features = ["fs", "trace"] } tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tower-sessions = "0.13.0" tower-sessions = "0.13.0"
tower-sessions-sqlx-store = "0.14.1" tower-sessions-sqlx-store = { version = "0.14.1", features = ["sqlite"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
serde = { version = "1.0.213", features = ["derive"] }
reqwest = { version = "0.12.9", features = ["json"] }
[dev-dependencies]
httpc-test = "0.1.10"

@ -0,0 +1,10 @@
# Copy this to .env and change OAUTH Values
RUST_LOG=debug,tower_http=info
DATABASE_URI=inventory-app.db
OAUTH_CLIENT_ID=changeme
OAUTH_CLIENT_SECRET=changme
OAUTH_AUTH_URL=https://accounts.google.com/o/oauth2/auth
OAUTH_TOKEN_URL=https://accounts.google.com/o/oauth2/token
OAUTH_REVOKE_URL=https://accounts.google.com/o/oauth2/revoke
OAUTH_USER_INFO_URL=https://www.googleapis.com/oauth2/v1/userinfo
OAUTH_REDIRECT_URL=http://localhost:4206/auth/authorized

@ -0,0 +1,27 @@
use axum::extract::FromRef;
use oauth2::basic::BasicClient;
use sqlx::SqlitePool;
// App state. Pretty basic stuff. Gets passed around by the server to the handlers and whatnot
// Use in a handler with the state enum:
// async fn handler(State(my_app_state): State<AppState>)
#[derive(Clone)]
pub struct AppState {
pub db: SqlitePool,
pub oauth_client: BasicClient,
}
// Axum extractors for app state. These allow the handler to just use
// pieces of the App state
// async fn handler(State(my_db): State<SqlitePool>)
impl FromRef<AppState> for SqlitePool {
fn from_ref(input: &AppState) -> Self {
input.db.clone()
}
}
impl FromRef<AppState> for BasicClient {
fn from_ref(input: &AppState) -> Self {
input.oauth_client.clone()
}
}

@ -0,0 +1,14 @@
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
pub async fn init(filename: &str) -> anyhow::Result<SqlitePool> {
let options = SqliteConnectOptions::new()
.filename(filename)
.create_if_missing(true);
let db = SqlitePool::connect_with(options).await?;
tracing::info!("Database connected {}", filename);
Ok(db)
}

@ -0,0 +1,26 @@
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
/// Application error that is a thin layer over anyhow::Error but has the distinction of having
/// the piping for converting from the error to an axum response
pub struct AppError(anyhow::Error);
// Convert app error into axum response. This is the default path for generic errors
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Unhandled error: {:#}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, "Unhandled internal error",).into_response()
}
}
// Quality-of-life helper that saves us from converting errors manually
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

@ -7,38 +7,237 @@ use tower_http::{
services::{ServeDir, ServeFile}, services::{ServeDir, ServeFile},
trace::TraceLayer, trace::TraceLayer,
}; };
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tower_sessions::{session_store::ExpiredDeletion, Expiry, Session, SessionManagerLayer};
use anyhow::{Context,Result}; use tower_sessions_sqlx_store::{SqliteStore};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use anyhow::{anyhow, Context, Result};
use axum::extract::{FromRef, Query, State};
use axum::http::header::SET_COOKIE;
use axum::http::HeaderMap;
use axum::response::{IntoResponse, Redirect};
use oauth2::basic::BasicClient;
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
use oauth2::reqwest::async_http_client;
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
use time::Duration;
use tower_sessions::cookie::SameSite;
use serde::{Deserialize, Serialize};
use tracing::info;
use crate::app_state::AppState;
use crate::error::AppError;
mod app_state;
mod db;
mod error;
//NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning
// process. There is a lot of implementation details that are obscured by all these pieces, and it
// can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of
// "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember
// this stuff, but in case you don't here's some notes
const CSRF_TOKEN: &str = "csrf_token";
const USER_SESSION: &str = "user";
#[tokio::main] #[tokio::main]
async fn main() -> Result<()>{ async fn main() -> Result<()>{
// Load local environment variables from a .env file
// Use the regular std::env::var to use
let env_status = match dotenvy::from_filename(".env") {
Ok(_) => { "found local .env file" }
Err(_) => { "no local .env file" }
};
// Create a subscriber that will turn tracing events into
// console logs. Use macros (tracing::info!, tracing::error!)
// to create events in the code
// Set default environment variable to set the level
// Example "RUST_LOG=debug,tower_http=warn"
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG) .with_env_filter(EnvFilter::from_default_env())
.compact()
.init(); .init();
/* tracing::info!("{}", env_status);
tracing_subscriber::registry()
.with( let db_file = std::env::var("DATABASE_URI")
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { .context("DATABASE_URI not set")?;
format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into()
}), let db = db::init(&db_file).await?;
)
.with(tracing_subscriber::fmt::layer()) let session_store = init_session_store(db.clone()).await?;
.init(); let session_layer = init_session_layer(session_store.clone()).await?;
/*TODO
let deletion_task = tokio::task::spawn(
session_store.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
);
*/ */
let oauth_client = init_oath_client()?;
let app_state = AppState { db, oauth_client };
let router = Router::new() let router = Router::new()
.route("/fail", get(fail))
.route_service("/", ServeFile::new("assets/index.html")) .route_service("/", ServeFile::new("assets/index.html"))
.nest_service("/js", ServeDir::new("assets/js")) .nest_service("/js", ServeDir::new("assets/js"))
.nest_service("/css", ServeDir::new("assets/css")); .nest_service("/css", ServeDir::new("assets/css"))
.route("/auth/oauth", get(auth_google))
.route("/auth/authorized", get(auth_authorized))
.route("/protected", get(protected))
.layer(session_layer)
.with_state(app_state);
let address = "0.0.0.0:4206"; let address = "0.0.0.0:4206";
let listener = tokio::net::TcpListener::bind(address) let listener = tokio::net::TcpListener::bind(address)
.await .await
.context("failed to bind")?; .context("failed to bind")?;
tracing::debug!("listening on {}", address); info!("listening on {}", address);
axum::serve(listener, router.layer(TraceLayer::new_for_http())) axum::serve(listener, router.layer(TraceLayer::new_for_http()))
.await.context("unable to serve") .await.context("unable to serve")
} }
async fn fail() -> Result<(), AppError> {
let val = always_fails()?;
Ok(val)
}
fn always_fails() -> Result<()> {
Err(anyhow!("I always fail"))
}
async fn auth_google(
session: Session,
State(oauth_client): State<BasicClient>,
) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = oauth_client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
session.insert(CSRF_TOKEN, csrf_token).await?;
Ok(Redirect::to(auth_url.as_ref()))
}
#[derive(Debug, Deserialize)]
struct AuthRequest {
code: String,
state: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct User {
id: String,
email: String,
name: String,
verified_email: bool,
picture: String,
}
async fn auth_authorized(
session: Session,
Query(query_auth): Query<AuthRequest>,
State(oauth_client): State<BasicClient>,
) -> Result<impl IntoResponse, AppError> {
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
.context("OAUTH_USER_INFO_URL not set")?;
let token = oauth_client
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
.request_async(async_http_client)
.await
.context("failed in sending request request to authorization server")?;
// Fetch user data
let client = reqwest::Client::new();
let user_data = client
.get(user_info_endpoint)
.bearer_auth(token.access_token().secret())
.send()
.await
.context("failed in sending request to target Url")?;
/*
let response_text = user_data.text().await?;
*/
let user_data = user_data
.json::<User>()
.await
.context("failed to deserialize response as JSON")?;
session.insert(USER_SESSION, user_data).await?;
Ok(Redirect::to("/protected"))
}
async fn protected(
session: Session,
) -> Result<impl IntoResponse, AppError> {
let user: Option<User> = session.get(USER_SESSION).await?;
if let Some(user) = user {
info!("Protected route: Logged in user {}", user.email);
}
else {
info!("Protected route: No user");
}
Ok(Redirect::to("/"))
}
async fn init_session_store(db: SqlitePool) -> Result<SqliteStore> {
let session_store = SqliteStore::new(db.clone());
session_store.migrate().await?;
Ok(session_store)
}
async fn init_session_layer(store: SqliteStore) -> Result<SessionManagerLayer<SqliteStore>> {
// This guy forms the session cookies
// The session manager layer is the glue between the session store
// and the handlers. The options basically define the options of
// the cookies given to the client
// Example cookie:
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
Ok(SessionManagerLayer::new(store)
.with_name("SESSION")
.with_same_site(SameSite::Lax)
.with_secure(true)
.with_http_only(true)
.with_path("/")
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600))))
}
fn init_oath_client() -> Result<BasicClient> {
use std::env::var;
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
let client_id = ClientId::new(client_id);
let client_secret = ClientSecret::new(client_secret);
let auth_url = AuthUrl::new(auth_url)?;
let token_url = TokenUrl::new(token_url)?;
let revoke_url = RevocationUrl::new(revoke_url)?;
let redirect_url = RedirectUrl::new(redirect_url)?;
let client = BasicClient::new(
client_id,
Some(client_secret),
auth_url,
Some(token_url))
.set_redirect_uri(redirect_url)
.set_revocation_uri(revoke_url);
Ok(client)
}

@ -0,0 +1,11 @@
use anyhow::Result;
#[tokio::test]
async fn sanity_check() -> Result<()> {
let client = httpc_test::new_client("http://localhost:4206")?;
client.do_get("/fail").await?.print().await?;
Ok(())
}
Loading…
Cancel
Save

Powered by TurnKey Linux.