rustls/server/
handy.rs

1use crate::error::TLSError;
2use crate::key;
3use crate::server;
4use crate::server::ClientHello;
5use crate::sign;
6use webpki;
7
8use std::collections;
9use std::sync::{Arc, Mutex};
10
11/// Something which never stores sessions.
12pub struct NoServerSessionStorage {}
13
14impl server::StoresServerSessions for NoServerSessionStorage {
15    fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
16        false
17    }
18    fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
19        None
20    }
21    fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
22        None
23    }
24}
25
26/// An implementor of `StoresServerSessions` that stores everything
27/// in memory.  If enforces a limit on the number of stored sessions
28/// to bound memory usage.
29pub struct ServerSessionMemoryCache {
30    cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
31    max_entries: usize,
32}
33
34impl ServerSessionMemoryCache {
35    /// Make a new ServerSessionMemoryCache.  `size` is the maximum
36    /// number of stored sessions.
37    pub fn new(size: usize) -> Arc<ServerSessionMemoryCache> {
38        debug_assert!(size > 0);
39        Arc::new(ServerSessionMemoryCache {
40            cache: Mutex::new(collections::HashMap::new()),
41            max_entries: size,
42        })
43    }
44
45    fn limit_size(&self) {
46        let mut cache = self.cache.lock().unwrap();
47        while cache.len() > self.max_entries {
48            let k = cache.keys().next().unwrap().clone();
49            cache.remove(&k);
50        }
51    }
52}
53
54impl server::StoresServerSessions for ServerSessionMemoryCache {
55    fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
56        self.cache
57            .lock()
58            .unwrap()
59            .insert(key, value);
60        self.limit_size();
61        true
62    }
63
64    fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
65        self.cache
66            .lock()
67            .unwrap()
68            .get(key)
69            .cloned()
70    }
71
72    fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
73        self.cache.lock().unwrap().remove(key)
74    }
75}
76
77/// Something which never produces tickets.
78pub struct NeverProducesTickets {}
79
80impl server::ProducesTickets for NeverProducesTickets {
81    fn enabled(&self) -> bool {
82        false
83    }
84    fn get_lifetime(&self) -> u32 {
85        0
86    }
87    fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
88        None
89    }
90    fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
91        None
92    }
93}
94
95/// Something which never resolves a certificate.
96pub struct FailResolveChain {}
97
98impl server::ResolvesServerCert for FailResolveChain {
99    fn resolve(&self, _client_hello: ClientHello) -> Option<sign::CertifiedKey> {
100        None
101    }
102}
103
104/// Something which always resolves to the same cert chain.
105pub struct AlwaysResolvesChain(sign::CertifiedKey);
106
107impl AlwaysResolvesChain {
108    /// Creates an `AlwaysResolvesChain`, auto-detecting the underlying private
109    /// key type and encoding.
110    pub fn new(
111        chain: Vec<key::Certificate>,
112        priv_key: &key::PrivateKey,
113    ) -> Result<AlwaysResolvesChain, TLSError> {
114        let key = sign::any_supported_type(priv_key)
115            .map_err(|_| TLSError::General("invalid private key".into()))?;
116        Ok(AlwaysResolvesChain(sign::CertifiedKey::new(
117            chain,
118            Arc::new(key),
119        )))
120    }
121
122    /// Creates an `AlwaysResolvesChain`, auto-detecting the underlying private
123    /// key type and encoding.
124    ///
125    /// If non-empty, the given OCSP response and SCTs are attached.
126    pub fn new_with_extras(
127        chain: Vec<key::Certificate>,
128        priv_key: &key::PrivateKey,
129        ocsp: Vec<u8>,
130        scts: Vec<u8>,
131    ) -> Result<AlwaysResolvesChain, TLSError> {
132        let mut r = AlwaysResolvesChain::new(chain, priv_key)?;
133        if !ocsp.is_empty() {
134            r.0.ocsp = Some(ocsp);
135        }
136        if !scts.is_empty() {
137            r.0.sct_list = Some(scts);
138        }
139        Ok(r)
140    }
141}
142
143impl server::ResolvesServerCert for AlwaysResolvesChain {
144    fn resolve(&self, _client_hello: ClientHello) -> Option<sign::CertifiedKey> {
145        Some(self.0.clone())
146    }
147}
148
149/// Something that resolves do different cert chains/keys based
150/// on client-supplied server name (via SNI).
151pub struct ResolvesServerCertUsingSNI {
152    by_name: collections::HashMap<String, sign::CertifiedKey>,
153}
154
155impl ResolvesServerCertUsingSNI {
156    /// Create a new and empty (ie, knows no certificates) resolver.
157    pub fn new() -> ResolvesServerCertUsingSNI {
158        ResolvesServerCertUsingSNI {
159            by_name: collections::HashMap::new(),
160        }
161    }
162
163    /// Add a new `sign::CertifiedKey` to be used for the given SNI `name`.
164    ///
165    /// This function fails if `name` is not a valid DNS name, or if
166    /// it's not valid for the supplied certificate, or if the certificate
167    /// chain is syntactically faulty.
168    pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), TLSError> {
169        let checked_name = webpki::DNSNameRef::try_from_ascii_str(name)
170            .map_err(|_| TLSError::General("Bad DNS name".into()))?;
171
172        ck.cross_check_end_entity_cert(Some(checked_name))?;
173        self.by_name.insert(name.into(), ck);
174        Ok(())
175    }
176}
177
178impl server::ResolvesServerCert for ResolvesServerCertUsingSNI {
179    fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
180        if let Some(name) = client_hello.server_name() {
181            self.by_name.get(name.into()).cloned()
182        } else {
183            // This kind of resolver requires SNI
184            None
185        }
186    }
187}
188
189#[cfg(test)]
190mod test {
191    use super::*;
192    use crate::server::ProducesTickets;
193    use crate::server::ResolvesServerCert;
194    use crate::StoresServerSessions;
195
196    #[test]
197    fn test_noserversessionstorage_drops_put() {
198        let c = NoServerSessionStorage {};
199        assert_eq!(c.put(vec![0x01], vec![0x02]), false);
200    }
201
202    #[test]
203    fn test_noserversessionstorage_denies_gets() {
204        let c = NoServerSessionStorage {};
205        c.put(vec![0x01], vec![0x02]);
206        assert_eq!(c.get(&[]), None);
207        assert_eq!(c.get(&[0x01]), None);
208        assert_eq!(c.get(&[0x02]), None);
209    }
210
211    #[test]
212    fn test_noserversessionstorage_denies_takes() {
213        let c = NoServerSessionStorage {};
214        assert_eq!(c.take(&[]), None);
215        assert_eq!(c.take(&[0x01]), None);
216        assert_eq!(c.take(&[0x02]), None);
217    }
218
219    #[test]
220    fn test_serversessionmemorycache_accepts_put() {
221        let c = ServerSessionMemoryCache::new(4);
222        assert_eq!(c.put(vec![0x01], vec![0x02]), true);
223    }
224
225    #[test]
226    fn test_serversessionmemorycache_persists_put() {
227        let c = ServerSessionMemoryCache::new(4);
228        assert_eq!(c.put(vec![0x01], vec![0x02]), true);
229        assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
230        assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
231    }
232
233    #[test]
234    fn test_serversessionmemorycache_overwrites_put() {
235        let c = ServerSessionMemoryCache::new(4);
236        assert_eq!(c.put(vec![0x01], vec![0x02]), true);
237        assert_eq!(c.put(vec![0x01], vec![0x04]), true);
238        assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
239    }
240
241    #[test]
242    fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
243        let c = ServerSessionMemoryCache::new(4);
244        assert_eq!(c.put(vec![0x01], vec![0x02]), true);
245        assert_eq!(c.put(vec![0x03], vec![0x04]), true);
246        assert_eq!(c.put(vec![0x05], vec![0x06]), true);
247        assert_eq!(c.put(vec![0x07], vec![0x08]), true);
248        assert_eq!(c.put(vec![0x09], vec![0x0a]), true);
249
250        let mut count = 0;
251        if c.get(&[0x01]).is_some() {
252            count += 1;
253        }
254        if c.get(&[0x03]).is_some() {
255            count += 1;
256        }
257        if c.get(&[0x05]).is_some() {
258            count += 1;
259        }
260        if c.get(&[0x07]).is_some() {
261            count += 1;
262        }
263        if c.get(&[0x09]).is_some() {
264            count += 1;
265        }
266
267        assert_eq!(count, 4);
268    }
269
270    #[test]
271    fn test_neverproducestickets_does_nothing() {
272        let npt = NeverProducesTickets {};
273        assert_eq!(false, npt.enabled());
274        assert_eq!(0, npt.get_lifetime());
275        assert_eq!(None, npt.encrypt(&[]));
276        assert_eq!(None, npt.decrypt(&[]));
277    }
278
279    #[test]
280    fn test_failresolvechain_does_nothing() {
281        let frc = FailResolveChain {};
282        assert!(
283            frc.resolve(ClientHello::new(None, &[], None))
284                .is_none()
285        );
286    }
287
288    #[test]
289    fn test_resolvesservercertusingsni_requires_sni() {
290        let rscsni = ResolvesServerCertUsingSNI::new();
291        assert!(
292            rscsni
293                .resolve(ClientHello::new(None, &[], None))
294                .is_none()
295        );
296    }
297
298    #[test]
299    fn test_resolvesservercertusingsni_handles_unknown_name() {
300        let rscsni = ResolvesServerCertUsingSNI::new();
301        let name = webpki::DNSNameRef::try_from_ascii_str("hello.com").unwrap();
302        assert!(
303            rscsni
304                .resolve(ClientHello::new(Some(name), &[], None))
305                .is_none()
306        );
307    }
308}