1use relay_redis::{
2 AsyncRedisConnection, RedisScripts,
3 redis::{self, FromRedisValue, Script, ServerErrorKind},
4};
5
6use crate::Result;
7
8#[derive(Debug, Clone, Copy)]
10pub enum Status {
11 Rejected,
13 Accepted,
15}
16
17impl Status {
18 pub fn is_rejected(&self) -> bool {
20 matches!(self, Self::Rejected)
21 }
22}
23
24impl FromRedisValue for Status {
25 fn from_redis_value(v: redis::Value) -> Result<Self, redis::ParsingError> {
26 let accepted = bool::from_redis_value(v)?;
27 Ok(if accepted {
28 Self::Accepted
29 } else {
30 Self::Rejected
31 })
32 }
33}
34
35#[derive(Debug)]
37pub struct CardinalityScriptResult {
38 pub cardinality: u32,
40 pub statuses: Vec<Status>,
42}
43
44impl CardinalityScriptResult {
45 pub fn validate(&self, num_hashes: usize) -> Result<()> {
49 if num_hashes == self.statuses.len() {
50 return Ok(());
51 }
52
53 Err(relay_redis::RedisError::Redis(redis::RedisError::from((
54 redis::ErrorKind::Server(ServerErrorKind::ResponseError),
55 "Script returned an invalid number of elements",
56 format!("Expected {num_hashes} results, got {}", self.statuses.len()),
57 )))
58 .into())
59 }
60}
61
62impl FromRedisValue for CardinalityScriptResult {
63 fn from_redis_value(v: redis::Value) -> Result<Self, redis::ParsingError> {
64 let seq = v.into_sequence().map_err(|v| {
65 format!("Expected a sequence from the cardinality script (value was: {v:?})")
66 })?;
67
68 let mut iter = seq.into_iter();
69
70 let cardinality = iter
71 .next()
72 .ok_or("Expected cardinality as the first result from the cardinality script")
73 .map_err(redis::ParsingError::from)
74 .and_then(FromRedisValue::from_redis_value)?;
75
76 let mut statuses = Vec::with_capacity(iter.len());
77 for value in iter {
78 statuses.push(Status::from_redis_value(value)?);
79 }
80
81 Ok(Self {
82 cardinality,
83 statuses,
84 })
85 }
86}
87
88pub struct CardinalityScript(&'static Script);
90
91impl CardinalityScript {
92 pub fn load() -> Self {
96 Self(RedisScripts::load_cardinality())
97 }
98
99 pub fn pipe(&self) -> CardinalityScriptPipeline<'_> {
101 CardinalityScriptPipeline {
102 script: self,
103 pipe: redis::pipe(),
104 }
105 }
106
107 async fn load_redis(&self, con: &mut AsyncRedisConnection) -> Result<()> {
109 self.0
110 .prepare_invoke()
111 .load_async(con)
112 .await
113 .map_err(relay_redis::RedisError::Redis)?;
114
115 Ok(())
116 }
117
118 fn prepare_invocation(
120 &self,
121 limit: u32,
122 expire: u64,
123 hashes: impl Iterator<Item = u32>,
124 keys: impl Iterator<Item = String>,
125 ) -> redis::ScriptInvocation<'_> {
126 let mut invocation = self.0.prepare_invoke();
127
128 for key in keys {
129 invocation.key(key);
130 }
131
132 invocation.arg(limit);
133 invocation.arg(expire);
134
135 for hash in hashes {
136 invocation.arg(&hash.to_le_bytes());
137 }
138
139 invocation
140 }
141}
142
143pub struct CardinalityScriptPipeline<'a> {
145 script: &'a CardinalityScript,
146 pipe: redis::Pipeline,
147}
148
149impl CardinalityScriptPipeline<'_> {
150 pub fn add_invocation(
152 &mut self,
153 limit: u32,
154 expire: u64,
155 hashes: impl Iterator<Item = u32>,
156 keys: impl Iterator<Item = String>,
157 ) -> &mut Self {
158 let invocation = self.script.prepare_invocation(limit, expire, hashes, keys);
159 self.pipe.invoke_script(&invocation);
160 self
161 }
162
163 pub async fn invoke(
167 &self,
168 con: &mut AsyncRedisConnection,
169 ) -> Result<Vec<CardinalityScriptResult>> {
170 match self.pipe.query_async(con).await {
171 Ok(result) => Ok(result),
172 Err(err) if err.kind() == redis::ErrorKind::Server(ServerErrorKind::NoScript) => {
173 relay_log::trace!("Redis script no loaded, loading it now");
174 self.script.load_redis(con).await?;
175 self.pipe
176 .query_async(con)
177 .await
178 .map_err(relay_redis::RedisError::Redis)
179 .map_err(Into::into)
180 }
181 Err(err) => Err(relay_redis::RedisError::Redis(err).into()),
182 }
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use relay_redis::{AsyncRedisClient, RedisConfigOptions};
189 use uuid::Uuid;
190
191 use super::*;
192
193 impl CardinalityScript {
194 async fn invoke_one(
195 &self,
196 con: &mut AsyncRedisConnection,
197 limit: u32,
198 expire: u64,
199 hashes: impl Iterator<Item = u32>,
200 keys: impl Iterator<Item = String>,
201 ) -> Result<CardinalityScriptResult> {
202 let mut results = self
203 .pipe()
204 .add_invocation(limit, expire, hashes, keys)
205 .invoke(con)
206 .await?;
207
208 assert_eq!(results.len(), 1);
209 Ok(results.pop().unwrap())
210 }
211 }
212
213 fn build_redis_client() -> AsyncRedisClient {
214 let url = std::env::var("RELAY_REDIS_URL")
215 .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
216
217 let opts = RedisConfigOptions {
218 max_connections: 1,
219 ..Default::default()
220 };
221 AsyncRedisClient::single("test", &url, &opts).unwrap()
222 }
223
224 fn keys(prefix: Uuid, keys: &[&str]) -> impl Iterator<Item = String> {
225 keys.iter()
226 .map(move |key| format!("{prefix}-{key}"))
227 .collect::<Vec<_>>()
228 .into_iter()
229 }
230
231 async fn assert_ttls(connection: &mut AsyncRedisConnection, prefix: Uuid) {
232 let keys = redis::cmd("KEYS")
233 .arg(format!("{prefix}-*"))
234 .query_async::<Vec<String>>(connection)
235 .await
236 .unwrap();
237
238 for key in keys {
239 let ttl = redis::cmd("TTL")
240 .arg(&key)
241 .query_async::<i64>(connection)
242 .await
243 .unwrap();
244
245 assert!(ttl >= 0, "Key {key} has no TTL");
246 }
247 }
248
249 #[tokio::test]
250 async fn test_below_limit_perfect_cardinality_ttl() {
251 let client = build_redis_client();
252 let mut connection = client.get_connection().await.unwrap();
253
254 let script = CardinalityScript::load();
255
256 let prefix = Uuid::new_v4();
257 let k1 = &["a", "b", "c"];
258 let k2 = &["b", "c", "d"];
259
260 script
261 .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k1))
262 .await
263 .unwrap();
264
265 script
266 .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k2))
267 .await
268 .unwrap();
269
270 assert_ttls(&mut connection, prefix).await;
271 }
272
273 #[tokio::test]
274 async fn test_load_script() {
275 let client = build_redis_client();
276 let mut connection = client.get_connection().await.unwrap();
277
278 let script = CardinalityScript::load();
279 let keys = keys(Uuid::new_v4(), &["a", "b", "c"]);
280
281 redis::cmd("SCRIPT")
282 .arg("FLUSH")
283 .exec_async(&mut connection)
284 .await
285 .unwrap();
286 script
287 .invoke_one(&mut connection, 50, 3600, 0..30, keys)
288 .await
289 .unwrap();
290 }
291
292 #[tokio::test]
293 async fn test_multiple_calls_in_pipeline() {
294 let client = build_redis_client();
295 let mut connection = client.get_connection().await.unwrap();
296
297 let script = CardinalityScript::load();
298 let k2 = keys(Uuid::new_v4(), &["a", "b", "c"]);
299 let k1 = keys(Uuid::new_v4(), &["a", "b", "c"]);
300
301 let mut pipeline = script.pipe();
302 let results = pipeline
303 .add_invocation(50, 3600, 0..30, k1)
304 .add_invocation(50, 3600, 0..30, k2)
305 .invoke(&mut connection)
306 .await
307 .unwrap();
308
309 assert_eq!(results.len(), 2);
310 }
311}