mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-15 18:20:11 +00:00
feat(gateway): add option to use source address in gateway awaiter
This commit is contained in:
parent
795e4f43b4
commit
5fcee4eb29
2 changed files with 30 additions and 10 deletions
|
|
@ -1,5 +1,5 @@
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Path, State},
|
extract::{ConnectInfo, Path, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
routing::{delete, get, post},
|
routing::{delete, get, post},
|
||||||
|
|
@ -17,7 +17,7 @@ use crate::{
|
||||||
},
|
},
|
||||||
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
fn status_code(code: StatusCode, body: String) -> Response {
|
fn status_code(code: StatusCode, body: String) -> Response {
|
||||||
(code, body).into_response()
|
(code, body).into_response()
|
||||||
|
|
@ -197,12 +197,13 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
|
||||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||||
}))
|
}))
|
||||||
|
|
||||||
.route("/await_event", post(|body: String| async move {
|
.route("/await_event", post(|ConnectInfo(addr): ConnectInfo<SocketAddr>, body: String| async move {
|
||||||
info!("got request: {body}");
|
info!("got request: {body} from: {addr}");
|
||||||
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
||||||
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
||||||
};
|
};
|
||||||
awaiter.handle_request(req).await;
|
|
||||||
|
awaiter.handle_request(req, addr).await;
|
||||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||||
}))
|
}))
|
||||||
.route("/clear_awaiter", post(|| async move {
|
.route("/clear_awaiter", post(|| async move {
|
||||||
|
|
@ -216,7 +217,7 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
|
||||||
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
info!("listening on {}", addr);
|
info!("listening on {}", addr);
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{hash_map::Entry, HashMap},
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -149,7 +150,7 @@ impl EventAwaiter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_request(&self, req: AwaitEventRequest) {
|
pub async fn handle_request(&self, req: AwaitEventRequest, addr: SocketAddr) {
|
||||||
match req {
|
match req {
|
||||||
AwaitEventRequest::Reaction {
|
AwaitEventRequest::Reaction {
|
||||||
message_id,
|
message_id,
|
||||||
|
|
@ -167,7 +168,7 @@ impl EventAwaiter {
|
||||||
.unwrap_or(DEFAULT_TIMEOUT),
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
)
|
)
|
||||||
.expect("invalid time"),
|
.expect("invalid time"),
|
||||||
target,
|
target_or_addr(target, addr),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -188,7 +189,7 @@ impl EventAwaiter {
|
||||||
.unwrap_or(DEFAULT_TIMEOUT),
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
)
|
)
|
||||||
.expect("invalid time"),
|
.expect("invalid time"),
|
||||||
target,
|
target_or_addr(target, addr),
|
||||||
options,
|
options,
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
|
@ -208,7 +209,7 @@ impl EventAwaiter {
|
||||||
.unwrap_or(DEFAULT_TIMEOUT),
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
)
|
)
|
||||||
.expect("invalid time"),
|
.expect("invalid time"),
|
||||||
target,
|
target_or_addr(target, addr),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -221,3 +222,21 @@ impl EventAwaiter {
|
||||||
self.interactions.write().await.clear();
|
self.interactions.write().await.clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn target_or_addr(target: String, addr: SocketAddr) -> String {
|
||||||
|
if target == "source-addr" {
|
||||||
|
let ip_str = match addr.ip() {
|
||||||
|
IpAddr::V4(v4) => v4.to_string(),
|
||||||
|
IpAddr::V6(v6) => {
|
||||||
|
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||||
|
v4.to_string()
|
||||||
|
} else {
|
||||||
|
format!("[{v6}]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
format!("http://{ip_str}:5002/events")
|
||||||
|
} else {
|
||||||
|
target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue