From a49dbefe83ee2e756308d4e762d04252309b0046 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sat, 9 Aug 2025 14:50:57 +0000 Subject: [PATCH] fix(api): automatically reload ratelimit script on redis server restart --- crates/api/src/middleware/ratelimit.rs | 43 +++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/crates/api/src/middleware/ratelimit.rs b/crates/api/src/middleware/ratelimit.rs index f4a63f7e..1638ecc9 100644 --- a/crates/api/src/middleware/ratelimit.rs +++ b/crates/api/src/middleware/ratelimit.rs @@ -45,21 +45,6 @@ pub fn ratelimiter(f: F) -> FromFnLayer, T> { tokio::spawn(async move { handle }); - let rscript = r.clone(); - tokio::spawn(async move { - if let Ok(()) = rscript.wait_for_connect().await { - match rscript - .script_load::(LUA_SCRIPT.to_string()) - .await - { - Ok(_) => info!("connected to redis for request rate limiting"), - Err(error) => error!(?error, "could not load redis script"), - } - } else { - error!("could not wait for connection to load redis script!"); - } - }); - r }); @@ -152,12 +137,34 @@ pub async fn do_request_ratelimited( let period = 1; // seconds let cost = 1; // todo: update this for group member endpoints + let script_exists: Vec = + match redis.script_exists(vec![LUA_SCRIPT_SHA.to_string()]).await { + Ok(exists) => exists, + Err(error) => { + error!(?error, "failed to check ratelimit script"); + return json_err( + StatusCode::INTERNAL_SERVER_ERROR, + r#"{"message": "500: internal server error", "code": 0}"#.to_string(), + ); + } + }; + + if script_exists[0] != 1 { + match redis + .script_load::(LUA_SCRIPT.to_string()) + .await + { + Ok(_) => info!("successfully loaded ratelimit script to redis"), + Err(error) => { + error!(?error, "could not load redis script") + } + } + } + // local rate_limit_key = KEYS[1] // local rate = ARGV[1] // local period = ARGV[2] // return {remaining, tostring(retry_after), reset_after} - - // todo: check if error is script not found and reload script let resp = redis .evalsha::<(i32, String, u64), String, Vec, Vec>( LUA_SCRIPT_SHA.to_string(), @@ -219,7 +226,7 @@ pub async fn do_request_ratelimited( return response; } Err(error) => { - tracing::error!(?error, "error getting ratelimit info"); + error!(?error, "error getting ratelimit info"); return json_err( StatusCode::INTERNAL_SERVER_ERROR, r#"{"message": "500: internal server error", "code": 0}"#.to_string(),