1use crate::error::TLSError;
2use crate::key;
3use crate::server;
4use crate::server::ClientHello;
5use crate::sign;
6use webpki;
7
8use std::collections;
9use std::sync::{Arc, Mutex};
10
11pub struct NoServerSessionStorage {}
13
14impl server::StoresServerSessions for NoServerSessionStorage {
15 fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
16 false
17 }
18 fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
19 None
20 }
21 fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
22 None
23 }
24}
25
26pub struct ServerSessionMemoryCache {
30 cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
31 max_entries: usize,
32}
33
34impl ServerSessionMemoryCache {
35 pub fn new(size: usize) -> Arc<ServerSessionMemoryCache> {
38 debug_assert!(size > 0);
39 Arc::new(ServerSessionMemoryCache {
40 cache: Mutex::new(collections::HashMap::new()),
41 max_entries: size,
42 })
43 }
44
45 fn limit_size(&self) {
46 let mut cache = self.cache.lock().unwrap();
47 while cache.len() > self.max_entries {
48 let k = cache.keys().next().unwrap().clone();
49 cache.remove(&k);
50 }
51 }
52}
53
54impl server::StoresServerSessions for ServerSessionMemoryCache {
55 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
56 self.cache
57 .lock()
58 .unwrap()
59 .insert(key, value);
60 self.limit_size();
61 true
62 }
63
64 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
65 self.cache
66 .lock()
67 .unwrap()
68 .get(key)
69 .cloned()
70 }
71
72 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
73 self.cache.lock().unwrap().remove(key)
74 }
75}
76
77pub struct NeverProducesTickets {}
79
80impl server::ProducesTickets for NeverProducesTickets {
81 fn enabled(&self) -> bool {
82 false
83 }
84 fn get_lifetime(&self) -> u32 {
85 0
86 }
87 fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
88 None
89 }
90 fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
91 None
92 }
93}
94
95pub struct FailResolveChain {}
97
98impl server::ResolvesServerCert for FailResolveChain {
99 fn resolve(&self, _client_hello: ClientHello) -> Option<sign::CertifiedKey> {
100 None
101 }
102}
103
104pub struct AlwaysResolvesChain(sign::CertifiedKey);
106
107impl AlwaysResolvesChain {
108 pub fn new(
111 chain: Vec<key::Certificate>,
112 priv_key: &key::PrivateKey,
113 ) -> Result<AlwaysResolvesChain, TLSError> {
114 let key = sign::any_supported_type(priv_key)
115 .map_err(|_| TLSError::General("invalid private key".into()))?;
116 Ok(AlwaysResolvesChain(sign::CertifiedKey::new(
117 chain,
118 Arc::new(key),
119 )))
120 }
121
122 pub fn new_with_extras(
127 chain: Vec<key::Certificate>,
128 priv_key: &key::PrivateKey,
129 ocsp: Vec<u8>,
130 scts: Vec<u8>,
131 ) -> Result<AlwaysResolvesChain, TLSError> {
132 let mut r = AlwaysResolvesChain::new(chain, priv_key)?;
133 if !ocsp.is_empty() {
134 r.0.ocsp = Some(ocsp);
135 }
136 if !scts.is_empty() {
137 r.0.sct_list = Some(scts);
138 }
139 Ok(r)
140 }
141}
142
143impl server::ResolvesServerCert for AlwaysResolvesChain {
144 fn resolve(&self, _client_hello: ClientHello) -> Option<sign::CertifiedKey> {
145 Some(self.0.clone())
146 }
147}
148
149pub struct ResolvesServerCertUsingSNI {
152 by_name: collections::HashMap<String, sign::CertifiedKey>,
153}
154
155impl ResolvesServerCertUsingSNI {
156 pub fn new() -> ResolvesServerCertUsingSNI {
158 ResolvesServerCertUsingSNI {
159 by_name: collections::HashMap::new(),
160 }
161 }
162
163 pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), TLSError> {
169 let checked_name = webpki::DNSNameRef::try_from_ascii_str(name)
170 .map_err(|_| TLSError::General("Bad DNS name".into()))?;
171
172 ck.cross_check_end_entity_cert(Some(checked_name))?;
173 self.by_name.insert(name.into(), ck);
174 Ok(())
175 }
176}
177
178impl server::ResolvesServerCert for ResolvesServerCertUsingSNI {
179 fn resolve(&self, client_hello: ClientHello) -> Option<sign::CertifiedKey> {
180 if let Some(name) = client_hello.server_name() {
181 self.by_name.get(name.into()).cloned()
182 } else {
183 None
185 }
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use super::*;
192 use crate::server::ProducesTickets;
193 use crate::server::ResolvesServerCert;
194 use crate::StoresServerSessions;
195
196 #[test]
197 fn test_noserversessionstorage_drops_put() {
198 let c = NoServerSessionStorage {};
199 assert_eq!(c.put(vec![0x01], vec![0x02]), false);
200 }
201
202 #[test]
203 fn test_noserversessionstorage_denies_gets() {
204 let c = NoServerSessionStorage {};
205 c.put(vec![0x01], vec![0x02]);
206 assert_eq!(c.get(&[]), None);
207 assert_eq!(c.get(&[0x01]), None);
208 assert_eq!(c.get(&[0x02]), None);
209 }
210
211 #[test]
212 fn test_noserversessionstorage_denies_takes() {
213 let c = NoServerSessionStorage {};
214 assert_eq!(c.take(&[]), None);
215 assert_eq!(c.take(&[0x01]), None);
216 assert_eq!(c.take(&[0x02]), None);
217 }
218
219 #[test]
220 fn test_serversessionmemorycache_accepts_put() {
221 let c = ServerSessionMemoryCache::new(4);
222 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
223 }
224
225 #[test]
226 fn test_serversessionmemorycache_persists_put() {
227 let c = ServerSessionMemoryCache::new(4);
228 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
229 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
230 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
231 }
232
233 #[test]
234 fn test_serversessionmemorycache_overwrites_put() {
235 let c = ServerSessionMemoryCache::new(4);
236 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
237 assert_eq!(c.put(vec![0x01], vec![0x04]), true);
238 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
239 }
240
241 #[test]
242 fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
243 let c = ServerSessionMemoryCache::new(4);
244 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
245 assert_eq!(c.put(vec![0x03], vec![0x04]), true);
246 assert_eq!(c.put(vec![0x05], vec![0x06]), true);
247 assert_eq!(c.put(vec![0x07], vec![0x08]), true);
248 assert_eq!(c.put(vec![0x09], vec![0x0a]), true);
249
250 let mut count = 0;
251 if c.get(&[0x01]).is_some() {
252 count += 1;
253 }
254 if c.get(&[0x03]).is_some() {
255 count += 1;
256 }
257 if c.get(&[0x05]).is_some() {
258 count += 1;
259 }
260 if c.get(&[0x07]).is_some() {
261 count += 1;
262 }
263 if c.get(&[0x09]).is_some() {
264 count += 1;
265 }
266
267 assert_eq!(count, 4);
268 }
269
270 #[test]
271 fn test_neverproducestickets_does_nothing() {
272 let npt = NeverProducesTickets {};
273 assert_eq!(false, npt.enabled());
274 assert_eq!(0, npt.get_lifetime());
275 assert_eq!(None, npt.encrypt(&[]));
276 assert_eq!(None, npt.decrypt(&[]));
277 }
278
279 #[test]
280 fn test_failresolvechain_does_nothing() {
281 let frc = FailResolveChain {};
282 assert!(
283 frc.resolve(ClientHello::new(None, &[], None))
284 .is_none()
285 );
286 }
287
288 #[test]
289 fn test_resolvesservercertusingsni_requires_sni() {
290 let rscsni = ResolvesServerCertUsingSNI::new();
291 assert!(
292 rscsni
293 .resolve(ClientHello::new(None, &[], None))
294 .is_none()
295 );
296 }
297
298 #[test]
299 fn test_resolvesservercertusingsni_handles_unknown_name() {
300 let rscsni = ResolvesServerCertUsingSNI::new();
301 let name = webpki::DNSNameRef::try_from_ascii_str("hello.com").unwrap();
302 assert!(
303 rscsni
304 .resolve(ClientHello::new(Some(name), &[], None))
305 .is_none()
306 );
307 }
308}