base64/
decode.rs

1use crate::{tables, Config, PAD_BYTE};
2
3#[cfg(any(feature = "alloc", feature = "std", test))]
4use crate::STANDARD;
5#[cfg(any(feature = "alloc", feature = "std", test))]
6use alloc::vec::Vec;
7use core::fmt;
8#[cfg(any(feature = "std", test))]
9use std::error;
10
11// decode logic operates on chunks of 8 input bytes without padding
12const INPUT_CHUNK_LEN: usize = 8;
13const DECODED_CHUNK_LEN: usize = 6;
14// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
15// 2 bytes of any output u64 should not be counted as written to (but must be available in a
16// slice).
17const DECODED_CHUNK_SUFFIX: usize = 2;
18
19// how many u64's of input to handle at a time
20const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
21const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
22// includes the trailing 2 bytes for the final u64 write
23const DECODED_BLOCK_LEN: usize =
24    CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
25
26/// Errors that can occur while decoding.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum DecodeError {
29    /// An invalid byte was found in the input. The offset and offending byte are provided.
30    InvalidByte(usize, u8),
31    /// The length of the input is invalid.
32    /// A typical cause of this is stray trailing whitespace or other separator bytes.
33    /// In the case where excess trailing bytes have produced an invalid length *and* the last byte
34    /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte`
35    /// will be emitted instead of `InvalidLength` to make the issue easier to debug.
36    InvalidLength,
37    /// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
38    /// This is indicative of corrupted or truncated Base64.
39    /// Unlike InvalidByte, which reports symbols that aren't in the alphabet, this error is for
40    /// symbols that are in the alphabet but represent nonsensical encodings.
41    InvalidLastSymbol(usize, u8),
42}
43
44impl fmt::Display for DecodeError {
45    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
46        match *self {
47            DecodeError::InvalidByte(index, byte) => {
48                write!(f, "Invalid byte {}, offset {}.", byte, index)
49            }
50            DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
51            DecodeError::InvalidLastSymbol(index, byte) => {
52                write!(f, "Invalid last symbol {}, offset {}.", byte, index)
53            }
54        }
55    }
56}
57
58#[cfg(any(feature = "std", test))]
59impl error::Error for DecodeError {
60    fn description(&self) -> &str {
61        match *self {
62            DecodeError::InvalidByte(_, _) => "invalid byte",
63            DecodeError::InvalidLength => "invalid length",
64            DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
65        }
66    }
67
68    fn cause(&self) -> Option<&dyn error::Error> {
69        None
70    }
71}
72
73///Decode from string reference as octets.
74///Returns a Result containing a Vec<u8>.
75///Convenience `decode_config(input, base64::STANDARD);`.
76///
77///# Example
78///
79///```rust
80///extern crate base64;
81///
82///fn main() {
83///    let bytes = base64::decode("aGVsbG8gd29ybGQ=").unwrap();
84///    println!("{:?}", bytes);
85///}
86///```
87#[cfg(any(feature = "alloc", feature = "std", test))]
88pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, DecodeError> {
89    decode_config(input, STANDARD)
90}
91
92///Decode from string reference as octets.
93///Returns a Result containing a Vec<u8>.
94///
95///# Example
96///
97///```rust
98///extern crate base64;
99///
100///fn main() {
101///    let bytes = base64::decode_config("aGVsbG8gd29ybGR+Cg==", base64::STANDARD).unwrap();
102///    println!("{:?}", bytes);
103///
104///    let bytes_url = base64::decode_config("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE).unwrap();
105///    println!("{:?}", bytes_url);
106///}
107///```
108#[cfg(any(feature = "alloc", feature = "std", test))]
109pub fn decode_config<T: AsRef<[u8]>>(input: T, config: Config) -> Result<Vec<u8>, DecodeError> {
110    let decoded_length_estimate = (input
111        .as_ref()
112        .len()
113        .checked_add(3)
114        .expect("decoded length calculation overflow"))
115        / 4
116        * 3;
117    let mut buffer = Vec::<u8>::with_capacity(decoded_length_estimate);
118
119    decode_config_buf(input, config, &mut buffer).map(|_| buffer)
120}
121
122///Decode from string reference as octets.
123///Writes into the supplied buffer to avoid allocation.
124///Returns a Result containing an empty tuple, aka ().
125///
126///# Example
127///
128///```rust
129///extern crate base64;
130///
131///fn main() {
132///    let mut buffer = Vec::<u8>::new();
133///    base64::decode_config_buf("aGVsbG8gd29ybGR+Cg==", base64::STANDARD, &mut buffer).unwrap();
134///    println!("{:?}", buffer);
135///
136///    buffer.clear();
137///
138///    base64::decode_config_buf("aGVsbG8gaW50ZXJuZXR-Cg==", base64::URL_SAFE, &mut buffer)
139///        .unwrap();
140///    println!("{:?}", buffer);
141///}
142///```
143#[cfg(any(feature = "alloc", feature = "std", test))]
144pub fn decode_config_buf<T: AsRef<[u8]>>(
145    input: T,
146    config: Config,
147    buffer: &mut Vec<u8>,
148) -> Result<(), DecodeError> {
149    let input_bytes = input.as_ref();
150
151    let starting_output_len = buffer.len();
152
153    let num_chunks = num_chunks(input_bytes);
154    let decoded_len_estimate = num_chunks
155        .checked_mul(DECODED_CHUNK_LEN)
156        .and_then(|p| p.checked_add(starting_output_len))
157        .expect("Overflow when calculating output buffer length");
158    buffer.resize(decoded_len_estimate, 0);
159
160    let bytes_written;
161    {
162        let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
163        bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
164    }
165
166    buffer.truncate(starting_output_len + bytes_written);
167
168    Ok(())
169}
170
171/// Decode the input into the provided output slice.
172///
173/// This will not write any bytes past exactly what is decoded (no stray garbage bytes at the end).
174///
175/// If you don't know ahead of time what the decoded length should be, size your buffer with a
176/// conservative estimate for the decoded length of an input: 3 bytes of output for every 4 bytes of
177/// input, rounded up, or in other words `(input_len + 3) / 4 * 3`.
178///
179/// If the slice is not large enough, this will panic.
180pub fn decode_config_slice<T: AsRef<[u8]>>(
181    input: T,
182    config: Config,
183    output: &mut [u8],
184) -> Result<usize, DecodeError> {
185    let input_bytes = input.as_ref();
186
187    decode_helper(input_bytes, num_chunks(input_bytes), config, output)
188}
189
190/// Return the number of input chunks (including a possibly partial final chunk) in the input
191fn num_chunks(input: &[u8]) -> usize {
192    input
193        .len()
194        .checked_add(INPUT_CHUNK_LEN - 1)
195        .expect("Overflow when calculating number of chunks in input")
196        / INPUT_CHUNK_LEN
197}
198
199/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
200/// Returns the number of bytes written, or an error.
201// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
202// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
203// but this is fragile and the best setting changes with only minor code modifications.
204#[inline]
205fn decode_helper(
206    input: &[u8],
207    num_chunks: usize,
208    config: Config,
209    output: &mut [u8],
210) -> Result<usize, DecodeError> {
211    let char_set = config.char_set;
212    let decode_table = char_set.decode_table();
213
214    let remainder_len = input.len() % INPUT_CHUNK_LEN;
215
216    // Because the fast decode loop writes in groups of 8 bytes (unrolled to
217    // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
218    // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
219    // soon enough that there will always be 2 more bytes of valid data written after that loop.
220    let trailing_bytes_to_skip = match remainder_len {
221        // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
222        // and the fast decode logic cannot handle padding
223        0 => INPUT_CHUNK_LEN,
224        // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
225        1 | 5 => {
226            // trailing whitespace is so common that it's worth it to check the last byte to
227            // possibly return a better error message
228            if let Some(b) = input.last() {
229                if *b != PAD_BYTE && decode_table[*b as usize] == tables::INVALID_VALUE {
230                    return Err(DecodeError::InvalidByte(input.len() - 1, *b));
231                }
232            }
233
234            return Err(DecodeError::InvalidLength);
235        }
236        // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
237        // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
238        // previous chunk.
239        2 => INPUT_CHUNK_LEN + 2,
240        // If this is 3 unpadded chars, then it would actually decode to 2 bytes. However, if this
241        // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
242        // with an error, not panic from going past the bounds of the output slice, so we let it
243        // use stage 3 + 4.
244        3 => INPUT_CHUNK_LEN + 3,
245        // This can also decode to one output byte because it may be 2 input chars + 2 padding
246        // chars, which would decode to 1 byte.
247        4 => INPUT_CHUNK_LEN + 4,
248        // Everything else is a legal decode len (given that we don't require padding), and will
249        // decode to at least 2 bytes of output.
250        _ => remainder_len,
251    };
252
253    // rounded up to include partial chunks
254    let mut remaining_chunks = num_chunks;
255
256    let mut input_index = 0;
257    let mut output_index = 0;
258
259    {
260        let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
261
262        // Fast loop, stage 1
263        // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
264        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
265            while input_index <= max_start_index {
266                let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
267                let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
268
269                decode_chunk(
270                    &input_slice[0..],
271                    input_index,
272                    decode_table,
273                    &mut output_slice[0..],
274                )?;
275                decode_chunk(
276                    &input_slice[8..],
277                    input_index + 8,
278                    decode_table,
279                    &mut output_slice[6..],
280                )?;
281                decode_chunk(
282                    &input_slice[16..],
283                    input_index + 16,
284                    decode_table,
285                    &mut output_slice[12..],
286                )?;
287                decode_chunk(
288                    &input_slice[24..],
289                    input_index + 24,
290                    decode_table,
291                    &mut output_slice[18..],
292                )?;
293
294                input_index += INPUT_BLOCK_LEN;
295                output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
296                remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
297            }
298        }
299
300        // Fast loop, stage 2 (aka still pretty fast loop)
301        // 8 bytes at a time for whatever we didn't do in stage 1.
302        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
303            while input_index < max_start_index {
304                decode_chunk(
305                    &input[input_index..(input_index + INPUT_CHUNK_LEN)],
306                    input_index,
307                    decode_table,
308                    &mut output
309                        [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
310                )?;
311
312                output_index += DECODED_CHUNK_LEN;
313                input_index += INPUT_CHUNK_LEN;
314                remaining_chunks -= 1;
315            }
316        }
317    }
318
319    // Stage 3
320    // If input length was such that a chunk had to be deferred until after the fast loop
321    // because decoding it would have produced 2 trailing bytes that wouldn't then be
322    // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
323    // trailing bytes.
324    // However, we still need to avoid the last chunk (partial or complete) because it could
325    // have padding, so we always do 1 fewer to avoid the last chunk.
326    for _ in 1..remaining_chunks {
327        decode_chunk_precise(
328            &input[input_index..],
329            input_index,
330            decode_table,
331            &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
332        )?;
333
334        input_index += INPUT_CHUNK_LEN;
335        output_index += DECODED_CHUNK_LEN;
336    }
337
338    // always have one more (possibly partial) block of 8 input
339    debug_assert!(input.len() - input_index > 1 || input.is_empty());
340    debug_assert!(input.len() - input_index <= 8);
341
342    // Stage 4
343    // Finally, decode any leftovers that aren't a complete input block of 8 bytes.
344    // Use a u64 as a stack-resident 8 byte buffer.
345    let mut leftover_bits: u64 = 0;
346    let mut morsels_in_leftover = 0;
347    let mut padding_bytes = 0;
348    let mut first_padding_index: usize = 0;
349    let mut last_symbol = 0_u8;
350    let start_of_leftovers = input_index;
351    for (i, b) in input[start_of_leftovers..].iter().enumerate() {
352        // '=' padding
353        if *b == PAD_BYTE {
354            // There can be bad padding in a few ways:
355            // 1 - Padding with non-padding characters after it
356            // 2 - Padding after zero or one non-padding characters before it
357            //     in the current quad.
358            // 3 - More than two characters of padding. If 3 or 4 padding chars
359            //     are in the same quad, that implies it will be caught by #2.
360            //     If it spreads from one quad to another, it will be caught by
361            //     #2 in the second quad.
362
363            if i % 4 < 2 {
364                // Check for case #2.
365                let bad_padding_index = start_of_leftovers
366                    + if padding_bytes > 0 {
367                        // If we've already seen padding, report the first padding index.
368                        // This is to be consistent with the faster logic above: it will report an
369                        // error on the first padding character (since it doesn't expect to see
370                        // anything but actual encoded data).
371                        first_padding_index
372                    } else {
373                        // haven't seen padding before, just use where we are now
374                        i
375                    };
376                return Err(DecodeError::InvalidByte(bad_padding_index, *b));
377            }
378
379            if padding_bytes == 0 {
380                first_padding_index = i;
381            }
382
383            padding_bytes += 1;
384            continue;
385        }
386
387        // Check for case #1.
388        // To make '=' handling consistent with the main loop, don't allow
389        // non-suffix '=' in trailing chunk either. Report error as first
390        // erroneous padding.
391        if padding_bytes > 0 {
392            return Err(DecodeError::InvalidByte(
393                start_of_leftovers + first_padding_index,
394                PAD_BYTE,
395            ));
396        }
397        last_symbol = *b;
398
399        // can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
400        // To minimize shifts, pack the leftovers from left to right.
401        let shift = 64 - (morsels_in_leftover + 1) * 6;
402        // tables are all 256 elements, lookup with a u8 index always succeeds
403        let morsel = decode_table[*b as usize];
404        if morsel == tables::INVALID_VALUE {
405            return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
406        }
407
408        leftover_bits |= (morsel as u64) << shift;
409        morsels_in_leftover += 1;
410    }
411
412    let leftover_bits_ready_to_append = match morsels_in_leftover {
413        0 => 0,
414        2 => 8,
415        3 => 16,
416        4 => 24,
417        6 => 32,
418        7 => 40,
419        8 => 48,
420        _ => unreachable!(
421            "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
422        ),
423    };
424
425    // if there are bits set outside the bits we care about, last symbol encodes trailing bits that
426    // will not be included in the output
427    let mask = !0 >> leftover_bits_ready_to_append;
428    if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
429        // last morsel is at `morsels_in_leftover` - 1
430        return Err(DecodeError::InvalidLastSymbol(
431            start_of_leftovers + morsels_in_leftover - 1,
432            last_symbol,
433        ));
434    }
435
436    let mut leftover_bits_appended_to_buf = 0;
437    while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
438        // `as` simply truncates the higher bits, which is what we want here
439        let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
440        output[output_index] = selected_bits;
441        output_index += 1;
442
443        leftover_bits_appended_to_buf += 8;
444    }
445
446    Ok(output_index)
447}
448
449#[inline]
450fn write_u64(output: &mut [u8], value: u64) {
451    output[..8].copy_from_slice(&value.to_be_bytes());
452}
453
454/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
455/// first 6 of those contain meaningful data.
456///
457/// `input` is the bytes to decode, of which the first 8 bytes will be processed.
458/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
459/// accurately)
460/// `decode_table` is the lookup table for the particular base64 alphabet.
461/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
462/// data.
463// yes, really inline (worth 30-50% speedup)
464#[inline(always)]
465fn decode_chunk(
466    input: &[u8],
467    index_at_start_of_input: usize,
468    decode_table: &[u8; 256],
469    output: &mut [u8],
470) -> Result<(), DecodeError> {
471    let mut accum: u64;
472
473    let morsel = decode_table[input[0] as usize];
474    if morsel == tables::INVALID_VALUE {
475        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
476    }
477    accum = (morsel as u64) << 58;
478
479    let morsel = decode_table[input[1] as usize];
480    if morsel == tables::INVALID_VALUE {
481        return Err(DecodeError::InvalidByte(
482            index_at_start_of_input + 1,
483            input[1],
484        ));
485    }
486    accum |= (morsel as u64) << 52;
487
488    let morsel = decode_table[input[2] as usize];
489    if morsel == tables::INVALID_VALUE {
490        return Err(DecodeError::InvalidByte(
491            index_at_start_of_input + 2,
492            input[2],
493        ));
494    }
495    accum |= (morsel as u64) << 46;
496
497    let morsel = decode_table[input[3] as usize];
498    if morsel == tables::INVALID_VALUE {
499        return Err(DecodeError::InvalidByte(
500            index_at_start_of_input + 3,
501            input[3],
502        ));
503    }
504    accum |= (morsel as u64) << 40;
505
506    let morsel = decode_table[input[4] as usize];
507    if morsel == tables::INVALID_VALUE {
508        return Err(DecodeError::InvalidByte(
509            index_at_start_of_input + 4,
510            input[4],
511        ));
512    }
513    accum |= (morsel as u64) << 34;
514
515    let morsel = decode_table[input[5] as usize];
516    if morsel == tables::INVALID_VALUE {
517        return Err(DecodeError::InvalidByte(
518            index_at_start_of_input + 5,
519            input[5],
520        ));
521    }
522    accum |= (morsel as u64) << 28;
523
524    let morsel = decode_table[input[6] as usize];
525    if morsel == tables::INVALID_VALUE {
526        return Err(DecodeError::InvalidByte(
527            index_at_start_of_input + 6,
528            input[6],
529        ));
530    }
531    accum |= (morsel as u64) << 22;
532
533    let morsel = decode_table[input[7] as usize];
534    if morsel == tables::INVALID_VALUE {
535        return Err(DecodeError::InvalidByte(
536            index_at_start_of_input + 7,
537            input[7],
538        ));
539    }
540    accum |= (morsel as u64) << 16;
541
542    write_u64(output, accum);
543
544    Ok(())
545}
546
547/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
548/// trailing garbage bytes.
549#[inline]
550fn decode_chunk_precise(
551    input: &[u8],
552    index_at_start_of_input: usize,
553    decode_table: &[u8; 256],
554    output: &mut [u8],
555) -> Result<(), DecodeError> {
556    let mut tmp_buf = [0_u8; 8];
557
558    decode_chunk(
559        input,
560        index_at_start_of_input,
561        decode_table,
562        &mut tmp_buf[..],
563    )?;
564
565    output[0..6].copy_from_slice(&tmp_buf[0..6]);
566
567    Ok(())
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use crate::{
574        encode::encode_config_buf,
575        encode::encode_config_slice,
576        tests::{assert_encode_sanity, random_config},
577    };
578
579    use rand::{
580        distributions::{Distribution, Uniform},
581        FromEntropy, Rng,
582    };
583
584    #[test]
585    fn decode_chunk_precise_writes_only_6_bytes() {
586        let input = b"Zm9vYmFy"; // "foobar"
587        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
588        decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
589        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
590    }
591
592    #[test]
593    fn decode_chunk_writes_8_bytes() {
594        let input = b"Zm9vYmFy"; // "foobar"
595        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
596        decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
597        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
598    }
599
600    #[test]
601    fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
602        let mut orig_data = Vec::new();
603        let mut encoded_data = String::new();
604        let mut decoded_with_prefix = Vec::new();
605        let mut decoded_without_prefix = Vec::new();
606        let mut prefix = Vec::new();
607
608        let prefix_len_range = Uniform::new(0, 1000);
609        let input_len_range = Uniform::new(0, 1000);
610
611        let mut rng = rand::rngs::SmallRng::from_entropy();
612
613        for _ in 0..10_000 {
614            orig_data.clear();
615            encoded_data.clear();
616            decoded_with_prefix.clear();
617            decoded_without_prefix.clear();
618            prefix.clear();
619
620            let input_len = input_len_range.sample(&mut rng);
621
622            for _ in 0..input_len {
623                orig_data.push(rng.gen());
624            }
625
626            let config = random_config(&mut rng);
627            encode_config_buf(&orig_data, config, &mut encoded_data);
628            assert_encode_sanity(&encoded_data, config, input_len);
629
630            let prefix_len = prefix_len_range.sample(&mut rng);
631
632            // fill the buf with a prefix
633            for _ in 0..prefix_len {
634                prefix.push(rng.gen());
635            }
636
637            decoded_with_prefix.resize(prefix_len, 0);
638            decoded_with_prefix.copy_from_slice(&prefix);
639
640            // decode into the non-empty buf
641            decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
642            // also decode into the empty buf
643            decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
644
645            assert_eq!(
646                prefix_len + decoded_without_prefix.len(),
647                decoded_with_prefix.len()
648            );
649            assert_eq!(orig_data, decoded_without_prefix);
650
651            // append plain decode onto prefix
652            prefix.append(&mut decoded_without_prefix);
653
654            assert_eq!(prefix, decoded_with_prefix);
655        }
656    }
657
658    #[test]
659    fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
660        let mut orig_data = Vec::new();
661        let mut encoded_data = String::new();
662        let mut decode_buf = Vec::new();
663        let mut decode_buf_copy: Vec<u8> = Vec::new();
664
665        let input_len_range = Uniform::new(0, 1000);
666
667        let mut rng = rand::rngs::SmallRng::from_entropy();
668
669        for _ in 0..10_000 {
670            orig_data.clear();
671            encoded_data.clear();
672            decode_buf.clear();
673            decode_buf_copy.clear();
674
675            let input_len = input_len_range.sample(&mut rng);
676
677            for _ in 0..input_len {
678                orig_data.push(rng.gen());
679            }
680
681            let config = random_config(&mut rng);
682            encode_config_buf(&orig_data, config, &mut encoded_data);
683            assert_encode_sanity(&encoded_data, config, input_len);
684
685            // fill the buffer with random garbage, long enough to have some room before and after
686            for _ in 0..5000 {
687                decode_buf.push(rng.gen());
688            }
689
690            // keep a copy for later comparison
691            decode_buf_copy.extend(decode_buf.iter());
692
693            let offset = 1000;
694
695            // decode into the non-empty buf
696            let decode_bytes_written =
697                decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
698
699            assert_eq!(orig_data.len(), decode_bytes_written);
700            assert_eq!(
701                orig_data,
702                &decode_buf[offset..(offset + decode_bytes_written)]
703            );
704            assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
705            assert_eq!(
706                &decode_buf_copy[offset + decode_bytes_written..],
707                &decode_buf[offset + decode_bytes_written..]
708            );
709        }
710    }
711
712    #[test]
713    fn decode_into_slice_fits_in_precisely_sized_slice() {
714        let mut orig_data = Vec::new();
715        let mut encoded_data = String::new();
716        let mut decode_buf = Vec::new();
717
718        let input_len_range = Uniform::new(0, 1000);
719
720        let mut rng = rand::rngs::SmallRng::from_entropy();
721
722        for _ in 0..10_000 {
723            orig_data.clear();
724            encoded_data.clear();
725            decode_buf.clear();
726
727            let input_len = input_len_range.sample(&mut rng);
728
729            for _ in 0..input_len {
730                orig_data.push(rng.gen());
731            }
732
733            let config = random_config(&mut rng);
734            encode_config_buf(&orig_data, config, &mut encoded_data);
735            assert_encode_sanity(&encoded_data, config, input_len);
736
737            decode_buf.resize(input_len, 0);
738
739            // decode into the non-empty buf
740            let decode_bytes_written =
741                decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
742
743            assert_eq!(orig_data.len(), decode_bytes_written);
744            assert_eq!(orig_data, decode_buf);
745        }
746    }
747
748    #[test]
749    fn detect_invalid_last_symbol_two_bytes() {
750        let decode =
751            |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
752
753        // example from https://github.com/marshallpierce/rust-base64/issues/75
754        assert!(decode("iYU=", false).is_ok());
755        // trailing 01
756        assert_eq!(
757            Err(DecodeError::InvalidLastSymbol(2, b'V')),
758            decode("iYV=", false)
759        );
760        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
761        // trailing 10
762        assert_eq!(
763            Err(DecodeError::InvalidLastSymbol(2, b'W')),
764            decode("iYW=", false)
765        );
766        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
767        // trailing 11
768        assert_eq!(
769            Err(DecodeError::InvalidLastSymbol(2, b'X')),
770            decode("iYX=", false)
771        );
772        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
773
774        // also works when there are 2 quads in the last block
775        assert_eq!(
776            Err(DecodeError::InvalidLastSymbol(6, b'X')),
777            decode("AAAAiYX=", false)
778        );
779        assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
780    }
781
782    #[test]
783    fn detect_invalid_last_symbol_one_byte() {
784        // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
785
786        assert!(decode("/w==").is_ok());
787        // trailing 01
788        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
789        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
790        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
791        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
792        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
793        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
794
795        // also works when there are 2 quads in the last block
796        assert_eq!(
797            Err(DecodeError::InvalidLastSymbol(5, b'x')),
798            decode("AAAA/x==")
799        );
800    }
801
802    #[test]
803    fn detect_invalid_last_symbol_every_possible_three_symbols() {
804        let mut base64_to_bytes = ::std::collections::HashMap::new();
805
806        let mut bytes = [0_u8; 2];
807        for b1 in 0_u16..256 {
808            bytes[0] = b1 as u8;
809            for b2 in 0_u16..256 {
810                bytes[1] = b2 as u8;
811                let mut b64 = vec![0_u8; 4];
812                assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
813                let mut v = ::std::vec::Vec::with_capacity(2);
814                v.extend_from_slice(&bytes[..]);
815
816                assert!(base64_to_bytes.insert(b64, v).is_none());
817            }
818        }
819
820        // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol
821
822        let mut symbols = [0_u8; 4];
823        for &s1 in STANDARD.char_set.encode_table().iter() {
824            symbols[0] = s1;
825            for &s2 in STANDARD.char_set.encode_table().iter() {
826                symbols[1] = s2;
827                for &s3 in STANDARD.char_set.encode_table().iter() {
828                    symbols[2] = s3;
829                    symbols[3] = PAD_BYTE;
830
831                    match base64_to_bytes.get(&symbols[..]) {
832                        Some(bytes) => {
833                            assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
834                        }
835                        None => assert_eq!(
836                            Err(DecodeError::InvalidLastSymbol(2, s3)),
837                            decode_config(&symbols[..], STANDARD)
838                        ),
839                    }
840                }
841            }
842        }
843    }
844
845    #[test]
846    fn detect_invalid_last_symbol_every_possible_two_symbols() {
847        let mut base64_to_bytes = ::std::collections::HashMap::new();
848
849        for b in 0_u16..256 {
850            let mut b64 = vec![0_u8; 4];
851            assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
852            let mut v = ::std::vec::Vec::with_capacity(1);
853            v.push(b as u8);
854
855            assert!(base64_to_bytes.insert(b64, v).is_none());
856        }
857
858        // every possible combination of symbols must either decode to 1 byte or get InvalidLastSymbol
859
860        let mut symbols = [0_u8; 4];
861        for &s1 in STANDARD.char_set.encode_table().iter() {
862            symbols[0] = s1;
863            for &s2 in STANDARD.char_set.encode_table().iter() {
864                symbols[1] = s2;
865                symbols[2] = PAD_BYTE;
866                symbols[3] = PAD_BYTE;
867
868                match base64_to_bytes.get(&symbols[..]) {
869                    Some(bytes) => {
870                        assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
871                    }
872                    None => assert_eq!(
873                        Err(DecodeError::InvalidLastSymbol(1, s2)),
874                        decode_config(&symbols[..], STANDARD)
875                    ),
876                }
877            }
878        }
879    }
880
881    #[test]
882    fn decode_config_estimation_works_for_various_lengths() {
883        for num_prefix_quads in 0..100 {
884            for suffix in &["AA", "AAA", "AAAA"] {
885                let mut prefix = "AAAA".repeat(num_prefix_quads);
886                prefix.push_str(suffix);
887                // make sure no overflow (and thus a panic) occurs
888                let res = decode_config(prefix, STANDARD);
889                assert!(res.is_ok());
890            }
891        }
892    }
893}