use crate::{tables, Config};
#[cfg(any(feature = "alloc", feature = "std", test))]
use crate::STANDARD;
#[cfg(any(feature = "alloc", feature = "std", test))]
use alloc::vec::Vec;
#[cfg(any(feature = "std", test))]
use std::error;
use core::fmt;
const INPUT_CHUNK_LEN: usize = 8;
const DECODED_CHUNK_LEN: usize = 6;
const DECODED_CHUNK_SUFFIX: usize = 2;
const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
const DECODED_BLOCK_LEN: usize =
    CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DecodeError {
    
    InvalidByte(usize, u8),
    
    InvalidLength,
    
    
    
    
    InvalidLastSymbol(usize, u8),
}
impl fmt::Display for DecodeError {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match *self {
            DecodeError::InvalidByte(index, byte) => {
                write!(f, "Invalid byte {}, offset {}.", byte, index)
            }
            DecodeError::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
            DecodeError::InvalidLastSymbol(index, byte) => {
                write!(f, "Invalid last symbol {}, offset {}.", byte, index)
            }
        }
    }
}
#[cfg(any(feature = "std", test))]
impl error::Error for DecodeError {
    fn description(&self) -> &str {
        match *self {
            DecodeError::InvalidByte(_, _) => "invalid byte",
            DecodeError::InvalidLength => "invalid length",
            DecodeError::InvalidLastSymbol(_, _) => "invalid last symbol",
        }
    }
    fn cause(&self) -> Option<&dyn error::Error> {
        None
    }
}
#[cfg(any(feature = "alloc", feature = "std", test))]
pub fn decode<T: ?Sized + AsRef<[u8]>>(input: &T) -> Result<Vec<u8>, DecodeError> {
    decode_config(input, STANDARD)
}
#[cfg(any(feature = "alloc", feature = "std", test))]
pub fn decode_config<T: ?Sized + AsRef<[u8]>>(
    input: &T,
    config: Config,
) -> Result<Vec<u8>, DecodeError> {
    let mut buffer = Vec::<u8>::with_capacity(input.as_ref().len() * 4 / 3);
    decode_config_buf(input, config, &mut buffer).map(|_| buffer)
}
#[cfg(any(feature = "alloc", feature = "std", test))]
pub fn decode_config_buf<T: ?Sized + AsRef<[u8]>>(
    input: &T,
    config: Config,
    buffer: &mut Vec<u8>,
) -> Result<(), DecodeError> {
    let input_bytes = input.as_ref();
    let starting_output_len = buffer.len();
    let num_chunks = num_chunks(input_bytes);
    let decoded_len_estimate = num_chunks
        .checked_mul(DECODED_CHUNK_LEN)
        .and_then(|p| p.checked_add(starting_output_len))
        .expect("Overflow when calculating output buffer length");
    buffer.resize(decoded_len_estimate, 0);
    let bytes_written;
    {
        let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
        bytes_written = decode_helper(input_bytes, num_chunks, config, buffer_slice)?;
    }
    buffer.truncate(starting_output_len + bytes_written);
    Ok(())
}
pub fn decode_config_slice<T: ?Sized + AsRef<[u8]>>(
    input: &T,
    config: Config,
    output: &mut [u8],
) -> Result<usize, DecodeError> {
    let input_bytes = input.as_ref();
    decode_helper(input_bytes, num_chunks(input_bytes), config, output)
}
fn num_chunks(input: &[u8]) -> usize {
    input
        .len()
        .checked_add(INPUT_CHUNK_LEN - 1)
        .expect("Overflow when calculating number of chunks in input")
        / INPUT_CHUNK_LEN
}
#[inline]
fn decode_helper(
    input: &[u8],
    num_chunks: usize,
    config: Config,
    output: &mut [u8],
) -> Result<usize, DecodeError> {
    let char_set = config.char_set;
    let decode_table = char_set.decode_table();
    let remainder_len = input.len() % INPUT_CHUNK_LEN;
    
    
    
    
    let trailing_bytes_to_skip = match remainder_len {
        
        
        0 => INPUT_CHUNK_LEN,
        
        1 | 5 => return Err(DecodeError::InvalidLength),
        
        
        
        2 => INPUT_CHUNK_LEN + 2,
        
        
        
        
        3 => INPUT_CHUNK_LEN + 3,
        
        
        4 => INPUT_CHUNK_LEN + 4,
        
        
        _ => remainder_len,
    };
    
    let mut remaining_chunks = num_chunks;
    let mut input_index = 0;
    let mut output_index = 0;
    {
        let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
        
        
        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
            while input_index <= max_start_index {
                let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
                let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
                decode_chunk(
                    &input_slice[0..],
                    input_index,
                    decode_table,
                    &mut output_slice[0..],
                )?;
                decode_chunk(
                    &input_slice[8..],
                    input_index + 8,
                    decode_table,
                    &mut output_slice[6..],
                )?;
                decode_chunk(
                    &input_slice[16..],
                    input_index + 16,
                    decode_table,
                    &mut output_slice[12..],
                )?;
                decode_chunk(
                    &input_slice[24..],
                    input_index + 24,
                    decode_table,
                    &mut output_slice[18..],
                )?;
                input_index += INPUT_BLOCK_LEN;
                output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
                remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
            }
        }
        
        
        if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
            while input_index < max_start_index {
                decode_chunk(
                    &input[input_index..(input_index + INPUT_CHUNK_LEN)],
                    input_index,
                    decode_table,
                    &mut output
                        [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
                )?;
                output_index += DECODED_CHUNK_LEN;
                input_index += INPUT_CHUNK_LEN;
                remaining_chunks -= 1;
            }
        }
    }
    
    
    
    
    
    
    
    for _ in 1..remaining_chunks {
        decode_chunk_precise(
            &input[input_index..],
            input_index,
            decode_table,
            &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
        )?;
        input_index += INPUT_CHUNK_LEN;
        output_index += DECODED_CHUNK_LEN;
    }
    
    debug_assert!(input.len() - input_index > 1 || input.is_empty());
    debug_assert!(input.len() - input_index <= 8);
    
    
    
    let mut leftover_bits: u64 = 0;
    let mut morsels_in_leftover = 0;
    let mut padding_bytes = 0;
    let mut first_padding_index: usize = 0;
    let mut last_symbol = 0_u8;
    let start_of_leftovers = input_index;
    for (i, b) in input[start_of_leftovers..].iter().enumerate() {
        
        if *b == 0x3D {
            
            
            
            
            
            
            
            
            if i % 4 < 2 {
                
                let bad_padding_index = start_of_leftovers
                    + if padding_bytes > 0 {
                        
                        
                        
                        
                        first_padding_index
                    } else {
                        
                        i
                    };
                return Err(DecodeError::InvalidByte(bad_padding_index, *b));
            }
            if padding_bytes == 0 {
                first_padding_index = i;
            }
            padding_bytes += 1;
            continue;
        }
        
        
        
        
        if padding_bytes > 0 {
            return Err(DecodeError::InvalidByte(
                start_of_leftovers + first_padding_index,
                0x3D,
            ));
        }
        last_symbol = *b;
        
        
        let shift = 64 - (morsels_in_leftover + 1) * 6;
        
        let morsel = decode_table[*b as usize];
        if morsel == tables::INVALID_VALUE {
            return Err(DecodeError::InvalidByte(start_of_leftovers + i, *b));
        }
        leftover_bits |= (morsel as u64) << shift;
        morsels_in_leftover += 1;
    }
    let leftover_bits_ready_to_append = match morsels_in_leftover {
        0 => 0,
        2 => 8,
        3 => 16,
        4 => 24,
        6 => 32,
        7 => 40,
        8 => 48,
        _ => unreachable!(
            "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
        ),
    };
    
    
    let mask = !0 >> leftover_bits_ready_to_append;
    if !config.decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
        
        return Err(DecodeError::InvalidLastSymbol(
            start_of_leftovers + morsels_in_leftover - 1,
            last_symbol,
        ));
    }
    let mut leftover_bits_appended_to_buf = 0;
    while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
        
        let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
        output[output_index] = selected_bits;
        output_index += 1;
        leftover_bits_appended_to_buf += 8;
    }
    Ok(output_index)
}
#[inline]
fn write_u64(output: &mut [u8], value: u64) {
    output[..8].copy_from_slice(&value.to_be_bytes());
}
#[inline(always)]
fn decode_chunk(
    input: &[u8],
    index_at_start_of_input: usize,
    decode_table: &[u8; 256],
    output: &mut [u8],
) -> Result<(), DecodeError> {
    let mut accum: u64;
    let morsel = decode_table[input[0] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
    }
    accum = (morsel as u64) << 58;
    let morsel = decode_table[input[1] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 1,
            input[1],
        ));
    }
    accum |= (morsel as u64) << 52;
    let morsel = decode_table[input[2] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 2,
            input[2],
        ));
    }
    accum |= (morsel as u64) << 46;
    let morsel = decode_table[input[3] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 3,
            input[3],
        ));
    }
    accum |= (morsel as u64) << 40;
    let morsel = decode_table[input[4] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 4,
            input[4],
        ));
    }
    accum |= (morsel as u64) << 34;
    let morsel = decode_table[input[5] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 5,
            input[5],
        ));
    }
    accum |= (morsel as u64) << 28;
    let morsel = decode_table[input[6] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 6,
            input[6],
        ));
    }
    accum |= (morsel as u64) << 22;
    let morsel = decode_table[input[7] as usize];
    if morsel == tables::INVALID_VALUE {
        return Err(DecodeError::InvalidByte(
            index_at_start_of_input + 7,
            input[7],
        ));
    }
    accum |= (morsel as u64) << 16;
    write_u64(output, accum);
    Ok(())
}
#[inline]
fn decode_chunk_precise(
    input: &[u8],
    index_at_start_of_input: usize,
    decode_table: &[u8; 256],
    output: &mut [u8],
) -> Result<(), DecodeError> {
    let mut tmp_buf = [0_u8; 8];
    decode_chunk(
        input,
        index_at_start_of_input,
        decode_table,
        &mut tmp_buf[..],
    )?;
    output[0..6].copy_from_slice(&tmp_buf[0..6]);
    Ok(())
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        encode::encode_config_buf,
        encode::encode_config_slice,
        tests::{assert_encode_sanity, random_config},
    };
    use rand::{
        distributions::{Distribution, Uniform},
        FromEntropy, Rng,
    };
    #[test]
    fn decode_chunk_precise_writes_only_6_bytes() {
        let input = b"Zm9vYmFy"; 
        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
        decode_chunk_precise(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
    }
    #[test]
    fn decode_chunk_writes_8_bytes() {
        let input = b"Zm9vYmFy"; 
        let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
        decode_chunk(&input[..], 0, tables::STANDARD_DECODE, &mut output).unwrap();
        assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
    }
    #[test]
    fn decode_into_nonempty_vec_doesnt_clobber_existing_prefix() {
        let mut orig_data = Vec::new();
        let mut encoded_data = String::new();
        let mut decoded_with_prefix = Vec::new();
        let mut decoded_without_prefix = Vec::new();
        let mut prefix = Vec::new();
        let prefix_len_range = Uniform::new(0, 1000);
        let input_len_range = Uniform::new(0, 1000);
        let mut rng = rand::rngs::SmallRng::from_entropy();
        for _ in 0..10_000 {
            orig_data.clear();
            encoded_data.clear();
            decoded_with_prefix.clear();
            decoded_without_prefix.clear();
            prefix.clear();
            let input_len = input_len_range.sample(&mut rng);
            for _ in 0..input_len {
                orig_data.push(rng.gen());
            }
            let config = random_config(&mut rng);
            encode_config_buf(&orig_data, config, &mut encoded_data);
            assert_encode_sanity(&encoded_data, config, input_len);
            let prefix_len = prefix_len_range.sample(&mut rng);
            
            for _ in 0..prefix_len {
                prefix.push(rng.gen());
            }
            decoded_with_prefix.resize(prefix_len, 0);
            decoded_with_prefix.copy_from_slice(&prefix);
            
            decode_config_buf(&encoded_data, config, &mut decoded_with_prefix).unwrap();
            
            decode_config_buf(&encoded_data, config, &mut decoded_without_prefix).unwrap();
            assert_eq!(
                prefix_len + decoded_without_prefix.len(),
                decoded_with_prefix.len()
            );
            assert_eq!(orig_data, decoded_without_prefix);
            
            prefix.append(&mut decoded_without_prefix);
            assert_eq!(prefix, decoded_with_prefix);
        }
    }
    #[test]
    fn decode_into_slice_doesnt_clobber_existing_prefix_or_suffix() {
        let mut orig_data = Vec::new();
        let mut encoded_data = String::new();
        let mut decode_buf = Vec::new();
        let mut decode_buf_copy: Vec<u8> = Vec::new();
        let input_len_range = Uniform::new(0, 1000);
        let mut rng = rand::rngs::SmallRng::from_entropy();
        for _ in 0..10_000 {
            orig_data.clear();
            encoded_data.clear();
            decode_buf.clear();
            decode_buf_copy.clear();
            let input_len = input_len_range.sample(&mut rng);
            for _ in 0..input_len {
                orig_data.push(rng.gen());
            }
            let config = random_config(&mut rng);
            encode_config_buf(&orig_data, config, &mut encoded_data);
            assert_encode_sanity(&encoded_data, config, input_len);
            
            for _ in 0..5000 {
                decode_buf.push(rng.gen());
            }
            
            decode_buf_copy.extend(decode_buf.iter());
            let offset = 1000;
            
            let decode_bytes_written =
                decode_config_slice(&encoded_data, config, &mut decode_buf[offset..]).unwrap();
            assert_eq!(orig_data.len(), decode_bytes_written);
            assert_eq!(
                orig_data,
                &decode_buf[offset..(offset + decode_bytes_written)]
            );
            assert_eq!(&decode_buf_copy[0..offset], &decode_buf[0..offset]);
            assert_eq!(
                &decode_buf_copy[offset + decode_bytes_written..],
                &decode_buf[offset + decode_bytes_written..]
            );
        }
    }
    #[test]
    fn decode_into_slice_fits_in_precisely_sized_slice() {
        let mut orig_data = Vec::new();
        let mut encoded_data = String::new();
        let mut decode_buf = Vec::new();
        let input_len_range = Uniform::new(0, 1000);
        let mut rng = rand::rngs::SmallRng::from_entropy();
        for _ in 0..10_000 {
            orig_data.clear();
            encoded_data.clear();
            decode_buf.clear();
            let input_len = input_len_range.sample(&mut rng);
            for _ in 0..input_len {
                orig_data.push(rng.gen());
            }
            let config = random_config(&mut rng);
            encode_config_buf(&orig_data, config, &mut encoded_data);
            assert_encode_sanity(&encoded_data, config, input_len);
            decode_buf.resize(input_len, 0);
            
            let decode_bytes_written =
                decode_config_slice(&encoded_data, config, &mut decode_buf[..]).unwrap();
            assert_eq!(orig_data.len(), decode_bytes_written);
            assert_eq!(orig_data, decode_buf);
        }
    }
    #[test]
    fn detect_invalid_last_symbol_two_bytes() {
        let decode =
            |input, forgiving| decode_config(input, STANDARD.decode_allow_trailing_bits(forgiving));
        
        assert!(decode("iYU=", false).is_ok());
        
        assert_eq!(
            Err(DecodeError::InvalidLastSymbol(2, b'V')),
            decode("iYV=", false)
        );
        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
        
        assert_eq!(
            Err(DecodeError::InvalidLastSymbol(2, b'W')),
            decode("iYW=", false)
        );
        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
        
        assert_eq!(
            Err(DecodeError::InvalidLastSymbol(2, b'X')),
            decode("iYX=", false)
        );
        assert_eq!(Ok(vec![137, 133]), decode("iYV=", true));
        
        assert_eq!(
            Err(DecodeError::InvalidLastSymbol(6, b'X')),
            decode("AAAAiYX=", false)
        );
        assert_eq!(Ok(vec![0, 0, 0, 137, 133]), decode("AAAAiYX=", true));
    }
    #[test]
    fn detect_invalid_last_symbol_one_byte() {
        
        assert!(decode("/w==").is_ok());
        
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'x')), decode("/x=="));
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'z')), decode("/z=="));
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'0')), decode("/0=="));
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'9')), decode("/9=="));
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'+')), decode("/+=="));
        assert_eq!(Err(DecodeError::InvalidLastSymbol(1, b'/')), decode("//=="));
        
        assert_eq!(
            Err(DecodeError::InvalidLastSymbol(5, b'x')),
            decode("AAAA/x==")
        );
    }
    #[test]
    fn detect_invalid_last_symbol_every_possible_three_symbols() {
        let mut base64_to_bytes = ::std::collections::HashMap::new();
        let mut bytes = [0_u8; 2];
        for b1 in 0_u16..256 {
            bytes[0] = b1 as u8;
            for b2 in 0_u16..256 {
                bytes[1] = b2 as u8;
                let mut b64 = vec![0_u8; 4];
                assert_eq!(4, encode_config_slice(&bytes, STANDARD, &mut b64[..]));
                let mut v = ::std::vec::Vec::with_capacity(2);
                v.extend_from_slice(&bytes[..]);
                assert!(base64_to_bytes.insert(b64, v).is_none());
            }
        }
        
        let mut symbols = [0_u8; 4];
        for &s1 in STANDARD.char_set.encode_table().iter() {
            symbols[0] = s1;
            for &s2 in STANDARD.char_set.encode_table().iter() {
                symbols[1] = s2;
                for &s3 in STANDARD.char_set.encode_table().iter() {
                    symbols[2] = s3;
                    symbols[3] = b'=';
                    match base64_to_bytes.get(&symbols[..]) {
                        Some(bytes) => {
                            assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
                        }
                        None => assert_eq!(
                            Err(DecodeError::InvalidLastSymbol(2, s3)),
                            decode_config(&symbols[..], STANDARD)
                        ),
                    }
                }
            }
        }
    }
    #[test]
    fn detect_invalid_last_symbol_every_possible_two_symbols() {
        let mut base64_to_bytes = ::std::collections::HashMap::new();
        for b in 0_u16..256 {
            let mut b64 = vec![0_u8; 4];
            assert_eq!(4, encode_config_slice(&[b as u8], STANDARD, &mut b64[..]));
            let mut v = ::std::vec::Vec::with_capacity(1);
            v.push(b as u8);
            assert!(base64_to_bytes.insert(b64, v).is_none());
        }
        
        let mut symbols = [0_u8; 4];
        for &s1 in STANDARD.char_set.encode_table().iter() {
            symbols[0] = s1;
            for &s2 in STANDARD.char_set.encode_table().iter() {
                symbols[1] = s2;
                symbols[2] = b'=';
                symbols[3] = b'=';
                match base64_to_bytes.get(&symbols[..]) {
                    Some(bytes) => {
                        assert_eq!(Ok(bytes.to_vec()), decode_config(&symbols, STANDARD))
                    }
                    None => assert_eq!(
                        Err(DecodeError::InvalidLastSymbol(1, s2)),
                        decode_config(&symbols[..], STANDARD)
                    ),
                }
            }
        }
    }
}