relay_cardinality/redis/
script.rs

1use relay_redis::{
2    AsyncRedisConnection, RedisScripts,
3    redis::{self, FromRedisValue, Script, ServerErrorKind},
4};
5
6use crate::Result;
7
8/// Status wether an entry/bucket is accepted or rejected by the cardinality limiter.
9#[derive(Debug, Clone, Copy)]
10pub enum Status {
11    /// Item is rejected.
12    Rejected,
13    /// Item is accepted.
14    Accepted,
15}
16
17impl Status {
18    /// Returns `true` if the status is [`Status::Rejected`].
19    pub fn is_rejected(&self) -> bool {
20        matches!(self, Self::Rejected)
21    }
22}
23
24impl FromRedisValue for Status {
25    fn from_redis_value(v: redis::Value) -> Result<Self, redis::ParsingError> {
26        let accepted = bool::from_redis_value(v)?;
27        Ok(if accepted {
28            Self::Accepted
29        } else {
30            Self::Rejected
31        })
32    }
33}
34
35/// Result returned from [`CardinalityScript`].
36#[derive(Debug)]
37pub struct CardinalityScriptResult {
38    /// Cardinality of the limit.
39    pub cardinality: u32,
40    /// Status for each hash passed to the script.
41    pub statuses: Vec<Status>,
42}
43
44impl CardinalityScriptResult {
45    /// Validates the result against the amount of hashes originally supplied.
46    ///
47    /// This is not necessarily required but recommended.
48    pub fn validate(&self, num_hashes: usize) -> Result<()> {
49        if num_hashes == self.statuses.len() {
50            return Ok(());
51        }
52
53        Err(relay_redis::RedisError::Redis(redis::RedisError::from((
54            redis::ErrorKind::Server(ServerErrorKind::ResponseError),
55            "Script returned an invalid number of elements",
56            format!("Expected {num_hashes} results, got {}", self.statuses.len()),
57        )))
58        .into())
59    }
60}
61
62impl FromRedisValue for CardinalityScriptResult {
63    fn from_redis_value(v: redis::Value) -> Result<Self, redis::ParsingError> {
64        let seq = v.into_sequence().map_err(|v| {
65            format!("Expected a sequence from the cardinality script (value was: {v:?})")
66        })?;
67
68        let mut iter = seq.into_iter();
69
70        let cardinality = iter
71            .next()
72            .ok_or("Expected cardinality as the first result from the cardinality script")
73            .map_err(redis::ParsingError::from)
74            .and_then(FromRedisValue::from_redis_value)?;
75
76        let mut statuses = Vec::with_capacity(iter.len());
77        for value in iter {
78            statuses.push(Status::from_redis_value(value)?);
79        }
80
81        Ok(Self {
82            cardinality,
83            statuses,
84        })
85    }
86}
87
88/// Abstraction over the `cardinality.lua` lua Redis script.
89pub struct CardinalityScript(&'static Script);
90
91impl CardinalityScript {
92    /// Loads the script.
93    ///
94    /// This is somewhat costly and shouldn't be done often.
95    pub fn load() -> Self {
96        Self(RedisScripts::load_cardinality())
97    }
98
99    /// Creates a new pipeline to batch multiple script invocations.
100    pub fn pipe(&self) -> CardinalityScriptPipeline<'_> {
101        CardinalityScriptPipeline {
102            script: self,
103            pipe: redis::pipe(),
104        }
105    }
106
107    /// Makes sure the script is loaded in Redis.
108    async fn load_redis(&self, con: &mut AsyncRedisConnection) -> Result<()> {
109        self.0
110            .prepare_invoke()
111            .load_async(con)
112            .await
113            .map_err(relay_redis::RedisError::Redis)?;
114
115        Ok(())
116    }
117
118    /// Returns a [`redis::ScriptInvocation`] with all keys and arguments prepared.
119    fn prepare_invocation(
120        &self,
121        limit: u32,
122        expire: u64,
123        hashes: impl Iterator<Item = u32>,
124        keys: impl Iterator<Item = String>,
125    ) -> redis::ScriptInvocation<'_> {
126        let mut invocation = self.0.prepare_invoke();
127
128        for key in keys {
129            invocation.key(key);
130        }
131
132        invocation.arg(limit);
133        invocation.arg(expire);
134
135        for hash in hashes {
136            invocation.arg(&hash.to_le_bytes());
137        }
138
139        invocation
140    }
141}
142
143/// Pipeline to batch multiple [`CardinalityScript`] invocations.
144pub struct CardinalityScriptPipeline<'a> {
145    script: &'a CardinalityScript,
146    pipe: redis::Pipeline,
147}
148
149impl CardinalityScriptPipeline<'_> {
150    /// Adds another invocation of the script to the pipeline.
151    pub fn add_invocation(
152        &mut self,
153        limit: u32,
154        expire: u64,
155        hashes: impl Iterator<Item = u32>,
156        keys: impl Iterator<Item = String>,
157    ) -> &mut Self {
158        let invocation = self.script.prepare_invocation(limit, expire, hashes, keys);
159        self.pipe.invoke_script(&invocation);
160        self
161    }
162
163    /// Invokes the entire pipeline and returns the results.
164    ///
165    /// Returns one result for each script invocation.
166    pub async fn invoke(
167        &self,
168        con: &mut AsyncRedisConnection,
169    ) -> Result<Vec<CardinalityScriptResult>> {
170        match self.pipe.query_async(con).await {
171            Ok(result) => Ok(result),
172            Err(err) if err.kind() == redis::ErrorKind::Server(ServerErrorKind::NoScript) => {
173                relay_log::trace!("Redis script no loaded, loading it now");
174                self.script.load_redis(con).await?;
175                self.pipe
176                    .query_async(con)
177                    .await
178                    .map_err(relay_redis::RedisError::Redis)
179                    .map_err(Into::into)
180            }
181            Err(err) => Err(relay_redis::RedisError::Redis(err).into()),
182        }
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use relay_redis::{AsyncRedisClient, RedisConfigOptions};
189    use uuid::Uuid;
190
191    use super::*;
192
193    impl CardinalityScript {
194        async fn invoke_one(
195            &self,
196            con: &mut AsyncRedisConnection,
197            limit: u32,
198            expire: u64,
199            hashes: impl Iterator<Item = u32>,
200            keys: impl Iterator<Item = String>,
201        ) -> Result<CardinalityScriptResult> {
202            let mut results = self
203                .pipe()
204                .add_invocation(limit, expire, hashes, keys)
205                .invoke(con)
206                .await?;
207
208            assert_eq!(results.len(), 1);
209            Ok(results.pop().unwrap())
210        }
211    }
212
213    fn build_redis_client() -> AsyncRedisClient {
214        let url = std::env::var("RELAY_REDIS_URL")
215            .unwrap_or_else(|_| "redis://127.0.0.1:6379".to_owned());
216
217        let opts = RedisConfigOptions {
218            max_connections: 1,
219            ..Default::default()
220        };
221        AsyncRedisClient::single("test", &url, &opts).unwrap()
222    }
223
224    fn keys(prefix: Uuid, keys: &[&str]) -> impl Iterator<Item = String> {
225        keys.iter()
226            .map(move |key| format!("{prefix}-{key}"))
227            .collect::<Vec<_>>()
228            .into_iter()
229    }
230
231    async fn assert_ttls(connection: &mut AsyncRedisConnection, prefix: Uuid) {
232        let keys = redis::cmd("KEYS")
233            .arg(format!("{prefix}-*"))
234            .query_async::<Vec<String>>(connection)
235            .await
236            .unwrap();
237
238        for key in keys {
239            let ttl = redis::cmd("TTL")
240                .arg(&key)
241                .query_async::<i64>(connection)
242                .await
243                .unwrap();
244
245            assert!(ttl >= 0, "Key {key} has no TTL");
246        }
247    }
248
249    #[tokio::test]
250    async fn test_below_limit_perfect_cardinality_ttl() {
251        let client = build_redis_client();
252        let mut connection = client.get_connection().await.unwrap();
253
254        let script = CardinalityScript::load();
255
256        let prefix = Uuid::new_v4();
257        let k1 = &["a", "b", "c"];
258        let k2 = &["b", "c", "d"];
259
260        script
261            .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k1))
262            .await
263            .unwrap();
264
265        script
266            .invoke_one(&mut connection, 50, 3600, 0..30, keys(prefix, k2))
267            .await
268            .unwrap();
269
270        assert_ttls(&mut connection, prefix).await;
271    }
272
273    #[tokio::test]
274    async fn test_load_script() {
275        let client = build_redis_client();
276        let mut connection = client.get_connection().await.unwrap();
277
278        let script = CardinalityScript::load();
279        let keys = keys(Uuid::new_v4(), &["a", "b", "c"]);
280
281        redis::cmd("SCRIPT")
282            .arg("FLUSH")
283            .exec_async(&mut connection)
284            .await
285            .unwrap();
286        script
287            .invoke_one(&mut connection, 50, 3600, 0..30, keys)
288            .await
289            .unwrap();
290    }
291
292    #[tokio::test]
293    async fn test_multiple_calls_in_pipeline() {
294        let client = build_redis_client();
295        let mut connection = client.get_connection().await.unwrap();
296
297        let script = CardinalityScript::load();
298        let k2 = keys(Uuid::new_v4(), &["a", "b", "c"]);
299        let k1 = keys(Uuid::new_v4(), &["a", "b", "c"]);
300
301        let mut pipeline = script.pipe();
302        let results = pipeline
303            .add_invocation(50, 3600, 0..30, k1)
304            .add_invocation(50, 3600, 0..30, k2)
305            .invoke(&mut connection)
306            .await
307            .unwrap();
308
309        assert_eq!(results.len(), 2);
310    }
311}