diff --git a/db/schema.sql b/db/schema.sql index 6b39ed2..524913f 100644 --- a/db/schema.sql +++ b/db/schema.sql @@ -2,6 +2,13 @@ create table person( person_id serial primary key, - person_email varchar(100), - person_password varchar(100) + person_email varchar(100) not null, + person_password varchar(100) not null +); + +-- Sessions +create table session( + session_id serial primary key, + session_person_id integer not null, + session_created_at timestamp with time zone not null ); diff --git a/src/controller/mod.rs b/src/controller/mod.rs index 8155cf0..9c22e6e 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -1,9 +1,15 @@ use maud::Markup; -use rocket::{form::Form, http::{CookieJar, Status}, response::Redirect}; +use rocket::{ + form::Form, + http::{CookieJar, Status}, + response::Redirect, +}; +use session::new_session; use utils::RS_SESSION_ID; use crate::db; +mod session; mod utils; #[get("/")] @@ -12,9 +18,10 @@ pub fn homepage() -> Markup { } #[get("/new")] -pub fn new_definition(_user: utils::User) -> (Status, Markup) { +pub fn new_definition(user: utils::Person) -> (Status, String) { log::info!("reached /new"); - todo!() + + (Status::Ok, format!("User: {}", user.person_id)) } #[get("/new", rank = 2)] @@ -56,9 +63,16 @@ pub async fn login(data: Form, cookies: &CookieJar<'_>) -> (Status, S Err(reason) => return (Status::InternalServerError, format!("{:?}", reason)), }; - if !re.is_empty() { + if re.len() == 1 { // TODO: generate a session id and assign - cookies.add_private((RS_SESSION_ID, "session-id")); + let person = &re[0]; + + let session_id = match new_session(person.person_id).await { + Ok(s) => s, + Err(error) => return (Status::Unauthorized, error), + }; + + cookies.add_private((RS_SESSION_ID, session_id.to_string())); (Status::Ok, "
".into()) } else { diff --git a/src/controller/session.rs b/src/controller/session.rs new file mode 100644 index 0000000..c5f01d8 --- /dev/null +++ b/src/controller/session.rs @@ -0,0 +1,64 @@ +use sqlx::types::chrono::Local; + +use crate::db; + +/// 30 minutes +const SESSION_LEN: i64 = 1000 * 60 * 30; + +/// Checks that the session signaled by session_id +/// is valid, and returns its person_id +pub async fn check_session(session_id: i32) -> Result { + let db = match db().await { + Ok(handle) => handle, + Err(e) => return Err(e), + }; + + // TODO: return the user info on the same trip + let result = sqlx::query!("select * from session where session_id = $1", session_id,) + .fetch_one(db) + .await; + + let result = match result { + Ok(r) => r, + Err(e) => return Err(format!("{:?}", e)), + }; + + let person_id = result.session_person_id; + let created_at = result.session_created_at.timestamp_millis(); + let current_time = Local::now().to_utc().timestamp_millis(); + + let time_difference = current_time - created_at; + + if time_difference < SESSION_LEN { + Ok(person_id) + } else { + // TODO: also remove all expired sessions + + Err("Expired".into()) + } +} + +/// Creates a new session for the person_id. Deletes any +/// previous session. +pub async fn new_session(person_id: i32) -> Result { + let db = match db().await { + Ok(handle) => handle, + Err(e) => return Err(e), + }; + + let now = Local::now().to_utc(); + + let result = sqlx::query!( + "insert into session (session_person_id, session_created_at) values + ($1, $2) returning session_id", + person_id, + now, + ) + .fetch_one(db) + .await; + + match result { + Ok(r) => Ok(r.session_id), + Err(reason) => return Err(format!("{:?}", reason)), + } +} diff --git a/src/controller/utils.rs b/src/controller/utils.rs index 504e42e..c23e6fa 100644 --- a/src/controller/utils.rs +++ b/src/controller/utils.rs @@ -3,26 +3,61 @@ use rocket::{ request::{FromRequest, Outcome, Request}, }; +use super::session::check_session; + /// Name of the header that stores the session ID of an user pub const RS_SESSION_ID: &str = "x-rs-session-id"; -pub struct User {} +pub struct Person { + pub person_id: i32, +} #[rocket::async_trait] -impl<'r> FromRequest<'r> for User { +impl<'r> FromRequest<'r> for Person { type Error = String; async fn from_request(req: &'r Request<'_>) -> Outcome { let session_cookie = req.cookies().get_private(RS_SESSION_ID); - let _session_cookie = match session_cookie { + let session_cookie = match session_cookie { Some(cookie) => cookie.value().to_string(), - None => return Outcome::Forward(Status::Unauthorized), + None => { + // remove cookie + req.cookies().remove_private(RS_SESSION_ID); + return Outcome::Forward(Status::Unauthorized); + } }; // Check if session cookie is valid // TODO + // If the session cookie is not valid, + // remove it and forward to login - Outcome::Success(User {}) + let session_id: i32 = match session_cookie.parse() { + Ok(v) => v, + Err(e) => { + // remove cookie + req.cookies().remove_private(RS_SESSION_ID); + log::error!( + "Error converting session_id to i32 ({}):\n{:?}", + session_cookie, + e + ); + return Outcome::Error(( + Status::Unauthorized, + "Invalid session id: not a number".into(), + )); + } + }; + + match check_session(session_id).await { + Ok(person_id) => Outcome::Success(Person { person_id }), + Err(reason) => { + // remove cookie + req.cookies().remove_private(RS_SESSION_ID); + log::info!("session check fail: {}", reason); + Outcome::Forward(Status::Unauthorized) + } + } } }