1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use relay_auth::{PublicKey, RelayId};
7use relay_config::{Config, RelayInfo};
8use relay_system::{
9 Addr, BroadcastChannel, BroadcastResponse, BroadcastSender, FromMessage, Interface, Service,
10};
11use serde::{Deserialize, Serialize};
12use tokio::sync::mpsc;
13
14use crate::services::upstream::{Method, RequestPriority, SendQuery, UpstreamQuery, UpstreamRelay};
15use crate::utils::{RetryBackoff, SleepHandle};
16
17#[derive(Debug)]
22pub struct GetRelay {
23 pub relay_id: RelayId,
27}
28
29pub type GetRelayResult = Option<RelayInfo>;
33
34#[derive(Debug)]
36pub struct RelayCache(GetRelay, BroadcastSender<GetRelayResult>);
37
38impl Interface for RelayCache {}
39
40impl FromMessage<GetRelay> for RelayCache {
41 type Response = BroadcastResponse<GetRelayResult>;
42
43 fn from_message(message: GetRelay, sender: BroadcastSender<GetRelayResult>) -> Self {
44 Self(message, sender)
45 }
46}
47
48#[derive(Debug, Deserialize)]
50#[serde(rename_all = "camelCase")]
51pub struct PublicKeysResultCompatibility {
52 #[serde(default, rename = "public_keys")]
54 pub public_keys: HashMap<RelayId, Option<PublicKey>>,
55
56 #[serde(default)]
61 pub relays: HashMap<RelayId, Option<RelayInfo>>,
62}
63
64#[derive(Debug, Serialize, Deserialize)]
70pub struct GetRelaysResponse {
71 pub relays: HashMap<RelayId, Option<RelayInfo>>,
76}
77
78impl From<PublicKeysResultCompatibility> for GetRelaysResponse {
79 fn from(relays_info: PublicKeysResultCompatibility) -> Self {
80 let relays = if relays_info.relays.is_empty() && !relays_info.public_keys.is_empty() {
81 relays_info
82 .public_keys
83 .into_iter()
84 .map(|(id, pk)| (id, pk.map(RelayInfo::new)))
85 .collect()
86 } else {
87 relays_info.relays
88 };
89 Self { relays }
90 }
91}
92
93#[derive(Debug, Deserialize, Serialize)]
95pub struct GetRelays {
96 pub relay_ids: Vec<RelayId>,
98}
99
100impl UpstreamQuery for GetRelays {
101 type Response = PublicKeysResultCompatibility;
102
103 fn method(&self) -> Method {
104 Method::POST
105 }
106
107 fn path(&self) -> Cow<'static, str> {
108 Cow::Borrowed("/api/0/relays/publickeys/")
109 }
110
111 fn priority() -> RequestPriority {
112 RequestPriority::High
113 }
114
115 fn retry() -> bool {
116 false
117 }
118
119 fn route(&self) -> &'static str {
120 "public_keys"
121 }
122}
123
124#[derive(Debug)]
126enum RelayState {
127 Exists {
128 relay: RelayInfo,
129 checked_at: Instant,
130 },
131 DoesNotExist {
132 checked_at: Instant,
133 },
134}
135
136impl RelayState {
137 fn is_valid_cache(&self, config: &Config) -> bool {
139 match *self {
140 RelayState::Exists { checked_at, .. } => {
141 checked_at.elapsed() < config.relay_cache_expiry()
142 }
143 RelayState::DoesNotExist { checked_at } => {
144 checked_at.elapsed() < config.cache_miss_expiry()
145 }
146 }
147 }
148
149 fn as_option(&self) -> Option<&RelayInfo> {
153 match *self {
154 RelayState::Exists { ref relay, .. } => Some(relay),
155 _ => None,
156 }
157 }
158
159 fn from_option(option: Option<RelayInfo>) -> Self {
161 match option {
162 Some(relay) => RelayState::Exists {
163 relay,
164 checked_at: Instant::now(),
165 },
166 None => RelayState::DoesNotExist {
167 checked_at: Instant::now(),
168 },
169 }
170 }
171}
172
173type FetchResult = Result<GetRelaysResponse, HashMap<RelayId, BroadcastChannel<GetRelayResult>>>;
178
179#[derive(Debug)]
181pub struct RelayCacheService {
182 static_relays: HashMap<RelayId, RelayInfo>,
183 relays: HashMap<RelayId, RelayState>,
184 channels: HashMap<RelayId, BroadcastChannel<GetRelayResult>>,
185 fetch_channel: (mpsc::Sender<FetchResult>, mpsc::Receiver<FetchResult>),
186 backoff: RetryBackoff,
187 delay: SleepHandle,
188 config: Arc<Config>,
189 upstream_relay: Addr<UpstreamRelay>,
190}
191
192impl RelayCacheService {
193 pub fn new(config: Arc<Config>, upstream_relay: Addr<UpstreamRelay>) -> Self {
195 Self {
196 static_relays: config.static_relays().clone(),
197 relays: HashMap::new(),
198 channels: HashMap::new(),
199 fetch_channel: mpsc::channel(1),
200 backoff: RetryBackoff::new(config.http_max_retry_interval()),
201 delay: SleepHandle::idle(),
202 config,
203 upstream_relay,
204 }
205 }
206
207 fn fetch_tx(&self) -> mpsc::Sender<FetchResult> {
209 let (ref tx, _) = self.fetch_channel;
210 tx.clone()
211 }
212
213 fn next_backoff(&mut self) -> Duration {
218 self.config.downstream_relays_batch_interval() + self.backoff.next_backoff()
219 }
220
221 fn schedule_fetch(&mut self) {
223 let backoff = self.next_backoff();
224 self.delay.set(backoff);
225 }
226
227 fn fetch_relays(&mut self) {
232 let channels = std::mem::take(&mut self.channels);
233 relay_log::debug!(
234 "updating public keys for {} relays (attempt {})",
235 channels.len(),
236 self.backoff.attempt(),
237 );
238
239 let fetch_tx = self.fetch_tx();
240 let upstream_relay = self.upstream_relay.clone();
241 relay_system::spawn!(async move {
242 let request = GetRelays {
243 relay_ids: channels.keys().cloned().collect(),
244 };
245
246 let query_result = match upstream_relay.send(SendQuery(request)).await {
247 Ok(inner) => inner,
248 Err(_send_error) => return,
250 };
251
252 let fetch_result = match query_result {
253 Ok(response) => {
254 let response = GetRelaysResponse::from(response);
255
256 for (id, channel) in channels {
257 relay_log::debug!("relay {id} public key updated");
258 let info = response.relays.get(&id).unwrap_or(&None);
259 channel.send(info.clone());
260 }
261
262 Ok(response)
263 }
264 Err(error) => {
265 relay_log::error!(
266 error = &error as &dyn std::error::Error,
267 "error fetching public keys"
268 );
269 Err(channels)
270 }
271 };
272
273 fetch_tx.send(fetch_result).await.ok();
274 });
275 }
276
277 fn handle_fetch_result(&mut self, result: FetchResult) {
279 match result {
280 Ok(response) => {
281 self.backoff.reset();
282
283 for (id, info) in response.relays {
284 self.relays.insert(id, RelayState::from_option(info));
285 }
286 }
287 Err(channels) => {
288 self.channels.extend(channels);
289 }
290 }
291
292 if !self.channels.is_empty() {
293 self.schedule_fetch();
294 }
295 }
296
297 fn get_or_fetch(&mut self, message: GetRelay, sender: BroadcastSender<GetRelayResult>) {
302 let relay_id = message.relay_id;
303
304 if let Some(key) = self.static_relays.get(&relay_id) {
306 sender.send(Some(key.clone()));
307 return;
308 }
309
310 if let Some(key) = self.relays.get(&relay_id) {
311 if key.is_valid_cache(&self.config) {
312 sender.send(key.as_option().cloned());
313 return;
314 }
315 }
316
317 if self.config.credentials().is_none() {
318 relay_log::error!(
319 "no credentials configured. relay {relay_id} cannot send requests to this relay",
320 );
321 sender.send(None);
322 return;
323 }
324
325 relay_log::debug!("relay {relay_id} public key requested");
326 self.channels.entry(relay_id).or_default().attach(sender);
327
328 if !self.backoff.started() {
329 self.schedule_fetch();
330 }
331 }
332}
333
334impl Service for RelayCacheService {
335 type Interface = RelayCache;
336
337 async fn run(mut self, mut rx: relay_system::Receiver<Self::Interface>) {
338 relay_log::info!("key cache started");
339
340 loop {
341 tokio::select! {
342 biased;
344
345 Some(result) = self.fetch_channel.1.recv() => self.handle_fetch_result(result),
346 () = &mut self.delay => self.fetch_relays(),
347 Some(message) = rx.recv() => self.get_or_fetch(message.0, message.1),
348 else => break,
349 }
350 }
351
352 relay_log::info!("key cache stopped");
353 }
354}