use serde::{Deserialize, Serialize};

use serde_json;
use std::cell::Cell;
use std::convert::TryFrom;
use std::convert::{From, Into};
use std::fmt;
use std::ops::{BitAnd, BitOr, Deref, DerefMut, Index, Range};
use std::string::String;
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Eq)]
pub struct BitArr<const N: usize>(#[serde(with = "serde_arrays")] pub [bool; N]);

// operations on BitArr
impl<const N: usize> BitArr<N> {
    pub fn sign_zero_ext<const M: usize>(&self, b: bool) -> BitArr<M> {
        let msb = if b { self.0[N - 1] } else { false };
        let mut res = [msb; M];
        res[0..N].copy_from_slice(&self.0);

        BitArr(res)
    }

    pub fn sign_ext<const M: usize>(&self) -> BitArr<M> {
        self.sign_zero_ext(true)
    }

    pub fn zero_ext<const M: usize>(&self) -> BitArr<M> {
        self.sign_zero_ext(false)
    }
}

impl<const N: usize> DerefMut for BitArr<N> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

impl<const N: usize> Deref for BitArr<N> {
    type Target = [bool; N];

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const N: usize> fmt::Binary for BitArr<N> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let mut s = String::new();
        for i in self.0.iter().rev() {
            s.push(if *i { '1' } else { '0' });
        }
        write!(f, "{}", s)
    }
}

// Convert a u8 to BitArr<N>
// panic: N is too small to fit all true bits of u8
// if N is larger than 8, other bits set zero
impl<const N: usize> From<u8> for BitArr<N> {
    fn from(u: u8) -> Self {
        let mut bs = [false; N];
        for i in 0..8 {
            let b = u & (1 << i);
            if b != 0 {
                bs[i] = true
            }
        }
        BitArr(bs)
    }
}

// Convert a BitArr<N> to u8
// panic: if N > 8
// if N <= 8 upper bits are set 0
impl<const N: usize> From<BitArr<N>> for u8 {
    fn from(bs: BitArr<N>) -> Self {
        assert!(N <= 8);
        let mut b = 0u8;
        for i in 0..N {
            if bs[i] {
                b |= 1 << i
            }
        }
        b
    }
}

// operations on bit array

impl<const N: usize> BitAnd for &BitArr<N> {
    type Output = BitArr<N>;

    // rhs is the "right-hand side" of the expression `a & b`
    fn bitand(self, rhs: Self) -> Self::Output {
        let mut bs = [false; N];
        for i in 0..self.0.len() {
            bs[i] = self.0[i] & rhs.0[i]
        }
        BitArr(bs)
    }
}

impl<const N: usize> BitOr for &BitArr<N> {
    type Output = BitArr<N>;

    // rhs is the "right-hand side" of the expression `a & b`
    fn bitor(self, rhs: Self) -> Self::Output {
        let mut bs = [false; N];
        for i in 0..self.0.len() {
            bs[i] = self.0[i] | rhs.0[i]
        }
        BitArr(bs)
    }
}

mod test {
    use crate::bitarr::*;
    #[test]
    fn extensions() {
        let a: BitArr<2> = 0b10u8.into();
        let ze: BitArr<4> = a.zero_ext();
        let se: BitArr<4> = a.sign_ext();
        assert!(ze == 0b0010u8.into());
        assert!(se == 0b1110u8.into());
    }
}