relay_pattern/
typed.rs

1use core::fmt;
2use std::fmt::Debug;
3use std::marker::PhantomData;
4use std::ops::Deref;
5
6use crate::{Error, Pattern, Patterns, PatternsBuilderConfigured};
7
8/// Compile time configuration for a [`TypedPattern`].
9pub trait PatternConfig {
10    /// Configures the pattern to match case insensitive.
11    const CASE_INSENSITIVE: bool = false;
12    /// Configures the maximum allowed complexity of the pattern.
13    const MAX_COMPLEXITY: u64 = u64::MAX;
14}
15
16/// The default pattern.
17///
18/// Equivalent to [`Pattern::new`].
19pub struct DefaultPatternConfig;
20
21impl PatternConfig for DefaultPatternConfig {}
22
23/// The default pattern but with case insensitive matching.
24///
25/// See: [`crate::PatternBuilder::case_insensitive`].
26pub struct CaseInsensitive;
27
28impl PatternConfig for CaseInsensitive {
29    const CASE_INSENSITIVE: bool = true;
30}
31
32/// A [`Pattern`] with compile time encoded [`PatternConfig`].
33///
34/// Encoding the pattern configuration allows context dependent serialization
35/// and usage of patterns and ensures a consistent usage of configuration options
36/// throught the code.
37///
38/// Often repeated configuration can be grouped into custom and importable configurations.
39///
40/// ```
41/// struct MetricConfig;
42///
43/// impl relay_pattern::PatternConfig for MetricConfig {
44///     const CASE_INSENSITIVE: bool = false;
45///     // More configuration ...
46/// }
47///
48/// type MetricPattern = relay_pattern::TypedPattern<MetricConfig>;
49///
50/// let pattern = MetricPattern::new("[cd]:foo/bar").unwrap();
51/// assert!(pattern.is_match("c:foo/bar"));
52/// ```
53#[derive(Debug)]
54pub struct TypedPattern<C = DefaultPatternConfig> {
55    pattern: Pattern,
56    _phantom: PhantomData<C>,
57}
58
59impl<C: PatternConfig> TypedPattern<C> {
60    /// Creates a new [`TypedPattern`] using the provided pattern and config `C`.
61    ///
62    /// ```
63    /// use relay_pattern::{Pattern, TypedPattern, CaseInsensitive};
64    ///
65    /// let pattern = TypedPattern::<CaseInsensitive>::new("foo*").unwrap();
66    /// assert!(pattern.is_match("FOOBAR"));
67    ///
68    /// // Equivalent to:
69    /// let pattern = Pattern::builder("foo*").case_insensitive(true).build().unwrap();
70    /// assert!(pattern.is_match("FOOBAR"));
71    /// ```
72    pub fn new(pattern: &str) -> Result<Self, Error> {
73        Pattern::builder(pattern)
74            .case_insensitive(C::CASE_INSENSITIVE)
75            .max_complexity(C::MAX_COMPLEXITY)
76            .build()
77            .map(|pattern| Self {
78                pattern,
79                _phantom: PhantomData,
80            })
81    }
82}
83
84#[cfg(feature = "serde")]
85impl<'de, C: PatternConfig> serde::Deserialize<'de> for TypedPattern<C> {
86    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
87    where
88        D: serde::Deserializer<'de>,
89    {
90        let pattern = <std::borrow::Cow<'_, str>>::deserialize(deserializer)?;
91        Self::new(&pattern).map_err(serde::de::Error::custom)
92    }
93}
94
95#[cfg(feature = "serde")]
96impl<C> serde::Serialize for TypedPattern<C> {
97    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
98    where
99        S: serde::Serializer,
100    {
101        serializer.collect_str(&self.pattern)
102    }
103}
104
105impl<C> From<TypedPattern<C>> for Pattern {
106    fn from(value: TypedPattern<C>) -> Self {
107        value.pattern
108    }
109}
110
111impl<C> AsRef<Pattern> for TypedPattern<C> {
112    fn as_ref(&self) -> &Pattern {
113        &self.pattern
114    }
115}
116
117impl<C> Deref for TypedPattern<C> {
118    type Target = Pattern;
119
120    fn deref(&self) -> &Self::Target {
121        &self.pattern
122    }
123}
124
125/// [`Patterns`] with a compile time configured [`PatternConfig`].
126pub struct TypedPatterns<C = DefaultPatternConfig> {
127    patterns: Patterns,
128    raw: Vec<String>,
129    _phantom: PhantomData<C>,
130}
131
132impl<C: PatternConfig> TypedPatterns<C> {
133    pub fn builder() -> TypedPatternsBuilder<C> {
134        let builder = Patterns::builder()
135            .case_insensitive(C::CASE_INSENSITIVE)
136            .patterns();
137
138        TypedPatternsBuilder {
139            builder,
140            raw: Vec::new(),
141            _phantom: PhantomData,
142        }
143    }
144}
145
146impl<C: PatternConfig> Default for TypedPatterns<C> {
147    fn default() -> Self {
148        Self::builder().build()
149    }
150}
151
152impl<C> PartialEq for TypedPatterns<C> {
153    fn eq(&self, other: &Self) -> bool {
154        self.raw.eq(&other.raw)
155    }
156}
157
158impl<C: PatternConfig> From<String> for TypedPatterns<C> {
159    fn from(value: String) -> Self {
160        [value].into_iter().collect()
161    }
162}
163
164impl<C: PatternConfig> From<Vec<String>> for TypedPatterns<C> {
165    fn from(value: Vec<String>) -> Self {
166        value.into_iter().collect()
167    }
168}
169
170impl<C: PatternConfig, const N: usize> From<[String; N]> for TypedPatterns<C> {
171    fn from(value: [String; N]) -> Self {
172        value.into_iter().collect()
173    }
174}
175
176/// Creates [`Patterns`] from an iterator of strings.
177///
178/// Invalid patterns are ignored.
179impl<C: PatternConfig> FromIterator<String> for TypedPatterns<C> {
180    fn from_iter<T: IntoIterator<Item = String>>(iter: T) -> Self {
181        let mut builder = Self::builder();
182        for pattern in iter.into_iter() {
183            let _err = builder.add(pattern);
184            #[cfg(debug_assertions)]
185            _err.expect("all patterns should be valid patterns");
186        }
187        builder.build()
188    }
189}
190
191impl<C> From<TypedPatterns<C>> for Patterns {
192    fn from(value: TypedPatterns<C>) -> Self {
193        value.patterns
194    }
195}
196
197impl<C> AsRef<Patterns> for TypedPatterns<C> {
198    fn as_ref(&self) -> &Patterns {
199        &self.patterns
200    }
201}
202
203impl<C> Deref for TypedPatterns<C> {
204    type Target = Patterns;
205
206    fn deref(&self) -> &Self::Target {
207        &self.patterns
208    }
209}
210
211impl<C> Clone for TypedPatterns<C> {
212    fn clone(&self) -> Self {
213        Self {
214            patterns: self.patterns.clone(),
215            raw: self.raw.clone(),
216            _phantom: PhantomData,
217        }
218    }
219}
220
221impl<C> fmt::Debug for TypedPatterns<C> {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        self.raw.fmt(f)
224    }
225}
226
227/// Deserializes patterns from a sequence of strings.
228///
229/// Invalid patterns are ignored while deserializing.
230#[cfg(feature = "serde")]
231impl<'de, C: PatternConfig> serde::Deserialize<'de> for TypedPatterns<C> {
232    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
233    where
234        D: serde::Deserializer<'de>,
235    {
236        struct Visitor<C>(PhantomData<C>);
237
238        impl<'a, C: PatternConfig> serde::de::Visitor<'a> for Visitor<C> {
239            type Value = TypedPatterns<C>;
240
241            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
242                formatter.write_str("a sequence of patterns")
243            }
244
245            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
246            where
247                A: serde::de::SeqAccess<'a>,
248            {
249                let mut builder = TypedPatterns::<C>::builder();
250
251                while let Some(item) = seq.next_element()? {
252                    // Ignore invalid patterns as documented.
253                    let _err = builder.add(item);
254                    #[cfg(debug_assertions)]
255                    _err.expect("all patterns should be valid patterns");
256                }
257
258                Ok(builder.build())
259            }
260        }
261
262        deserializer.deserialize_seq(Visitor(PhantomData))
263    }
264}
265
266#[cfg(feature = "serde")]
267impl<C> serde::Serialize for TypedPatterns<C> {
268    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
269    where
270        S: serde::Serializer,
271    {
272        self.raw.serialize(serializer)
273    }
274}
275
276pub struct TypedPatternsBuilder<C> {
277    builder: PatternsBuilderConfigured,
278    raw: Vec<String>,
279    _phantom: PhantomData<C>,
280}
281
282impl<C: PatternConfig> TypedPatternsBuilder<C> {
283    /// Adds a pattern to the builder.
284    pub fn add(&mut self, pattern: String) -> Result<&mut Self, Error> {
285        self.builder.add(&pattern)?;
286        self.raw.push(pattern);
287        Ok(self)
288    }
289
290    /// Builds a [`TypedPatterns`] from the contained patterns.
291    pub fn build(self) -> TypedPatterns<C> {
292        TypedPatterns {
293            patterns: self.builder.build(),
294            raw: self.raw,
295            _phantom: PhantomData,
296        }
297    }
298
299    /// Builds a [`TypedPatterns`] from the contained patterns and clears the builder.
300    pub fn take(&mut self) -> TypedPatterns<C> {
301        TypedPatterns {
302            patterns: self.builder.take(),
303            raw: std::mem::take(&mut self.raw),
304            _phantom: PhantomData,
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_default() {
315        let pattern: TypedPattern = TypedPattern::new("*[rt]x").unwrap();
316        assert!(pattern.is_match("f/o_rx"));
317        assert!(pattern.is_match("f/o_tx"));
318        assert!(pattern.is_match("F/o_tx"));
319        // case sensitive
320        assert!(!pattern.is_match("f/o_Tx"));
321        assert!(!pattern.is_match("f/o_rX"));
322    }
323
324    #[test]
325    fn test_case_insensitive() {
326        let pattern: TypedPattern<CaseInsensitive> = TypedPattern::new("*[rt]x").unwrap();
327        // case insensitive
328        assert!(pattern.is_match("f/o_Tx"));
329        assert!(pattern.is_match("f/o_rX"));
330    }
331
332    #[test]
333    #[cfg(feature = "serde")]
334    fn test_deserialize() {
335        let pattern: TypedPattern<CaseInsensitive> = serde_json::from_str(r#""*[rt]x""#).unwrap();
336        assert!(pattern.is_match("foobar_rx"));
337    }
338
339    #[test]
340    #[cfg(feature = "serde")]
341    fn test_deserialize_err() {
342        let r: Result<TypedPattern<CaseInsensitive>, _> = serde_json::from_str(r#""[invalid""#);
343        assert!(r.is_err());
344    }
345
346    #[test]
347    #[cfg(feature = "serde")]
348    fn test_deserialize_complexity() {
349        struct Test;
350        impl PatternConfig for Test {
351            const MAX_COMPLEXITY: u64 = 2;
352        }
353        let r: Result<TypedPattern<Test>, _> = serde_json::from_str(r#""{foo,bar}""#);
354        assert!(r.is_ok());
355        let r: Result<TypedPattern<Test>, _> = serde_json::from_str(r#""{foo,bar,baz}""#);
356        assert!(r.is_err());
357    }
358
359    #[test]
360    #[cfg(feature = "serde")]
361    fn test_serialize() {
362        let pattern: TypedPattern = TypedPattern::new("*[rt]x").unwrap();
363        assert_eq!(serde_json::to_string(&pattern).unwrap(), r#""*[rt]x""#);
364        let pattern: TypedPattern<CaseInsensitive> = TypedPattern::new("*[rt]x").unwrap();
365        assert_eq!(serde_json::to_string(&pattern).unwrap(), r#""*[rt]x""#);
366    }
367
368    #[test]
369    fn test_patterns_default() {
370        let patterns: TypedPatterns = TypedPatterns::builder()
371            .add("*[rt]x".to_owned())
372            .unwrap()
373            .add("foobar".to_owned())
374            .unwrap()
375            .take();
376        assert!(patterns.is_match("f/o_rx"));
377        assert!(patterns.is_match("foobar"));
378        assert!(!patterns.is_match("Foobar"));
379    }
380
381    #[test]
382    fn test_patterns_case_insensitive() {
383        let patterns: TypedPatterns<CaseInsensitive> = TypedPatterns::builder()
384            .add("*[rt]x".to_owned())
385            .unwrap()
386            .add("foobar".to_owned())
387            .unwrap()
388            .take();
389        assert!(patterns.is_match("f/o_rx"));
390        assert!(patterns.is_match("f/o_Rx"));
391        assert!(patterns.is_match("foobar"));
392        assert!(patterns.is_match("Foobar"));
393    }
394
395    #[test]
396    #[cfg(feature = "serde")]
397    fn test_patterns_deserialize() {
398        let pattern: TypedPatterns<CaseInsensitive> =
399            serde_json::from_str(r#"["*[rt]x","foobar"]"#).unwrap();
400        assert!(pattern.is_match("foobar_rx"));
401        assert!(pattern.is_match("FOOBAR"));
402    }
403
404    #[test]
405    #[cfg(all(feature = "serde", not(debug_assertions)))]
406    fn test_patterns_deserialize_err() {
407        let r: TypedPatterns<CaseInsensitive> =
408            serde_json::from_str(r#"["[invalid","foobar"]"#).unwrap();
409        assert!(r.is_match("foobar"));
410        assert!(r.is_match("FOOBAR"));
411
412        // The invalid element is dropped.
413        assert_eq!(serde_json::to_string(&r).unwrap(), r#"["foobar"]"#);
414    }
415
416    #[test]
417    #[cfg(feature = "serde")]
418    fn test_patterns_serialize() {
419        let pattern: TypedPatterns = TypedPatterns::builder()
420            .add("*[rt]x".to_owned())
421            .unwrap()
422            .add("foobar".to_owned())
423            .unwrap()
424            .take();
425        assert_eq!(
426            serde_json::to_string(&pattern).unwrap(),
427            r#"["*[rt]x","foobar"]"#
428        );
429
430        let pattern: TypedPatterns<CaseInsensitive> = TypedPatterns::builder()
431            .add("*[rt]x".to_owned())
432            .unwrap()
433            .add("foobar".to_owned())
434            .unwrap()
435            .take();
436        assert_eq!(
437            serde_json::to_string(&pattern).unwrap(),
438            r#"["*[rt]x","foobar"]"#
439        );
440    }
441}