relay_cardinality/redis/
script.rs

1use relay_redis::{
2    AsyncRedisConnection, RedisScripts,
3    redis::{self, FromRedisValue, Script},
4};
5
6use crate::Result;
7
8/// Status wether an entry/bucket is accepted or rejected by the cardinality limiter.
9#[derive(Debug, Clone, Copy)]
10pub enum Status {
11    /// Item is rejected.
12    Rejected,
13    /// Item is accepted.
14    Accepted,
15}
16
17impl Status {
18    /// Returns `true` if the status is [`Status::Rejected`].
19    pub fn is_rejected(&self) -> bool {
20        matches!(self, Self::Rejected)
21    }
22}
23
24impl FromRedisValue for Status {
25    fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
26        let accepted = bool::from_redis_value(v)?;
27        Ok(if accepted {
28            Self::Accepted
29        } else {
30            Self::Rejected
31        })
32    }
33}
34
35/// Result returned from [`CardinalityScript`].
36#[derive(Debug)]
37pub struct CardinalityScriptResult {
38    /// Cardinality of the limit.
39    pub cardinality: u32,
40    /// Status for each hash passed to the script.
41    pub statuses: Vec<Status>,
42}
43
44impl CardinalityScriptResult {
45    /// Validates the result against the amount of hashes originally supplied.
46    ///
47    /// This is not necessarily required but recommended.
48    pub fn validate(&self, num_hashes: usize) -> Result<()> {
49        if num_hashes == self.statuses.len() {
50            return Ok(());
51        }
52
53        Err(relay_redis::RedisError::Redis(redis::RedisError::from((
54            redis::ErrorKind::ResponseError,
55            "Script returned an invalid number of elements",
56            format!("Expected {num_hashes} results, got {}", self.statuses.len()),
57        )))
58        .into())
59    }
60}
61
62impl FromRedisValue for CardinalityScriptResult {
63    fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
64        let Some(seq) = v.as_sequence() else {
65            return Err(redis::RedisError::from((
66                redis::ErrorKind::TypeError,
67                "Expected a sequence from the cardinality script",
68                format!("{v:?}"),
69            )));
70        };
71
72        let mut iter = seq.iter();
73
74        let cardinality = iter
75            .next()
76            .ok_or_else(|| {
77                redis::RedisError::from((
78                    redis::ErrorKind::TypeError,
79                    "Expected cardinality as the first result from the cardinality script",
80                ))
81            })
82            .and_then(FromRedisValue::from_redis_value)?;
83
84        let mut statuses = Vec::with_capacity(iter.len());
85        for value in iter {
86            statuses.push(Status::from_redis_value(value)?);
87        }
88
89        Ok(Self {
90            cardinality,
91            statuses,
92        })
93    }
94}
95
96/// Abstraction over the `cardinality.lua` lua Redis script.
97pub struct CardinalityScript(&'static Script);
98
99impl CardinalityScript {
100    /// Loads the script.
101    ///
102    /// This is somewhat costly and shouldn't be done often.
103    pub fn load() -> Self {
104        Self(RedisScripts::load_cardinality())
105    }
106
107    /// Creates a new pipeline to batch multiple script invocations.
108    pub fn pipe(&self) -> CardinalityScriptPipeline<'_> {
109        CardinalityScriptPipeline {
110            script: self,
111            pipe: redis::pipe(),
112        }
113    }
114
115    /// Makes sure the script is loaded in Redis.
116    async fn load_redis(&self, con: &mut AsyncRedisConnection) -> Result<()> {
117        self.0
118            .prepare_invoke()
119            .load_async(con)
120            .await
121            .map_err(relay_redis::RedisError::Redis)?;
122
123        Ok(())
124    }
125
126    /// Returns a [`redis::ScriptInvocation`] with all keys and arguments prepared.
127    fn prepare_invocation(
128        &self,
129        limit: u32,
130        expire: u64,
131        hashes: impl Iterator<Item = u32>,
132        keys: impl Iterator<Item = String>,
133    ) -> redis::ScriptInvocation<'_> {
134        let mut invocation = self.0.prepare_invoke();
135
136        for key in keys {
137            invocation.key(key);
138        }
139
140        invocation.arg(limit);
141        invocation.arg(expire);
142
143        for hash in hashes {
144            invocation.arg(&hash.to_le_bytes());
145        }
146
147        invocation
148    }
149}
150
151/// Pipeline to batch multiple [`CardinalityScript`] invocations.
152pub struct CardinalityScriptPipeline<'a> {
153    script: &'a CardinalityScript,
154    pipe: redis::Pipeline,
155}
156
157impl CardinalityScriptPipeline<'_> {
158    /// Adds another invocation of the script to the pipeline.
159    pub fn add_invocation(
160        &mut self,
161        limit: u32,
162        expire: u64,
163        hashes: impl Iterator<Item = u32>,
164        keys: impl Iterator<Item = String>,
165    ) -> &mut Self {
166        let invocation = self.script.prepare_invocation(limit, expire, hashes, keys);
167        self.pipe.invoke_script(&invocation);
168        self
169    }
170
171    /// Invokes the entire pipeline and returns the results.
172    ///
173    /// Returns one result for each script invocation.
174    pub async fn invoke(
175        &self,
176        con: &mut AsyncRedisConnection,
177    ) -> Result<Vec<CardinalityScriptResult>> {
178        match self.pipe.query_async(con).await {
179            Ok(result) => Ok(result),
180            Err(err) if err.kind() == redis::ErrorKind::NoScriptError => {
181                relay_log::trace!("Redis script no loaded, loading it now");
182                self.script.load_redis(con).await?;
183                self.pipe
184                    .query_async(con)
185                    .await
186                    .map_err(relay_redis::RedisError::Redis)
187                    .map_err(Into::into)
188            }
189            Err(err) => Err(relay_redis::RedisError::Redis(err).into()),
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use relay_redis::{AsyncRedisClient, RedisConfigOptions};
197    use uuid::Uuid;
198
199    use super::*;
200
201    impl CardinalityScript {
202        async fn invoke_one(
203            &self,
204            con: &mut AsyncRedisConnection,
205            limit: u32,
206            expire: u64,
207            hashes: impl Iterator<Item = u32>,
208            keys: impl Iterator<Item = String>,
209        ) -> Result<CardinalityScriptResult> {
210            let mut results = self
211                .pipe()
212                .add_invocation(limit, expire, hashes, keys)
213                .invoke(con)
214                .await?;
215
216            assert_eq!(results.len(), 1);
217            Ok(results.pop().unwrap())
218        }
219    }
220
221    fn build_redis_client() -> AsyncRedisClient {
222        let url = std::env::var("RELAY_REDIS_URL")
223            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
224
225        let opts = RedisConfigOptions {
226            max_connections: 1,
227            ..Default::default()
228        };
229        AsyncRedisClient::single(&url, &opts).unwrap()
230    }
231
232    fn keys(prefix: Uuid, keys: &[&str]) -> impl Iterator<Item = String> {
233        keys.iter()
234            .map(move |key| format!("{prefix}-{key}"))
235            .collect::<Vec<_>>()
236            .into_iter()
237    }
238
239    async fn assert_ttls(connection: &mut AsyncRedisConnection, prefix: Uuid) {
240        let keys = redis::cmd("KEYS")
241            .arg(format!("{prefix}-*"))
242            .query_async::<Vec<String>>(connection)
243            .await
244            .unwrap();
245
246        for key in keys {
247            let ttl = redis::cmd("TTL")
248                .arg(&key)
249                .query_async::<i64>(connection)
250                .await
251                .unwrap();
252
253            assert!(ttl >= 0, "Key {key} has no TTL");
254        }
255    }
256
257    #[tokio::test]
258    async fn test_below_limit_perfect_cardinality_ttl() {
259        let client = build_redis_client();
260        let mut connection = client.get_connection().await.unwrap();
261
262        let script = CardinalityScript::load();
263
264        let prefix = Uuid::new_v4();
265        let k1 = &["a", "b", "c"];
266        let k2 = &["b", "c", "d"];
267
268        script
269            .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k1))
270            .await
271            .unwrap();
272
273        script
274            .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k2))
275            .await
276            .unwrap();
277
278        assert_ttls(&mut connection, prefix).await;
279    }
280
281    #[tokio::test]
282    async fn test_load_script() {
283        let client = build_redis_client();
284        let mut connection = client.get_connection().await.unwrap();
285
286        let script = CardinalityScript::load();
287        let keys = keys(Uuid::new_v4(), &["a", "b", "c"]);
288
289        redis::cmd("SCRIPT")
290            .arg("FLUSH")
291            .exec_async(&mut connection)
292            .await
293            .unwrap();
294        script
295            .invoke_one(&mut connection, 50, 3600, 0..30, keys)
296            .await
297            .unwrap();
298    }
299
300    #[tokio::test]
301    async fn test_multiple_calls_in_pipeline() {
302        let client = build_redis_client();
303        let mut connection = client.get_connection().await.unwrap();
304
305        let script = CardinalityScript::load();
306        let k2 = keys(Uuid::new_v4(), &["a", "b", "c"]);
307        let k1 = keys(Uuid::new_v4(), &["a", "b", "c"]);
308
309        let mut pipeline = script.pipe();
310        let results = pipeline
311            .add_invocation(50, 3600, 0..30, k1)
312            .add_invocation(50, 3600, 0..30, k2)
313            .invoke(&mut connection)
314            .await
315            .unwrap();
316
317        assert_eq!(results.len(), 2);
318    }
319}