From 4db7c19409e93b7db0bebdbaab7875e403ed9c92 Mon Sep 17 00:00:00 2001 From: Araozu Date: Wed, 6 Dec 2023 09:50:42 -0500 Subject: [PATCH] [BE] Fixes #21: Attempt to connect to DB on each query, if no connection is present --- backend/src/controller/person/mod.rs | 2 +- backend/src/main.rs | 39 +++++++++++++++++++++------- backend/src/model/course.rs | 4 +-- backend/src/model/custom_label.rs | 6 ++--- backend/src/model/person.rs | 6 ++--- backend/src/model/register.rs | 10 +++---- 6 files changed, 44 insertions(+), 23 deletions(-) diff --git a/backend/src/controller/person/mod.rs b/backend/src/controller/person/mod.rs index 1e5ccb7..4ffe230 100644 --- a/backend/src/controller/person/mod.rs +++ b/backend/src/controller/person/mod.rs @@ -10,7 +10,7 @@ use crate::{db, model::person::Person}; #[get("/person/")] pub async fn get_by_dni(dni: i32) -> (Status, Json>) { - let db = db(); + let db = db().await; info!("get person with dni {}", dni); /* diff --git a/backend/src/main.rs b/backend/src/main.rs index 1e8cd33..7c2130e 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,5 +1,6 @@ use cors::Cors; use once_cell::sync::OnceCell; +use rocket::tokio; use sqlx::mysql::MySqlPoolOptions; use sqlx::{MySql, Pool}; use std::env; @@ -17,17 +18,29 @@ pub mod json_result; static DB: OnceCell> = OnceCell::new(); /// Returns a global reference to the database pool -/// This MUST be called after the DB pool has been initialized, -/// otherwise it will panic -pub fn db() -> &'static Pool { - DB.get().expect("DB not initialized") +/// +/// If the database pool has not been initialized, this function will attempt to initialize it +/// up to 3 times +/// +/// If the database pool fails to initialize, this function will panic. +pub async fn db() -> &'static Pool { + let attempts = 3; + + for _ in 0..attempts { + match DB.get() { + Some(db) => return db, + None => { + log::info!("DB not initialized, initializing from db()"); + init_db().await; + } + } + }; + + log::error!("Failed to initialize DB after {} attempts", attempts); + panic!("DB not initialized"); } -#[launch] -async fn rocket() -> _ { - dotenvy::dotenv().expect("Failed to load .env file"); - env_logger::init(); - +pub async fn init_db() { /* Init DB and set it as a global variable */ @@ -41,6 +54,14 @@ async fn rocket() -> _ { Ok(pool) => DB.set(pool).expect("Failed to set DB pool"), Err(e) => log::error!("Error connecting to DB: {}", e), } +} + +#[launch] +async fn rocket() -> _ { + dotenvy::dotenv().expect("Failed to load .env file"); + env_logger::init(); + + init_db().await; /* Init Rocket */ rocket::build().attach(Cors {}).mount( diff --git a/backend/src/model/course.rs b/backend/src/model/course.rs index ca84a54..eaa2df0 100644 --- a/backend/src/model/course.rs +++ b/backend/src/model/course.rs @@ -28,7 +28,7 @@ pub struct Course { impl Course { pub async fn get_all() -> Result, sqlx::Error> { - let db = db(); + let db = db().await; let results = sqlx::query!("SELECT * FROM course") .fetch_all(db) @@ -47,7 +47,7 @@ impl Course { } pub async fn get_course_name(course_id: i32) -> Option { - let db = db(); + let db = db().await; let res = sqlx::query!( "SELECT course_name FROM course WHERE course_id = ?", diff --git a/backend/src/model/custom_label.rs b/backend/src/model/custom_label.rs index 2fc6533..eb4aec5 100644 --- a/backend/src/model/custom_label.rs +++ b/backend/src/model/custom_label.rs @@ -9,7 +9,7 @@ pub struct CustomLabel { impl CustomLabel { pub async fn get_all() -> Result, sqlx::Error> { - let db = db(); + let db = db().await; let result = sqlx::query_as::<_, CustomLabel>( r#" @@ -24,7 +24,7 @@ impl CustomLabel { } pub async fn get_id_by_value(value: &String) -> Result { - let db = db(); + let db = db().await; let result = sqlx::query!( "SELECT custom_label_id FROM custom_label WHERE custom_label_value = ?", @@ -43,7 +43,7 @@ impl CustomLabel { } pub async fn create(value: &String) -> Result { - let db = db(); + let db = db().await; sqlx::query!( "INSERT INTO custom_label (custom_label_value) VALUES (?)", diff --git a/backend/src/model/person.rs b/backend/src/model/person.rs index 416681e..aea54f9 100644 --- a/backend/src/model/person.rs +++ b/backend/src/model/person.rs @@ -32,7 +32,7 @@ pub struct Person { impl Person { pub async fn get_by_dni(dni: i32) -> Result { - let db = db(); + let db = db().await; let result = sqlx::query_as!(Person, "SELECT * FROM person WHERE person_dni = ?", dni) .fetch_one(db) @@ -58,7 +58,7 @@ pub struct PersonCreate { impl PersonCreate { pub async fn create(&self) -> Result<(), sqlx::Error> { - let db = db(); + let db = db().await; sqlx::query!( "INSERT INTO person (person_dni, person_names, person_paternal_surname, person_maternal_surname) VALUES (?, ?, ?, ?)", @@ -84,7 +84,7 @@ pub struct PersonLink { impl PersonLink { /// Links a person to a user in the online classroom pub async fn insert(&self) -> Result<(), String> { - let db = db(); + let db = db().await; let res = sqlx::query!( "UPDATE person SET person_classroom_id = ?, person_classroom_username = ? WHERE person_id = ?", diff --git a/backend/src/model/register.rs b/backend/src/model/register.rs index 0f78cdb..b528043 100644 --- a/backend/src/model/register.rs +++ b/backend/src/model/register.rs @@ -32,7 +32,7 @@ pub struct RegisterCreate { impl RegisterCreate { /// Registers a new certificate pub async fn create(&self) -> Result<(), sqlx::Error> { - let db = db(); + let db = db().await; // Get custom_label_id from db based of self.custom_label let custom_label_id = { @@ -82,7 +82,7 @@ impl RegisterCreate { } async fn get_next_register_code(course_id: i32) -> Result { - let db = db(); + let db = db().await; let course_name = Course::get_course_name(course_id).await; @@ -140,7 +140,7 @@ pub struct Register { impl Register { pub async fn get_by_dni(dni: String) -> Result, sqlx::Error> { - let db = db(); + let db = db().await; let res = sqlx::query!( "SELECT * FROM register @@ -166,7 +166,7 @@ impl Register { } pub async fn delete(register_id: i32) -> Result<(), sqlx::Error> { - let db = db(); + let db = db().await; let _ = sqlx::query!("DELETE FROM register WHERE register_id = ?", register_id) .execute(db) @@ -179,7 +179,7 @@ impl Register { register_id_list: String, person_dni_list: String, ) -> Result, sqlx::Error> { - let db = db(); + let db = db().await; let sql = format!( "