parent
618e9bde4b
commit
c9ece20bd8
@ -0,0 +1,169 @@
|
||||
use anyhow::Context;
|
||||
use axum::{async_trait, http};
|
||||
use axum::extract::{FromRef, FromRequestParts, Query, State};
|
||||
use axum::http::{header, StatusCode};
|
||||
use axum::http::request::Parts;
|
||||
use axum::response::{IntoResponse, Redirect, Response};
|
||||
use oauth2::basic::BasicClient;
|
||||
use oauth2::reqwest::async_http_client;
|
||||
use oauth2::{AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, RevocationUrl, Scope, TokenResponse, TokenUrl};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::SqlitePool;
|
||||
use tower_sessions::Session;
|
||||
|
||||
use crate::error::AppError;
|
||||
use crate::{CSRF_TOKEN, USER_SESSION};
|
||||
|
||||
|
||||
pub fn init_client() -> anyhow::Result<BasicClient> {
|
||||
use std::env::var;
|
||||
|
||||
let client_id = var("OAUTH_CLIENT_ID").context("env OAUTH_CLIENT_ID not set")?;
|
||||
let client_secret = var("OAUTH_CLIENT_SECRET").context("env OAUTH_CLIENT_SECRET not set")?;
|
||||
let auth_url = var("OAUTH_AUTH_URL").context("env OAUTH_AUTH_URL not set")?;
|
||||
let token_url = var("OAUTH_TOKEN_URL").context("env OAUTH_TOKEN_URL not set")?;
|
||||
let revoke_url = var("OAUTH_REVOKE_URL").context("env OAUTH_REVOKE_URL not set")?;
|
||||
let redirect_url = var("OAUTH_REDIRECT_URL").context("env OAUTH_REDIRECT_URL not set")?;
|
||||
|
||||
let client_id = ClientId::new(client_id);
|
||||
let client_secret = ClientSecret::new(client_secret);
|
||||
let auth_url = AuthUrl::new(auth_url)?;
|
||||
let token_url = TokenUrl::new(token_url)?;
|
||||
let revoke_url = RevocationUrl::new(revoke_url)?;
|
||||
let redirect_url = RedirectUrl::new(redirect_url)?;
|
||||
|
||||
let client = BasicClient::new(
|
||||
client_id,
|
||||
Some(client_secret),
|
||||
auth_url,
|
||||
Some(token_url))
|
||||
.set_redirect_uri(redirect_url)
|
||||
.set_revocation_uri(revoke_url);
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub async fn auth_google(
|
||||
session: Session,
|
||||
State(oauth_client): State<BasicClient>,
|
||||
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||
let (auth_url, csrf_token) = oauth_client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scope(Scope::new("profile".to_string()))
|
||||
.add_scope(Scope::new("email".to_string()))
|
||||
.url();
|
||||
|
||||
session.insert(CSRF_TOKEN, csrf_token).await?;
|
||||
|
||||
Ok(Redirect::to(auth_url.as_ref()))
|
||||
}
|
||||
|
||||
|
||||
pub async fn auth_authorized(
|
||||
session: Session,
|
||||
Query(query_auth): Query<AuthRequest>,
|
||||
State(oauth_client): State<BasicClient>,
|
||||
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||
let user_info_endpoint = std::env::var("OAUTH_USER_INFO_URL")
|
||||
.context("OAUTH_USER_INFO_URL not set")?;
|
||||
|
||||
let token = oauth_client
|
||||
.exchange_code(AuthorizationCode::new(query_auth.code.clone()))
|
||||
.request_async(async_http_client)
|
||||
.await
|
||||
.context("failed in sending request request to authorization server")?;
|
||||
|
||||
// Fetch user data
|
||||
let client = reqwest::Client::new();
|
||||
let user_data = client
|
||||
.get(user_info_endpoint)
|
||||
.bearer_auth(token.access_token().secret())
|
||||
.send()
|
||||
.await
|
||||
.context("failed in sending request to target Url")?;
|
||||
|
||||
let user_data = user_data
|
||||
.json::<User>()
|
||||
.await
|
||||
.context("failed to deserialize response as JSON")?;
|
||||
|
||||
session.insert(USER_SESSION, user_data).await?;
|
||||
|
||||
//TODO Redirect somewhere sane
|
||||
Ok(Redirect::to("/protected"))
|
||||
}
|
||||
|
||||
pub async fn auth_logout(
|
||||
session: Session,
|
||||
) -> anyhow::Result<impl IntoResponse, AppError> {
|
||||
|
||||
session.remove::<User>(USER_SESSION).await?;
|
||||
|
||||
Ok(Redirect::to("/"))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AuthRequest {
|
||||
pub code: String,
|
||||
pub state: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct User {
|
||||
pub id: String,
|
||||
pub email: String,
|
||||
pub name: String,
|
||||
pub verified_email: bool,
|
||||
pub picture: String,
|
||||
}
|
||||
|
||||
pub struct UserExtractError(http::StatusCode);
|
||||
|
||||
impl IntoResponse for UserExtractError {
|
||||
fn into_response(self) -> Response {
|
||||
match self.0 {
|
||||
StatusCode::UNAUTHORIZED => { Redirect::temporary("/auth/login").into_response() }
|
||||
StatusCode::FORBIDDEN => { StatusCode::FORBIDDEN.into_response() }
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(http::StatusCode, &'static str)> for UserExtractError {
|
||||
fn from(value: (StatusCode, &'static str)) -> Self {
|
||||
Self(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tower_sessions::session::Error> for UserExtractError {
|
||||
fn from(_value: tower_sessions::session::Error) -> Self {
|
||||
Self(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S> FromRequestParts<S> for User
|
||||
where
|
||||
SqlitePool: FromRef<S>,
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = UserExtractError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let session = Session::from_request_parts(parts, state).await
|
||||
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?;
|
||||
|
||||
let user: User = session.get(USER_SESSION).await
|
||||
.map_err(|_| UserExtractError(StatusCode::INTERNAL_SERVER_ERROR))?
|
||||
.ok_or(UserExtractError(StatusCode::UNAUTHORIZED))?;
|
||||
|
||||
let db = SqlitePool::from_ref(state);
|
||||
|
||||
//TODO actual verification of users
|
||||
if user.email != "whatswithwes@gmail.com" {
|
||||
Err(UserExtractError(StatusCode::FORBIDDEN))?;
|
||||
}
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,44 @@
|
||||
use tower_sessions_sqlx_store::SqliteStore;
|
||||
use tower_sessions::{ExpiredDeletion, Expiry, SessionManagerLayer};
|
||||
use tower_sessions::cookie::SameSite;
|
||||
use time::Duration;
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::task::JoinHandle;
|
||||
use crate::db;
|
||||
|
||||
pub async fn init() -> Result<(SessionManagerLayer<SqliteStore>, JoinHandle<Result<()>>)> {
|
||||
|
||||
// Session store is a session aware database backing for the session data
|
||||
let session_db_location = std::env::var("SESSION_DATABASE_URI")
|
||||
.context("SESSION_DATABASE_URI not set")?;
|
||||
let session_db = db::init(&session_db_location).await?;
|
||||
let session_store = SqliteStore::new(session_db);
|
||||
session_store.migrate().await?;
|
||||
|
||||
// This guy forms the session cookies
|
||||
// The session manager layer is the glue between the session store
|
||||
// and the handlers. The options basically define the options of
|
||||
// the cookies given to the client
|
||||
// Example cookie:
|
||||
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
|
||||
let session_layer = SessionManagerLayer::new(session_store.clone())
|
||||
.with_name("SESSION")
|
||||
.with_same_site(SameSite::Lax)
|
||||
.with_secure(true)
|
||||
.with_http_only(true)
|
||||
.with_path("/")
|
||||
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600)));
|
||||
|
||||
|
||||
// We need to spawn a long-running task to clean up expired sessions
|
||||
let task = tokio::task::spawn(deletion_task(session_store));
|
||||
|
||||
Ok((session_layer, task))
|
||||
}
|
||||
|
||||
async fn deletion_task(session_store: SqliteStore) -> Result<()> {
|
||||
session_store.clone()
|
||||
.continuously_delete_expired(tokio::time::Duration::from_secs(60))
|
||||
.await
|
||||
.context("delete expired task failed")
|
||||
}
|
||||
|
After Width: | Height: | Size: 15 KiB |
Loading…
Reference in new issue