1#[cfg(feature = "logging")]
2use crate::log::warn;
3use crate::msgs::codec::Codec;
4use crate::msgs::handshake::HandshakeMessagePayload;
5use crate::msgs::message::{Message, MessagePayload};
6use ring::digest;
7use std::mem;
8
9pub struct HandshakeHash {
17 alg: Option<&'static digest::Algorithm>,
19
20 ctx: Option<digest::Context>,
22
23 client_auth_enabled: bool,
25
26 buffer: Vec<u8>,
28}
29
30impl HandshakeHash {
31 pub fn new() -> HandshakeHash {
32 HandshakeHash {
33 alg: None,
34 ctx: None,
35 client_auth_enabled: false,
36 buffer: Vec::new(),
37 }
38 }
39
40 pub fn set_client_auth_enabled(&mut self) {
43 debug_assert!(self.ctx.is_none()); self.client_auth_enabled = true;
45 }
46
47 pub fn abandon_client_auth(&mut self) {
50 self.client_auth_enabled = false;
51 self.buffer.drain(..);
52 }
53
54 pub fn start_hash(&mut self, alg: &'static digest::Algorithm) -> bool {
56 match self.alg {
57 None => {}
58 Some(started) => {
59 if started != alg {
60 warn!("altered hash to HandshakeHash::start_hash");
62 return false;
63 }
64
65 return true;
66 }
67 }
68 self.alg = Some(alg);
69 debug_assert!(self.ctx.is_none());
70
71 let mut ctx = digest::Context::new(alg);
72 ctx.update(&self.buffer);
73 self.ctx = Some(ctx);
74
75 if !self.client_auth_enabled {
77 self.buffer.drain(..);
78 }
79 true
80 }
81
82 pub fn add_message(&mut self, m: &Message) -> &mut HandshakeHash {
84 match m.payload {
85 MessagePayload::Handshake(ref hs) => {
86 let buf = hs.get_encoding();
87 self.update_raw(&buf);
88 }
89 _ => {}
90 };
91 self
92 }
93
94 fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
96 if self.ctx.is_some() {
97 self.ctx.as_mut().unwrap().update(buf);
98 }
99
100 if self.ctx.is_none() || self.client_auth_enabled {
101 self.buffer.extend_from_slice(buf);
102 }
103
104 self
105 }
106
107 pub fn get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8> {
110 let mut ctx = if self.ctx.is_none() {
111 let mut ctx = digest::Context::new(hash);
112 ctx.update(&self.buffer);
113 ctx
114 } else {
115 self.ctx.as_ref().unwrap().clone()
116 };
117
118 ctx.update(extra);
119 let hash = ctx.finish();
120 let mut ret = Vec::new();
121 ret.extend_from_slice(hash.as_ref());
122 ret
123 }
124
125 pub fn rollup_for_hrr(&mut self) {
129 let old_hash = self.ctx.take().unwrap().finish();
130 let old_handshake_hash_msg =
131 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
132
133 self.ctx = Some(digest::Context::new(self.alg.unwrap()));
134 self.update_raw(&old_handshake_hash_msg.get_encoding());
135 }
136
137 pub fn get_current_hash(&self) -> Vec<u8> {
139 let hash = self
140 .ctx
141 .as_ref()
142 .unwrap()
143 .clone()
144 .finish();
145 let mut ret = Vec::new();
146 ret.extend_from_slice(hash.as_ref());
147 ret
148 }
149
150 pub fn take_handshake_buf(&mut self) -> Vec<u8> {
154 debug_assert!(self.client_auth_enabled);
155 mem::replace(&mut self.buffer, Vec::new())
156 }
157}
158
159#[cfg(test)]
160mod test {
161 use super::HandshakeHash;
162 use ring::digest;
163
164 #[test]
165 fn hashes_correctly() {
166 let mut hh = HandshakeHash::new();
167 hh.update_raw(b"hello");
168 assert_eq!(hh.buffer.len(), 5);
169 hh.start_hash(&digest::SHA256);
170 assert_eq!(hh.buffer.len(), 0);
171 hh.update_raw(b"world");
172 let h = hh.get_current_hash();
173 assert_eq!(h[0], 0x93);
174 assert_eq!(h[1], 0x6a);
175 assert_eq!(h[2], 0x18);
176 assert_eq!(h[3], 0x5c);
177 }
178
179 #[test]
180 fn buffers_correctly() {
181 let mut hh = HandshakeHash::new();
182 hh.set_client_auth_enabled();
183 hh.update_raw(b"hello");
184 assert_eq!(hh.buffer.len(), 5);
185 hh.start_hash(&digest::SHA256);
186 assert_eq!(hh.buffer.len(), 5);
187 hh.update_raw(b"world");
188 assert_eq!(hh.buffer.len(), 10);
189 let h = hh.get_current_hash();
190 assert_eq!(h[0], 0x93);
191 assert_eq!(h[1], 0x6a);
192 assert_eq!(h[2], 0x18);
193 assert_eq!(h[3], 0x5c);
194 let buf = hh.take_handshake_buf();
195 assert_eq!(b"helloworld".to_vec(), buf);
196 }
197
198 #[test]
199 fn abandon() {
200 let mut hh = HandshakeHash::new();
201 hh.set_client_auth_enabled();
202 hh.update_raw(b"hello");
203 assert_eq!(hh.buffer.len(), 5);
204 hh.start_hash(&digest::SHA256);
205 assert_eq!(hh.buffer.len(), 5);
206 hh.abandon_client_auth();
207 assert_eq!(hh.buffer.len(), 0);
208 hh.update_raw(b"world");
209 assert_eq!(hh.buffer.len(), 0);
210 let h = hh.get_current_hash();
211 assert_eq!(h[0], 0x93);
212 assert_eq!(h[1], 0x6a);
213 assert_eq!(h[2], 0x18);
214 assert_eq!(h[3], 0x5c);
215 }
216}