1use crate::client;
2use crate::error::TLSError;
3use crate::key;
4use crate::msgs::enums::SignatureScheme;
5use crate::sign;
6
7use std::collections;
8use std::sync::{Arc, Mutex};
9
10pub struct NoClientSessionStorage {}
12
13impl client::StoresClientSessions for NoClientSessionStorage {
14 fn put(&self, _key: Vec<u8>, _value: Vec<u8>) -> bool {
15 false
16 }
17
18 fn get(&self, _key: &[u8]) -> Option<Vec<u8>> {
19 None
20 }
21}
22
23pub struct ClientSessionMemoryCache {
27 cache: Mutex<collections::HashMap<Vec<u8>, Vec<u8>>>,
28 max_entries: usize,
29}
30
31impl ClientSessionMemoryCache {
32 pub fn new(size: usize) -> Arc<ClientSessionMemoryCache> {
35 debug_assert!(size > 0);
36 Arc::new(ClientSessionMemoryCache {
37 cache: Mutex::new(collections::HashMap::new()),
38 max_entries: size,
39 })
40 }
41
42 fn limit_size(&self) {
43 let mut cache = self.cache.lock().unwrap();
44 while cache.len() > self.max_entries {
45 let k = cache.keys().next().unwrap().clone();
46 cache.remove(&k);
47 }
48 }
49}
50
51impl client::StoresClientSessions for ClientSessionMemoryCache {
52 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
53 self.cache
54 .lock()
55 .unwrap()
56 .insert(key, value);
57 self.limit_size();
58 true
59 }
60
61 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
62 self.cache
63 .lock()
64 .unwrap()
65 .get(key)
66 .cloned()
67 }
68}
69
70pub struct FailResolveClientCert {}
71
72impl client::ResolvesClientCert for FailResolveClientCert {
73 fn resolve(
74 &self,
75 _acceptable_issuers: &[&[u8]],
76 _sigschemes: &[SignatureScheme],
77 ) -> Option<sign::CertifiedKey> {
78 None
79 }
80
81 fn has_certs(&self) -> bool {
82 false
83 }
84}
85
86pub struct AlwaysResolvesClientCert(sign::CertifiedKey);
87
88impl AlwaysResolvesClientCert {
89 pub fn new(
90 chain: Vec<key::Certificate>,
91 priv_key: &key::PrivateKey,
92 ) -> Result<AlwaysResolvesClientCert, TLSError> {
93 let key = sign::any_supported_type(priv_key)
94 .map_err(|_| TLSError::General("invalid private key".into()))?;
95 Ok(AlwaysResolvesClientCert(sign::CertifiedKey::new(
96 chain,
97 Arc::new(key),
98 )))
99 }
100}
101
102impl client::ResolvesClientCert for AlwaysResolvesClientCert {
103 fn resolve(
104 &self,
105 _acceptable_issuers: &[&[u8]],
106 _sigschemes: &[SignatureScheme],
107 ) -> Option<sign::CertifiedKey> {
108 Some(self.0.clone())
109 }
110
111 fn has_certs(&self) -> bool {
112 true
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use super::*;
119 use crate::StoresClientSessions;
120
121 #[test]
122 fn test_noclientsessionstorage_drops_put() {
123 let c = NoClientSessionStorage {};
124 assert_eq!(c.put(vec![0x01], vec![0x02]), false);
125 }
126
127 #[test]
128 fn test_noclientsessionstorage_denies_gets() {
129 let c = NoClientSessionStorage {};
130 c.put(vec![0x01], vec![0x02]);
131 assert_eq!(c.get(&[]), None);
132 assert_eq!(c.get(&[0x01]), None);
133 assert_eq!(c.get(&[0x02]), None);
134 }
135
136 #[test]
137 fn test_clientsessionmemorycache_accepts_put() {
138 let c = ClientSessionMemoryCache::new(4);
139 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
140 }
141
142 #[test]
143 fn test_clientsessionmemorycache_persists_put() {
144 let c = ClientSessionMemoryCache::new(4);
145 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
146 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
147 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
148 }
149
150 #[test]
151 fn test_clientsessionmemorycache_overwrites_put() {
152 let c = ClientSessionMemoryCache::new(4);
153 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
154 assert_eq!(c.put(vec![0x01], vec![0x04]), true);
155 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
156 }
157
158 #[test]
159 fn test_clientsessionmemorycache_drops_to_maintain_size_invariant() {
160 let c = ClientSessionMemoryCache::new(4);
161 assert_eq!(c.put(vec![0x01], vec![0x02]), true);
162 assert_eq!(c.put(vec![0x03], vec![0x04]), true);
163 assert_eq!(c.put(vec![0x05], vec![0x06]), true);
164 assert_eq!(c.put(vec![0x07], vec![0x08]), true);
165 assert_eq!(c.put(vec![0x09], vec![0x0a]), true);
166
167 let mut count = 0;
168 if c.get(&[0x01]).is_some() {
169 count += 1;
170 }
171 if c.get(&[0x03]).is_some() {
172 count += 1;
173 }
174 if c.get(&[0x05]).is_some() {
175 count += 1;
176 }
177 if c.get(&[0x07]).is_some() {
178 count += 1;
179 }
180 if c.get(&[0x09]).is_some() {
181 count += 1;
182 }
183
184 assert_eq!(count, 4);
185 }
186}