1use crate::rand;
2use crate::server::ProducesTickets;
3
4use ring::aead;
5use std::mem;
6use std::sync::{Arc, Mutex};
7use std::time;
8
9pub fn timebase() -> u64 {
12 time::SystemTime::now()
13 .duration_since(time::UNIX_EPOCH)
14 .unwrap()
15 .as_secs()
16}
17
18pub struct AEADTicketer {
23 alg: &'static aead::Algorithm,
24 key: aead::LessSafeKey,
25 lifetime: u32,
26}
27
28impl AEADTicketer {
29 pub fn new_custom(
34 alg: &'static aead::Algorithm,
35 key: &[u8],
36 lifetime_seconds: u32,
37 ) -> AEADTicketer {
38 let key = aead::UnboundKey::new(alg, key).unwrap();
39 AEADTicketer {
40 alg,
41 key: aead::LessSafeKey::new(key),
42 lifetime: lifetime_seconds,
43 }
44 }
45
46 pub fn new() -> AEADTicketer {
48 let mut key = [0u8; 32];
49 rand::fill_random(&mut key);
50 AEADTicketer::new_custom(&aead::CHACHA20_POLY1305, &key, 60 * 60 * 12)
51 }
52}
53
54impl ProducesTickets for AEADTicketer {
55 fn enabled(&self) -> bool {
56 true
57 }
58 fn get_lifetime(&self) -> u32 {
59 self.lifetime
60 }
61
62 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
64 let mut nonce_buf = [0u8; 12];
66 rand::fill_random(&mut nonce_buf);
67 let nonce = ring::aead::Nonce::assume_unique_for_key(nonce_buf);
68 let aad = ring::aead::Aad::empty();
69
70 let mut ciphertext =
71 Vec::with_capacity(nonce_buf.len() + message.len() + self.key.algorithm().tag_len());
72 ciphertext.extend(&nonce_buf);
73 ciphertext.extend(message);
74 self.key
75 .seal_in_place_separate_tag(nonce, aad, &mut ciphertext[nonce_buf.len()..])
76 .map(|tag| {
77 ciphertext.extend(tag.as_ref());
78 ciphertext
79 })
80 .ok()
81 }
82
83 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
85 let nonce_len = self.alg.nonce_len();
86 let tag_len = self.alg.tag_len();
87
88 if ciphertext.len() < nonce_len + tag_len {
89 return None;
90 }
91
92 let nonce =
93 ring::aead::Nonce::try_assume_unique_for_key(&ciphertext[0..nonce_len]).unwrap();
94 let aad = ring::aead::Aad::empty();
95
96 let mut out = Vec::new();
97 out.extend_from_slice(&ciphertext[nonce_len..]);
98
99 let plain_len = match self
100 .key
101 .open_in_place(nonce, aad, &mut out)
102 {
103 Ok(plaintext) => plaintext.len(),
104 Err(..) => {
105 return None;
106 }
107 };
108
109 out.truncate(plain_len);
110 Some(out)
111 }
112}
113
114struct TicketSwitcherState {
115 current: Box<dyn ProducesTickets>,
116 previous: Option<Box<dyn ProducesTickets>>,
117 next_switch_time: u64,
118}
119
120pub struct TicketSwitcher {
124 generator: fn() -> Box<dyn ProducesTickets>,
125 lifetime: u32,
126 state: Mutex<TicketSwitcherState>,
127}
128
129impl TicketSwitcher {
130 pub fn new(lifetime: u32, generator: fn() -> Box<dyn ProducesTickets>) -> TicketSwitcher {
135 TicketSwitcher {
136 generator,
137 lifetime,
138 state: Mutex::new(TicketSwitcherState {
139 current: generator(),
140 previous: None,
141 next_switch_time: timebase() + u64::from(lifetime),
142 }),
143 }
144 }
145
146 pub fn maybe_roll(&self) {
153 let mut state = self.state.lock().unwrap();
154 let now = timebase();
155
156 if now > state.next_switch_time {
157 state.previous = Some(mem::replace(&mut state.current, (self.generator)()));
158 state.next_switch_time = now + u64::from(self.lifetime);
159 }
160 }
161}
162
163impl ProducesTickets for TicketSwitcher {
164 fn get_lifetime(&self) -> u32 {
165 self.lifetime * 2
166 }
167
168 fn enabled(&self) -> bool {
169 true
170 }
171
172 fn encrypt(&self, message: &[u8]) -> Option<Vec<u8>> {
173 self.maybe_roll();
174
175 self.state
176 .lock()
177 .unwrap()
178 .current
179 .encrypt(message)
180 }
181
182 fn decrypt(&self, ciphertext: &[u8]) -> Option<Vec<u8>> {
183 self.maybe_roll();
184
185 let state = self.state.lock().unwrap();
186 let rc = state.current.decrypt(ciphertext);
187
188 if rc.is_none() && state.previous.is_some() {
189 state
190 .previous
191 .as_ref()
192 .unwrap()
193 .decrypt(ciphertext)
194 } else {
195 rc
196 }
197 }
198}
199
200pub struct Ticketer {}
202
203fn generate_inner() -> Box<dyn ProducesTickets> {
204 Box::new(AEADTicketer::new())
205}
206
207impl Ticketer {
208 pub fn new() -> Arc<dyn ProducesTickets> {
213 Arc::new(TicketSwitcher::new(6 * 60 * 60, generate_inner))
214 }
215}
216
217#[test]
218fn basic_pairwise_test() {
219 let t = Ticketer::new();
220 assert_eq!(true, t.enabled());
221 let cipher = t.encrypt(b"hello world").unwrap();
222 let plain = t.decrypt(&cipher).unwrap();
223 assert_eq!(plain, b"hello world");
224}