|
|
|
@ -7,38 +7,237 @@ use tower_http::{
|
|
|
|
services::{ServeDir, ServeFile},
|
|
|
|
services::{ServeDir, ServeFile},
|
|
|
|
trace::TraceLayer,
|
|
|
|
trace::TraceLayer,
|
|
|
|
};
|
|
|
|
};
|
|
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
|
|
use tower_sessions::{session_store::ExpiredDeletion, Expiry, Session, SessionManagerLayer};
|
|
|
|
use anyhow::{Context,Result};
|
|
|
|
use tower_sessions_sqlx_store::{SqliteStore};
|
|
|
|
|
|
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
|
|
|
|
|
|
|
use anyhow::{anyhow, Context, Result};
|
|
|
|
|
|
|
|
use axum::extract::{FromRef, Query, State};
|
|
|
|
|
|
|
|
use axum::http::header::SET_COOKIE;
|
|
|
|
|
|
|
|
use axum::http::HeaderMap;
|
|
|
|
|
|
|
|
use axum::response::{IntoResponse, Redirect};
|
|
|
|
|
|
|
|
use oauth2::basic::BasicClient;
|
|
|
|
|
|
|
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
|
|
|
|
|
|
|
use oauth2::reqwest::async_http_client;
|
|
|
|
|
|
|
|
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
|
|
|
|
|
|
|
|
use time::Duration;
|
|
|
|
|
|
|
|
use tower_sessions::cookie::SameSite;
|
|
|
|
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
|
|
|
|
use tracing::info;
|
|
|
|
|
|
|
|
use crate::app_state::AppState;
|
|
|
|
|
|
|
|
use crate::error::AppError;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mod app_state;
|
|
|
|
|
|
|
|
mod db;
|
|
|
|
|
|
|
|
mod error;
|
|
|
|
|
|
|
|
//NOTE TO FUTURE ME: I'm leaving a bunch of notes about these things as part of the learning
|
|
|
|
|
|
|
|
// process. There is a lot of implementation details that are obscured by all these pieces, and it
|
|
|
|
|
|
|
|
// can be hard to tell how heavy a line is. Lots of comment in this file and some of the kind of
|
|
|
|
|
|
|
|
// "first breadcrumb" type files so you have a place to pick up the threads. Hopefully you remember
|
|
|
|
|
|
|
|
// this stuff, but in case you don't here's some notes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const CSRF_TOKEN: &str = "csrf_token";
|
|
|
|
|
|
|
|
const USER_SESSION: &str = "user";
|
|
|
|
|
|
|
|
|
|
|
|
#[tokio::main]
|
|
|
|
#[tokio::main]
|
|
|
|
async fn main() -> Result<()>{
|
|
|
|
async fn main() -> Result<()>{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Load local environment variables from a .env file
|
|
|
|
|
|
|
|
// Use the regular std::env::var to use
|
|
|
|
|
|
|
|
let env_status = match dotenvy::from_filename(".env") {
|
|
|
|
|
|
|
|
Ok(_) => { "found local .env file" }
|
|
|
|
|
|
|
|
Err(_) => { "no local .env file" }
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Create a subscriber that will turn tracing events into
|
|
|
|
|
|
|
|
// console logs. Use macros (tracing::info!, tracing::error!)
|
|
|
|
|
|
|
|
// to create events in the code
|
|
|
|
|
|
|
|
// Set default environment variable to set the level
|
|
|
|
|
|
|
|
// Example "RUST_LOG=debug,tower_http=warn"
|
|
|
|
tracing_subscriber::fmt()
|
|
|
|
tracing_subscriber::fmt()
|
|
|
|
.with_max_level(tracing::Level::DEBUG)
|
|
|
|
.with_env_filter(EnvFilter::from_default_env())
|
|
|
|
|
|
|
|
.compact()
|
|
|
|
.init();
|
|
|
|
.init();
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
tracing::info!("{}", env_status);
|
|
|
|
tracing_subscriber::registry()
|
|
|
|
|
|
|
|
.with(
|
|
|
|
let db_file = std::env::var("DATABASE_URI")
|
|
|
|
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
|
|
|
|
.context("DATABASE_URI not set")?;
|
|
|
|
format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into()
|
|
|
|
|
|
|
|
}),
|
|
|
|
let db = db::init(&db_file).await?;
|
|
|
|
)
|
|
|
|
|
|
|
|
.with(tracing_subscriber::fmt::layer())
|
|
|
|
let session_store = init_session_store(db.clone()).await?;
|
|
|
|
.init();
|
|
|
|
let session_layer = init_session_layer(session_store.clone()).await?;
|
|
|
|
|
|
|
|
/*TODO
|
|
|
|
|
|
|
|
let deletion_task = tokio::task::spawn(
|
|
|
|
|
|
|
|
session_store.clone()
|
|
|
|
|
|
|
|
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
|
|
|
|
|
|
|
|
);
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let oauth_client = init_oath_client()?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let app_state = AppState { db, oauth_client };
|
|
|
|
|
|
|
|
|
|
|
|
let router = Router::new()
|
|
|
|
let router = Router::new()
|
|
|
|
|
|
|
|
.route("/fail", get(fail))
|
|
|
|
.route_service("/", ServeFile::new("assets/index.html"))
|
|
|
|
.route_service("/", ServeFile::new("assets/index.html"))
|
|
|
|
.nest_service("/js", ServeDir::new("assets/js"))
|
|
|
|
.nest_service("/js", ServeDir::new("assets/js"))
|
|
|
|
.nest_service("/css", ServeDir::new("assets/css"));
|
|
|
|
.nest_service("/css", ServeDir::new("assets/css"))
|
|
|
|
|
|
|
|
.route("/auth/oauth", get(auth_google))
|
|
|
|
|
|
|
|
.route("/auth/authorized", get(auth_authorized))
|
|
|
|
|
|
|
|
.route("/protected", get(protected))
|
|
|
|
|
|
|
|
.layer(session_layer)
|
|
|
|
|
|
|
|
.with_state(app_state);
|
|
|
|
|
|
|
|
|
|
|
|
let address = "0.0.0.0:4206";
|
|
|
|
let address = "0.0.0.0:4206";
|
|
|
|
let listener = tokio::net::TcpListener::bind(address)
|
|
|
|
let listener = tokio::net::TcpListener::bind(address)
|
|
|
|
.await
|
|
|
|
.await
|
|
|
|
.context("failed to bind")?;
|
|
|
|
.context("failed to bind")?;
|
|
|
|
|
|
|
|
|
|
|
|
tracing::debug!("listening on {}", address);
|
|
|
|
info!("listening on {}", address);
|
|
|
|
|
|
|
|
|
|
|
|
axum::serve(listener, router.layer(TraceLayer::new_for_http()))
|
|
|
|
axum::serve(listener, router.layer(TraceLayer::new_for_http()))
|
|
|
|
.await.context("unable to serve")
|
|
|
|
.await.context("unable to serve")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn fail() -> Result<(), AppError> {
|
|
|
|
|
|
|
|
let val = always_fails()?;
|
|
|
|
|
|
|
|
Ok(val)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn always_fails() -> Result<()> {
|
|
|
|
|
|
|
|
Err(anyhow!("I always fail"))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn auth_google(
|
|
|
|
|
|
|
|
session: Session,
|
|
|
|
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
|
|
|
|
) -> Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
let (auth_url, csrf_token) = oauth_client
|
|
|
|
|
|
|
|
.authorize_url(CsrfToken::new_random)
|
|
|
|
|
|
|
|
.add_scope(Scope::new("profile".to_string()))
|
|
|
|
|
|
|
|
.add_scope(Scope::new("email".to_string()))
|
|
|
|
|
|
|
|
.url();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session.insert(CSRF_TOKEN, csrf_token).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to(auth_url.as_ref()))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
|
|
|
|
|
|
struct AuthRequest {
|
|
|
|
|
|
|
|
code: String,
|
|
|
|
|
|
|
|
state: String,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
|
|
|
|
|
|
struct User {
|
|
|
|
|
|
|
|
id: String,
|
|
|
|
|
|
|
|
email: String,
|
|
|
|
|
|
|
|
name: String,
|
|
|
|
|
|
|
|
verified_email: bool,
|
|
|
|
|
|
|
|
picture: String,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn auth_authorized(
|
|
|
|
|
|
|
|
session: Session,
|
|
|
|
|
|
|
|
Query(query_auth): Query<AuthRequest>,
|
|
|
|
|
|
|
|
State(oauth_client): State<BasicClient>,
|
|
|
|
|
|
|
|
) -> Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
|
|
|
|
|
|
|
.context("OAUTH_USER_INFO_URL not set")?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let token = oauth_client
|
|
|
|
|
|
|
|
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
|
|
|
|
|
|
|
|
.request_async(async_http_client)
|
|
|
|
|
|
|
|
.await
|
|
|
|
|
|
|
|
.context("failed in sending request request to authorization server")?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Fetch user data
|
|
|
|
|
|
|
|
let client = reqwest::Client::new();
|
|
|
|
|
|
|
|
let user_data = client
|
|
|
|
|
|
|
|
.get(user_info_endpoint)
|
|
|
|
|
|
|
|
.bearer_auth(token.access_token().secret())
|
|
|
|
|
|
|
|
.send()
|
|
|
|
|
|
|
|
.await
|
|
|
|
|
|
|
|
.context("failed in sending request to target Url")?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
let response_text = user_data.text().await?;
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let user_data = user_data
|
|
|
|
|
|
|
|
.json::<User>()
|
|
|
|
|
|
|
|
.await
|
|
|
|
|
|
|
|
.context("failed to deserialize response as JSON")?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session.insert(USER_SESSION, user_data).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to("/protected"))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn protected(
|
|
|
|
|
|
|
|
session: Session,
|
|
|
|
|
|
|
|
) -> Result<impl IntoResponse, AppError> {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let user: Option<User> = session.get(USER_SESSION).await?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let Some(user) = user {
|
|
|
|
|
|
|
|
info!("Protected route: Logged in user {}", user.email);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
else {
|
|
|
|
|
|
|
|
info!("Protected route: No user");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Redirect::to("/"))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async fn init_session_store(db: SqlitePool) -> Result<SqliteStore> {
|
|
|
|
|
|
|
|
let session_store = SqliteStore::new(db.clone());
|
|
|
|
|
|
|
|
session_store.migrate().await?;
|
|
|
|
|
|
|
|
Ok(session_store)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
async fn init_session_layer(store: SqliteStore) -> Result<SessionManagerLayer<SqliteStore>> {
|
|
|
|
|
|
|
|
// This guy forms the session cookies
|
|
|
|
|
|
|
|
// The session manager layer is the glue between the session store
|
|
|
|
|
|
|
|
// and the handlers. The options basically define the options of
|
|
|
|
|
|
|
|
// the cookies given to the client
|
|
|
|
|
|
|
|
// Example cookie:
|
|
|
|
|
|
|
|
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
|
|
|
|
|
|
|
|
Ok(SessionManagerLayer::new(store)
|
|
|
|
|
|
|
|
.with_name("SESSION")
|
|
|
|
|
|
|
|
.with_same_site(SameSite::Lax)
|
|
|
|
|
|
|
|
.with_secure(true)
|
|
|
|
|
|
|
|
.with_http_only(true)
|
|
|
|
|
|
|
|
.with_path("/")
|
|
|
|
|
|
|
|
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600))))
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn init_oath_client() -> Result<BasicClient> {
|
|
|
|
|
|
|
|
use std::env::var;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
|
|
|
|
|
|
|
|
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
|
|
|
|
|
|
|
|
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
|
|
|
|
|
|
|
|
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
|
|
|
|
|
|
|
|
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
|
|
|
|
|
|
|
|
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let client_id = ClientId::new(client_id);
|
|
|
|
|
|
|
|
let client_secret = ClientSecret::new(client_secret);
|
|
|
|
|
|
|
|
let auth_url = AuthUrl::new(auth_url)?;
|
|
|
|
|
|
|
|
let token_url = TokenUrl::new(token_url)?;
|
|
|
|
|
|
|
|
let revoke_url = RevocationUrl::new(revoke_url)?;
|
|
|
|
|
|
|
|
let redirect_url = RedirectUrl::new(redirect_url)?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let client = BasicClient::new(
|
|
|
|
|
|
|
|
client_id,
|
|
|
|
|
|
|
|
Some(client_secret),
|
|
|
|
|
|
|
|
auth_url,
|
|
|
|
|
|
|
|
Some(token_url))
|
|
|
|
|
|
|
|
.set_redirect_uri(redirect_url)
|
|
|
|
|
|
|
|
.set_revocation_uri(revoke_url);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(client)
|
|
|
|
|
|
|
|
}
|
|
|
|
|