1use relay_redis::{
2 AsyncRedisConnection, RedisScripts,
3 redis::{self, FromRedisValue, Script},
4};
5
6use crate::Result;
7
8#[derive(Debug, Clone, Copy)]
10pub enum Status {
11 Rejected,
13 Accepted,
15}
16
17impl Status {
18 pub fn is_rejected(&self) -> bool {
20 matches!(self, Self::Rejected)
21 }
22}
23
24impl FromRedisValue for Status {
25 fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
26 let accepted = bool::from_redis_value(v)?;
27 Ok(if accepted {
28 Self::Accepted
29 } else {
30 Self::Rejected
31 })
32 }
33}
34
35#[derive(Debug)]
37pub struct CardinalityScriptResult {
38 pub cardinality: u32,
40 pub statuses: Vec<Status>,
42}
43
44impl CardinalityScriptResult {
45 pub fn validate(&self, num_hashes: usize) -> Result<()> {
49 if num_hashes == self.statuses.len() {
50 return Ok(());
51 }
52
53 Err(relay_redis::RedisError::Redis(redis::RedisError::from((
54 redis::ErrorKind::ResponseError,
55 "Script returned an invalid number of elements",
56 format!("Expected {num_hashes} results, got {}", self.statuses.len()),
57 )))
58 .into())
59 }
60}
61
62impl FromRedisValue for CardinalityScriptResult {
63 fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
64 let Some(seq) = v.as_sequence() else {
65 return Err(redis::RedisError::from((
66 redis::ErrorKind::TypeError,
67 "Expected a sequence from the cardinality script",
68 format!("{v:?}"),
69 )));
70 };
71
72 let mut iter = seq.iter();
73
74 let cardinality = iter
75 .next()
76 .ok_or_else(|| {
77 redis::RedisError::from((
78 redis::ErrorKind::TypeError,
79 "Expected cardinality as the first result from the cardinality script",
80 ))
81 })
82 .and_then(FromRedisValue::from_redis_value)?;
83
84 let mut statuses = Vec::with_capacity(iter.len());
85 for value in iter {
86 statuses.push(Status::from_redis_value(value)?);
87 }
88
89 Ok(Self {
90 cardinality,
91 statuses,
92 })
93 }
94}
95
96pub struct CardinalityScript(&'static Script);
98
99impl CardinalityScript {
100 pub fn load() -> Self {
104 Self(RedisScripts::load_cardinality())
105 }
106
107 pub fn pipe(&self) -> CardinalityScriptPipeline<'_> {
109 CardinalityScriptPipeline {
110 script: self,
111 pipe: redis::pipe(),
112 }
113 }
114
115 async fn load_redis(&self, con: &mut AsyncRedisConnection) -> Result<()> {
117 self.0
118 .prepare_invoke()
119 .load_async(con)
120 .await
121 .map_err(relay_redis::RedisError::Redis)?;
122
123 Ok(())
124 }
125
126 fn prepare_invocation(
128 &self,
129 limit: u32,
130 expire: u64,
131 hashes: impl Iterator<Item = u32>,
132 keys: impl Iterator<Item = String>,
133 ) -> redis::ScriptInvocation<'_> {
134 let mut invocation = self.0.prepare_invoke();
135
136 for key in keys {
137 invocation.key(key);
138 }
139
140 invocation.arg(limit);
141 invocation.arg(expire);
142
143 for hash in hashes {
144 invocation.arg(&hash.to_le_bytes());
145 }
146
147 invocation
148 }
149}
150
151pub struct CardinalityScriptPipeline<'a> {
153 script: &'a CardinalityScript,
154 pipe: redis::Pipeline,
155}
156
157impl CardinalityScriptPipeline<'_> {
158 pub fn add_invocation(
160 &mut self,
161 limit: u32,
162 expire: u64,
163 hashes: impl Iterator<Item = u32>,
164 keys: impl Iterator<Item = String>,
165 ) -> &mut Self {
166 let invocation = self.script.prepare_invocation(limit, expire, hashes, keys);
167 self.pipe.invoke_script(&invocation);
168 self
169 }
170
171 pub async fn invoke(
175 &self,
176 con: &mut AsyncRedisConnection,
177 ) -> Result<Vec<CardinalityScriptResult>> {
178 match self.pipe.query_async(con).await {
179 Ok(result) => Ok(result),
180 Err(err) if err.kind() == redis::ErrorKind::NoScriptError => {
181 relay_log::trace!("Redis script no loaded, loading it now");
182 self.script.load_redis(con).await?;
183 self.pipe
184 .query_async(con)
185 .await
186 .map_err(relay_redis::RedisError::Redis)
187 .map_err(Into::into)
188 }
189 Err(err) => Err(relay_redis::RedisError::Redis(err).into()),
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
197 use uuid::Uuid;
198
199 use super::*;
200
201 impl CardinalityScript {
202 async fn invoke_one(
203 &self,
204 con: &mut AsyncRedisConnection,
205 limit: u32,
206 expire: u64,
207 hashes: impl Iterator<Item = u32>,
208 keys: impl Iterator<Item = String>,
209 ) -> Result<CardinalityScriptResult> {
210 let mut results = self
211 .pipe()
212 .add_invocation(limit, expire, hashes, keys)
213 .invoke(con)
214 .await?;
215
216 assert_eq!(results.len(), 1);
217 Ok(results.pop().unwrap())
218 }
219 }
220
221 fn build_redis_client() -> AsyncRedisClient {
222 let url = std::env::var("RELAY_REDIS_URL")
223 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
224
225 let opts = RedisConfigOptions {
226 max_connections: 1,
227 ..Default::default()
228 };
229 AsyncRedisClient::single(&url, &opts).unwrap()
230 }
231
232 fn keys(prefix: Uuid, keys: &[&str]) -> impl Iterator<Item = String> {
233 keys.iter()
234 .map(move |key| format!("{prefix}-{key}"))
235 .collect::<Vec<_>>()
236 .into_iter()
237 }
238
239 async fn assert_ttls(connection: &mut AsyncRedisConnection, prefix: Uuid) {
240 let keys = redis::cmd("KEYS")
241 .arg(format!("{prefix}-*"))
242 .query_async::<Vec<String>>(connection)
243 .await
244 .unwrap();
245
246 for key in keys {
247 let ttl = redis::cmd("TTL")
248 .arg(&key)
249 .query_async::<i64>(connection)
250 .await
251 .unwrap();
252
253 assert!(ttl >= 0, "Key {key} has no TTL");
254 }
255 }
256
257 #[tokio::test]
258 async fn test_below_limit_perfect_cardinality_ttl() {
259 let client = build_redis_client();
260 let mut connection = client.get_connection().await.unwrap();
261
262 let script = CardinalityScript::load();
263
264 let prefix = Uuid::new_v4();
265 let k1 = &["a", "b", "c"];
266 let k2 = &["b", "c", "d"];
267
268 script
269 .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k1))
270 .await
271 .unwrap();
272
273 script
274 .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k2))
275 .await
276 .unwrap();
277
278 assert_ttls(&mut connection, prefix).await;
279 }
280
281 #[tokio::test]
282 async fn test_load_script() {
283 let client = build_redis_client();
284 let mut connection = client.get_connection().await.unwrap();
285
286 let script = CardinalityScript::load();
287 let keys = keys(Uuid::new_v4(), &["a", "b", "c"]);
288
289 redis::cmd("SCRIPT")
290 .arg("FLUSH")
291 .exec_async(&mut connection)
292 .await
293 .unwrap();
294 script
295 .invoke_one(&mut connection, 50, 3600, 0..30, keys)
296 .await
297 .unwrap();
298 }
299
300 #[tokio::test]
301 async fn test_multiple_calls_in_pipeline() {
302 let client = build_redis_client();
303 let mut connection = client.get_connection().await.unwrap();
304
305 let script = CardinalityScript::load();
306 let k2 = keys(Uuid::new_v4(), &["a", "b", "c"]);
307 let k1 = keys(Uuid::new_v4(), &["a", "b", "c"]);
308
309 let mut pipeline = script.pipe();
310 let results = pipeline
311 .add_invocation(50, 3600, 0..30, k1)
312 .add_invocation(50, 3600, 0..30, k2)
313 .invoke(&mut connection)
314 .await
315 .unwrap();
316
317 assert_eq!(results.len(), 2);
318 }
319}