rustls/
hash_hs.rs

1#[cfg(feature = "logging")]
2use crate::log::warn;
3use crate::msgs::codec::Codec;
4use crate::msgs::handshake::HandshakeMessagePayload;
5use crate::msgs::message::{Message, MessagePayload};
6use ring::digest;
7use std::mem;
8
9/// This deals with keeping a running hash of the handshake
10/// payloads.  This is computed by buffering initially.  Once
11/// we know what hash function we need to use we switch to
12/// incremental hashing.
13///
14/// For client auth, we also need to buffer all the messages.
15/// This is disabled in cases where client auth is not possible.
16pub struct HandshakeHash {
17    /// None before we know what hash function we're using
18    alg: Option<&'static digest::Algorithm>,
19
20    /// None before we know what hash function we're using
21    ctx: Option<digest::Context>,
22
23    /// true if we need to keep all messages
24    client_auth_enabled: bool,
25
26    /// buffer for pre-hashing stage and client-auth.
27    buffer: Vec<u8>,
28}
29
30impl HandshakeHash {
31    pub fn new() -> HandshakeHash {
32        HandshakeHash {
33            alg: None,
34            ctx: None,
35            client_auth_enabled: false,
36            buffer: Vec::new(),
37        }
38    }
39
40    /// We might be doing client auth, so need to keep a full
41    /// log of the handshake.
42    pub fn set_client_auth_enabled(&mut self) {
43        debug_assert!(self.ctx.is_none()); // or we might have already discarded messages
44        self.client_auth_enabled = true;
45    }
46
47    /// We decided not to do client auth after all, so discard
48    /// the transcript.
49    pub fn abandon_client_auth(&mut self) {
50        self.client_auth_enabled = false;
51        self.buffer.drain(..);
52    }
53
54    /// We now know what hash function the verify_data will use.
55    pub fn start_hash(&mut self, alg: &'static digest::Algorithm) -> bool {
56        match self.alg {
57            None => {}
58            Some(started) => {
59                if started != alg {
60                    // hash type is changing
61                    warn!("altered hash to HandshakeHash::start_hash");
62                    return false;
63                }
64
65                return true;
66            }
67        }
68        self.alg = Some(alg);
69        debug_assert!(self.ctx.is_none());
70
71        let mut ctx = digest::Context::new(alg);
72        ctx.update(&self.buffer);
73        self.ctx = Some(ctx);
74
75        // Discard buffer if we don't need it now.
76        if !self.client_auth_enabled {
77            self.buffer.drain(..);
78        }
79        true
80    }
81
82    /// Hash/buffer a handshake message.
83    pub fn add_message(&mut self, m: &Message) -> &mut HandshakeHash {
84        match m.payload {
85            MessagePayload::Handshake(ref hs) => {
86                let buf = hs.get_encoding();
87                self.update_raw(&buf);
88            }
89            _ => {}
90        };
91        self
92    }
93
94    /// Hash or buffer a byte slice.
95    fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
96        if self.ctx.is_some() {
97            self.ctx.as_mut().unwrap().update(buf);
98        }
99
100        if self.ctx.is_none() || self.client_auth_enabled {
101            self.buffer.extend_from_slice(buf);
102        }
103
104        self
105    }
106
107    /// Get the hash value if we were to hash `extra` too,
108    /// using hash function `hash`.
109    pub fn get_hash_given(&self, hash: &'static digest::Algorithm, extra: &[u8]) -> Vec<u8> {
110        let mut ctx = if self.ctx.is_none() {
111            let mut ctx = digest::Context::new(hash);
112            ctx.update(&self.buffer);
113            ctx
114        } else {
115            self.ctx.as_ref().unwrap().clone()
116        };
117
118        ctx.update(extra);
119        let hash = ctx.finish();
120        let mut ret = Vec::new();
121        ret.extend_from_slice(hash.as_ref());
122        ret
123    }
124
125    /// Take the current hash value, and encapsulate it in a
126    /// 'handshake_hash' handshake message.  Start this hash
127    /// again, with that message at the front.
128    pub fn rollup_for_hrr(&mut self) {
129        let old_hash = self.ctx.take().unwrap().finish();
130        let old_handshake_hash_msg =
131            HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
132
133        self.ctx = Some(digest::Context::new(self.alg.unwrap()));
134        self.update_raw(&old_handshake_hash_msg.get_encoding());
135    }
136
137    /// Get the current hash value.
138    pub fn get_current_hash(&self) -> Vec<u8> {
139        let hash = self
140            .ctx
141            .as_ref()
142            .unwrap()
143            .clone()
144            .finish();
145        let mut ret = Vec::new();
146        ret.extend_from_slice(hash.as_ref());
147        ret
148    }
149
150    /// Takes this object's buffer containing all handshake messages
151    /// so far.  This method only works once; it resets the buffer
152    /// to empty.
153    pub fn take_handshake_buf(&mut self) -> Vec<u8> {
154        debug_assert!(self.client_auth_enabled);
155        mem::replace(&mut self.buffer, Vec::new())
156    }
157}
158
159#[cfg(test)]
160mod test {
161    use super::HandshakeHash;
162    use ring::digest;
163
164    #[test]
165    fn hashes_correctly() {
166        let mut hh = HandshakeHash::new();
167        hh.update_raw(b"hello");
168        assert_eq!(hh.buffer.len(), 5);
169        hh.start_hash(&digest::SHA256);
170        assert_eq!(hh.buffer.len(), 0);
171        hh.update_raw(b"world");
172        let h = hh.get_current_hash();
173        assert_eq!(h[0], 0x93);
174        assert_eq!(h[1], 0x6a);
175        assert_eq!(h[2], 0x18);
176        assert_eq!(h[3], 0x5c);
177    }
178
179    #[test]
180    fn buffers_correctly() {
181        let mut hh = HandshakeHash::new();
182        hh.set_client_auth_enabled();
183        hh.update_raw(b"hello");
184        assert_eq!(hh.buffer.len(), 5);
185        hh.start_hash(&digest::SHA256);
186        assert_eq!(hh.buffer.len(), 5);
187        hh.update_raw(b"world");
188        assert_eq!(hh.buffer.len(), 10);
189        let h = hh.get_current_hash();
190        assert_eq!(h[0], 0x93);
191        assert_eq!(h[1], 0x6a);
192        assert_eq!(h[2], 0x18);
193        assert_eq!(h[3], 0x5c);
194        let buf = hh.take_handshake_buf();
195        assert_eq!(b"helloworld".to_vec(), buf);
196    }
197
198    #[test]
199    fn abandon() {
200        let mut hh = HandshakeHash::new();
201        hh.set_client_auth_enabled();
202        hh.update_raw(b"hello");
203        assert_eq!(hh.buffer.len(), 5);
204        hh.start_hash(&digest::SHA256);
205        assert_eq!(hh.buffer.len(), 5);
206        hh.abandon_client_auth();
207        assert_eq!(hh.buffer.len(), 0);
208        hh.update_raw(b"world");
209        assert_eq!(hh.buffer.len(), 0);
210        let h = hh.get_current_hash();
211        assert_eq!(h[0], 0x93);
212        assert_eq!(h[1], 0x6a);
213        assert_eq!(h[2], 0x18);
214        assert_eq!(h[3], 0x5c);
215    }
216}