use std::{cell::OnceCell, collections::HashMap};
use itertools::Itertools;
use crate::{
    ir::{
        self,
        editor::action::Action,
        function::FunctionDefinitionIndex,
        quantity::Quantity,
        statement::{IRStatement, IsIRStatement},
        FunctionDefinition, RegisterName,
    },
    utility::data_type::Type,
};
use super::IsAnalyzer;
#[derive(Debug, Clone, Default)]
pub struct MemoryAccessInfo {
    pub alloca: Option<FunctionDefinitionIndex>,
    pub store: Vec<FunctionDefinitionIndex>,
    pub load: Vec<FunctionDefinitionIndex>,
    store_group_by_basic_block: OnceCell<HashMap<usize, Vec<usize>>>,
    load_group_by_basic_block: OnceCell<HashMap<usize, Vec<usize>>>,
}
impl MemoryAccessInfo {
    fn store_group_by_basic_block(&self) -> &HashMap<usize, Vec<usize>> {
        self.store_group_by_basic_block.get_or_init(|| {
            self.store
                .iter()
                .group_by(|it| it.0)
                .into_iter()
                .map(|(basic_block_id, it)| {
                    (basic_block_id, it.into_iter().map(|it| it.1).collect())
                })
                .collect()
        })
    }
    fn load_group_by_basic_block(&self) -> &HashMap<usize, Vec<usize>> {
        self.load_group_by_basic_block.get_or_init(|| {
            self.load
                .iter()
                .group_by(|it| it.0)
                .into_iter()
                .map(|(basic_block_id, it)| {
                    (basic_block_id, it.into_iter().map(|it| it.1).collect())
                })
                .collect()
        })
    }
    pub fn loads_dorminated_by_store_in_block(
        &self,
        store: &FunctionDefinitionIndex,
    ) -> Vec<FunctionDefinitionIndex> {
        let store_in_basic_block = self.store_group_by_basic_block().get(&store.0).unwrap();
        let next_store_index = store_in_basic_block
            .iter()
            .find(|&&it| it > store.1)
            .cloned()
            .unwrap_or(usize::MAX);
        self.load_group_by_basic_block()
            .get(&store.0)
            .unwrap_or(&Vec::new())
            .iter()
            .filter(|&&it| it > store.1 && it < next_store_index)
            .map(|it| (store.0, *it).into())
            .collect_vec()
    }
}
#[derive(Debug, Default)]
pub struct MemoryUsage {
    memory_access: OnceCell<HashMap<RegisterName, MemoryAccessInfo>>,
}
impl MemoryUsage {
    pub fn new() -> Self {
        Self {
            memory_access: OnceCell::new(),
        }
    }
    fn memory_access_info(
        &self,
        function: &ir::FunctionDefinition,
        variable_name: &RegisterName,
    ) -> &MemoryAccessInfo {
        self.memory_access(function).get(variable_name).unwrap()
    }
    fn memory_access_variables(
        &self,
        function: &ir::FunctionDefinition,
    ) -> impl Iterator<Item = &RegisterName> {
        self.memory_access(function).keys()
    }
    fn memory_access_variables_and_types(
        &self,
        function: &ir::FunctionDefinition,
    ) -> HashMap<RegisterName, Type> {
        self.memory_access(function)
            .iter()
            .map(|(variable, info)| {
                let data_type = function[info.alloca.clone().unwrap()]
                    .as_alloca()
                    .alloc_type
                    .clone();
                (variable.clone(), data_type)
            })
            .collect()
    }
    fn memory_access(
        &self,
        function: &ir::FunctionDefinition,
    ) -> &HashMap<RegisterName, MemoryAccessInfo> {
        self.memory_access
            .get_or_init(|| self.init_memory_access(function))
    }
    fn init_memory_access(
        &self,
        function: &ir::FunctionDefinition,
    ) -> HashMap<RegisterName, MemoryAccessInfo> {
        let mut memory_access: HashMap<RegisterName, MemoryAccessInfo> = HashMap::new();
        for (index, statement) in function.iter().function_definition_index_enumerate() {
            match statement {
                IRStatement::Alloca(_) => {
                    memory_access
                        .entry(statement.generate_register().unwrap().0)
                        .or_default()
                        .alloca = Some(index.clone());
                }
                IRStatement::Store(store) => {
                    if let Quantity::RegisterName(local) = &store.target {
                        memory_access
                            .entry(local.clone())
                            .or_default()
                            .store
                            .push(index);
                    }
                }
                IRStatement::Load(load) => {
                    if let Quantity::RegisterName(local) = &load.from {
                        memory_access
                            .entry(local.clone())
                            .or_default()
                            .load
                            .push(index);
                    }
                }
                _ => (),
            }
        }
        memory_access
    }
}
pub struct BindedMemoryUsage<'item, 'bind: 'item> {
    bind_on: &'bind FunctionDefinition,
    item: &'item MemoryUsage,
}
impl<'item, 'bind: 'item> BindedMemoryUsage<'item, 'bind> {
    pub fn memory_access_info(&self, variable_name: &RegisterName) -> &MemoryAccessInfo {
        self.item.memory_access_info(self.bind_on, variable_name)
    }
    pub fn memory_access_variables(&self) -> impl Iterator<Item = &RegisterName> {
        self.item.memory_access_variables(self.bind_on)
    }
    pub fn memory_access_variables_and_types(&self) -> HashMap<RegisterName, Type> {
        self.item.memory_access_variables_and_types(self.bind_on)
    }
}
impl<'item, 'bind: 'item> IsAnalyzer<'item, 'bind> for MemoryUsage {
    fn on_action(&mut self, _action: &Action) {
        self.memory_access.take();
    }
    type Binded = BindedMemoryUsage<'item, 'bind>;
    fn bind(&'item self, content: &'bind ir::FunctionDefinition) -> Self::Binded {
        BindedMemoryUsage {
            bind_on: content,
            item: self,
        }
    }
}