You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
112 lines
3.8 KiB
112 lines
3.8 KiB
use crate::db;
|
|
use crate::error::AppError;
|
|
use anyhow::{anyhow, Context, Result};
|
|
use askama_axum::{IntoResponse, Response};
|
|
use axum::async_trait;
|
|
use axum::extract::FromRequestParts;
|
|
use axum::http::request::Parts;
|
|
use axum::response::Redirect;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::result;
|
|
use chrono::FixedOffset;
|
|
use tokio::task::JoinHandle;
|
|
use tower_sessions::cookie::SameSite;
|
|
use tower_sessions::{ExpiredDeletion, Expiry, Session, SessionManagerLayer};
|
|
use tower_sessions::cookie::time::Duration;
|
|
use tower_sessions_sqlx_store::SqliteStore;
|
|
|
|
pub async fn init() -> Result<(SessionManagerLayer<SqliteStore>, JoinHandle<Result<()>>)> {
|
|
|
|
// Session store is a session aware database backing for the session data
|
|
let session_db_location = std::env::var("SESSION_DATABASE_URL")
|
|
.context("SESSION_DATABASE_URL not set")?;
|
|
let session_db = db::connect_db(&session_db_location).await?;
|
|
let session_store = SqliteStore::new(session_db);
|
|
session_store.migrate().await?;
|
|
|
|
// This guy form the session cookies
|
|
// The session manager layer is the glue between the session store
|
|
// and the handlers. The options basically define the options of
|
|
// the cookies given to the client
|
|
// Example cookie:
|
|
// SESSION=biglongsessionid; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age=3600
|
|
let session_layer = SessionManagerLayer::new(session_store.clone())
|
|
.with_name("SESSION")
|
|
.with_same_site(SameSite::Lax)
|
|
.with_secure(true)
|
|
.with_http_only(true)
|
|
.with_path("/")
|
|
.with_expiry(Expiry::OnInactivity(Duration::seconds(3600)));
|
|
|
|
|
|
// We need to spawn a long-running task to clean up expired sessions
|
|
let task = tokio::task::spawn(deletion_task(session_store));
|
|
|
|
Ok((session_layer, task))
|
|
}
|
|
|
|
async fn deletion_task(session_store: SqliteStore) -> Result<()> {
|
|
session_store.clone()
|
|
.continuously_delete_expired(tokio::time::Duration::from_secs(60))
|
|
.await
|
|
.context("delete expired task failed")
|
|
}
|
|
|
|
pub const USER_SESSION: &str = "user";
|
|
|
|
/// User information that will be return from the OAUTH authority
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct SessionUser {
|
|
pub id: i64,
|
|
pub role: i64,
|
|
pub oauth_id: String,
|
|
pub email: String,
|
|
pub name: String,
|
|
pub verified_email: bool,
|
|
pub picture: String,
|
|
pub tz_offset: i32,
|
|
}
|
|
|
|
/// A custom error for the User extractor
|
|
pub enum UserExtractError {
|
|
InternalServerError(anyhow::Error),
|
|
Unauthorized,
|
|
}
|
|
|
|
impl IntoResponse for UserExtractError {
|
|
fn into_response(self) -> Response {
|
|
match self {
|
|
UserExtractError::InternalServerError(err) => AppError::from(err).into_response(),
|
|
UserExtractError::Unauthorized => { Redirect::temporary("/auth/login").into_response() }
|
|
}
|
|
}
|
|
}
|
|
|
|
/// The user extractor is used to pull out the user data from the session. This can be used
|
|
/// as a guard to ensure that a user session exists. Basically an authentication
|
|
/// (but not authorization) guard
|
|
#[async_trait]
|
|
impl<S> FromRequestParts<S> for SessionUser
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = UserExtractError;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> result::Result<Self, Self::Rejection> {
|
|
let session = Session::from_request_parts(parts, state).await
|
|
.map_err(|_| UserExtractError::InternalServerError(anyhow!("session from parts failed")))?;
|
|
|
|
let user = session.get(USER_SESSION).await
|
|
.map_err(|e| UserExtractError::InternalServerError(anyhow::Error::from(e)))?
|
|
.ok_or(UserExtractError::Unauthorized)?;
|
|
|
|
Ok(user)
|
|
}
|
|
}
|
|
|
|
impl SessionUser {
|
|
pub fn get_timezone(&self) -> Result<FixedOffset> {
|
|
FixedOffset::east_opt(self.tz_offset)
|
|
.ok_or(anyhow::anyhow!("Invalid timezone"))
|
|
}
|
|
} |