parent
618e9bde4b
commit
c9ece20bd8
@ -0,0 +1,169 @@
|
|||||||
|
use anyhow::Context;
|
||||||
|
use axum::{async_trait, http};
|
||||||
|
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
||||||
|
use axum::http::{header, StatusCode};
|
||||||
|
use axum::http::request::Parts;
|
||||||
|
use axum::response::{IntoResponse, Redirect, Response};
|
||||||
|
use oauth2::basic::BasicClient;
|
||||||
|
use oauth2::reqwest::async_http_client;
|
||||||
|
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::SqlitePool;
|
||||||
|
use tower_sessions::Session;
|
||||||
|
|
||||||
|
use crate::error::AppError;
|
||||||
|
use crate::{CSRF_TOKEN, USER_SESSION};
|
||||||
|
|
||||||
|
|
||||||
|
pub fn init_client() -> anyhow::Result<BasicClient> {
|
||||||
|
use std::env::var;
|
||||||
|
|
||||||
|
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
|
||||||
|
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
|
||||||
|
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
|
||||||
|
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
|
||||||
|
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
|
||||||
|
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
|
||||||
|
|
||||||
|
let client_id = ClientId::new(client_id);
|
||||||
|
let client_secret = ClientSecret::new(client_secret);
|
||||||
|
let auth_url = AuthUrl::new(auth_url)?;
|
||||||
|
let token_url = TokenUrl::new(token_url)?;
|
||||||
|
let revoke_url = RevocationUrl::new(revoke_url)?;
|
||||||
|
let redirect_url = RedirectUrl::new(redirect_url)?;
|
||||||
|
|
||||||
|
let client = BasicClient::new(
|
||||||
|
client_id,
|
||||||
|
Some(client_secret),
|
||||||
|
auth_url,
|
||||||
|
Some(token_url))
|
||||||
|
.set_redirect_uri(redirect_url)
|
||||||
|
.set_revocation_uri(revoke_url);
|
||||||
|
|
||||||
|
Ok(client)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn auth_google(
|
||||||
|
session: Session,
|
||||||
|
State(oauth_client): State<BasicClient>,
|
||||||
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||||
|
let (auth_url, csrf_token) = oauth_client
|
||||||
|
.authorize_url(CsrfToken::new_random)
|
||||||
|
.add_scope(Scope::new("profile".to_string()))
|
||||||
|
.add_scope(Scope::new("email".to_string()))
|
||||||
|
.url();
|
||||||
|
|
||||||
|
session.insert(CSRF_TOKEN, csrf_token).await?;
|
||||||
|
|
||||||
|
Ok(Redirect::to(auth_url.as_ref()))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub async fn auth_authorized(
|
||||||
|
session: Session,
|
||||||
|
Query(query_auth): Query<AuthRequest>,
|
||||||
|
State(oauth_client): State<BasicClient>,
|
||||||
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||||
|
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
||||||
|
.context("OAUTH_USER_INFO_URL not set")?;
|
||||||
|
|
||||||
|
let token = oauth_client
|
||||||
|
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
|
||||||
|
.request_async(async_http_client)
|
||||||
|
.await
|
||||||
|
.context("failed in sending request request to authorization server")?;
|
||||||
|
|
||||||
|
// Fetch user data
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let user_data = client
|
||||||
|
.get(user_info_endpoint)
|
||||||
|
.bearer_auth(token.access_token().secret())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("failed in sending request to target Url")?;
|
||||||
|
|
||||||
|
let user_data = user_data
|
||||||
|
.json::<User>()
|
||||||
|
.await
|
||||||
|
.context("failed to deserialize response as JSON")?;
|
||||||
|
|
||||||
|
session.insert(USER_SESSION, user_data).await?;
|
||||||
|
|
||||||
|
//TODO Redirect somewhere sane
|
||||||
|
Ok(Redirect::to("/protected"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn auth_logout(
|
||||||
|
session: Session,
|
||||||
|
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||||
|
|
||||||
|
session.remove::<User>(USER_SESSION).await?;
|
||||||
|
|
||||||
|
Ok(Redirect::to("/"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AuthRequest {
|
||||||
|
pub code: String,
|
||||||
|
pub state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct User {
|
||||||
|
pub id: String,
|
||||||
|
pub email: String,
|
||||||
|
pub name: String,
|
||||||
|
pub verified_email: bool,
|
||||||
|
pub picture: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UserExtractError(http::StatusCode);
|
||||||
|
|
||||||
|
impl IntoResponse for UserExtractError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
match self.0 {
|
||||||
|
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
|
||||||
|
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
|
||||||
|
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<(http::StatusCode, &'static str)> for UserExtractError {
|
||||||
|
fn from(value: (StatusCode, &'static str)) -> Self {
|
||||||
|
Self(value.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<tower_sessions::session::Error> for UserExtractError {
|
||||||
|
fn from(_value: tower_sessions::session::Error) -> Self {
|
||||||
|
Self(StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S> FromRequestParts<S> for User
|
||||||
|
where
|
||||||
|
SqlitePool: FromRef<S>,
|
||||||
|
S: Send + Sync,
|
||||||
|
{
|
||||||
|
type Rejection = UserExtractError;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let session = Session::from_request_parts(parts, state).await
|
||||||
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||||
|
|
||||||
|
let user: User = session.get(USER_SESSION).await
|
||||||
|
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||||
|
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
|
||||||
|
|
||||||
|
let db = SqlitePool::from_ref(state);
|
||||||
|
|
||||||
|
//TODO actual verification of users
|
||||||
|
if user.email != "whatswithwes@gmail.com" {
|
||||||
|
Err(UserExtractError(StatusCode::FORBIDDEN))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
use tower_sessions_sqlx_store::SqliteStore;
|
||||||
|
use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer};
|
||||||
|
use tower_sessions::cookie::SameSite;
|
||||||
|
use time::Duration;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use crate::db;
|
||||||
|
|
||||||
|
pub async fn init() -> Result<(SessionManagerLayer<SqliteStore>, JoinHandle<Result<()>>)> {
|
||||||
|
|
||||||
|
// Session store is a session aware database backing for the session data
|
||||||
|
let session_db_location = std::env::var("SESSION_DATABASE_URI")
|
||||||
|
.context("SESSION_DATABASE_URI not set")?;
|
||||||
|
let session_db = db::init(&session_db_location).await?;
|
||||||
|
let session_store = SqliteStore::new(session_db);
|
||||||
|
session_store.migrate().await?;
|
||||||
|
|
||||||
|
// This guy forms the session cookies
|
||||||
|
// The session manager layer is the glue between the session store
|
||||||
|
// and the handlers. The options basically define the options of
|
||||||
|
// the cookies given to the client
|
||||||
|
// Example cookie:
|
||||||
|
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
|
||||||
|
let session_layer = SessionManagerLayer::new(session_store.clone())
|
||||||
|
.with_name("SESSION")
|
||||||
|
.with_same_site(SameSite::Lax)
|
||||||
|
.with_secure(true)
|
||||||
|
.with_http_only(true)
|
||||||
|
.with_path("/")
|
||||||
|
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600)));
|
||||||
|
|
||||||
|
|
||||||
|
// We need to spawn a long-running task to clean up expired sessions
|
||||||
|
let task = tokio::task::spawn(deletion_task(session_store));
|
||||||
|
|
||||||
|
Ok((session_layer, task))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deletion_task(session_store: SqliteStore) -> Result<()> {
|
||||||
|
session_store.clone()
|
||||||
|
.continuously_delete_expired(tokio::time::Duration::from_secs(60))
|
||||||
|
.await
|
||||||
|
.context("delete expired task failed")
|
||||||
|
}
|
||||||
|
After Width: | Height: | Size: 15 KiB |
Loading…
Reference in new issue