use itertools::Itertools;
use std::mem;
use wasm_encoder::{
    CodeSection, ConstExpr, ExportKind, ExportSection, FunctionSection, GlobalSection, GlobalType,
    MemorySection, MemoryType, Module, TypeSection, ValType,
};
use crate::ir::{
    analyzer::{BindedControlFlowGraph, BindedScc, ControlFlowGraph, IsAnalyzer},
    editor::Analyzer,
    statement::IRStatement,
    FunctionDefinition, IR,
};
use self::{
    control_flow::{CFSelector, ControlFlowElement},
    lowering::{lower_function_body, lower_function_type},
};
mod control_flow;
mod lowering;
fn fold_loop(scc: &BindedScc, current_result: &mut Vec<ControlFlowElement>) {
    if let Some(sub_sccs) = scc.top_level_sccs() {
        for sub_scc in sub_sccs
            .into_iter()
            .filter(|sub_scc: &BindedScc<'_>| !sub_scc.is_trivial())
        {
            let sub_scc_start_index = current_result
                .iter()
                .position(|it| {
                    if let &ControlFlowElement::BasicBlock { id: block_id } = it {
                        sub_scc.contains(block_id)
                    } else {
                        false
                    }
                })
                .unwrap();
            let sub_scc_end_index = current_result
                .iter()
                .rposition(|it| {
                    if let &ControlFlowElement::BasicBlock { id: block_id } = it {
                        sub_scc.contains(block_id)
                    } else {
                        false
                    }
                })
                .unwrap();
            let mut new_result = current_result[sub_scc_start_index..=sub_scc_end_index]
                .iter()
                .cloned()
                .collect_vec();
            fold_loop(&sub_scc, &mut new_result);
            current_result.splice(
                sub_scc_start_index..=sub_scc_end_index,
                [ControlFlowElement::Loop {
                    content: new_result.into_iter().collect(),
                }],
            );
        }
    }
}
fn fold_if_else_once(
    content: &mut ControlFlowElement,
    control_flow_graph: BindedControlFlowGraph,
) -> bool {
    for block_id in 0..control_flow_graph.bind_on.content.len() {
        let predecessors = control_flow_graph.predecessor(block_id);
        if predecessors.len() == 1 {
            let predecessor_block_id = predecessors[0];
            let predecessor_last_instruction = control_flow_graph.bind_on[predecessor_block_id]
                .content
                .last();
            if !matches!(predecessor_last_instruction, Some(IRStatement::Branch(_))) {
                continue;
            }
            let predecessor_selector = content.find_node(predecessor_block_id).unwrap();
            let block_selector = content.find_node(block_id).unwrap();
            let if_element_selector = if predecessor_selector.is_if_condition() {
                let mut if_element_predecessor_in_selector = predecessor_selector.clone();
                if_element_predecessor_in_selector.pop_back();
                if if_element_predecessor_in_selector.is_parent_of(&block_selector) {
                    continue;
                }
                if_element_predecessor_in_selector
            } else {
                content.replace(
                    &predecessor_selector,
                    ControlFlowElement::If {
                        condition: Box::new(ControlFlowElement::BasicBlock {
                            id: predecessor_block_id,
                        }),
                        on_success: Vec::new(),
                        on_failure: Vec::new(),
                    },
                );
                predecessor_selector.clone()
            };
            let to_move_selectors = collect_to_move(content, &block_selector, &control_flow_graph);
            let to_move_items = to_move_selectors
                .iter()
                .map(|it| content[it].clone())
                .collect_vec();
            let predecessor_element = &mut content[&if_element_selector];
            let predecessor_node_id = predecessor_element.first_basic_block_id();
            let move_to = if control_flow_graph.branch_direction(predecessor_node_id, block_id) {
                if let ControlFlowElement::If { on_success, .. } = predecessor_element {
                    on_success
                } else {
                    unreachable!()
                }
            } else if let ControlFlowElement::If { on_failure, .. } = predecessor_element {
                on_failure
            } else {
                unreachable!()
            };
            move_to.extend(to_move_items);
            for to_move_selector in to_move_selectors.iter().rev() {
                content.remove(to_move_selector);
            }
            return true;
        }
    }
    false
}
fn collect_to_move(
    root_element: &ControlFlowElement,
    first_to_move_node_selector: &CFSelector,
    control_flow_graph: &BindedControlFlowGraph<'_, '_>,
) -> Vec<CFSelector> {
    let first_to_move_element = &root_element[first_to_move_node_selector];
    let first_to_move_element_first_bb_id = first_to_move_element.first_basic_block_id();
    let move_to_if_condition_bb_id =
        control_flow_graph.predecessor(first_to_move_element_first_bb_id)[0];
    let mut to_move = vec![first_to_move_node_selector.clone()];
    let mut next = root_element.next_element_sibling(first_to_move_node_selector);
    while let Some(current_element_selector) = next {
        let current_node_id = root_element[¤t_element_selector].first_basic_block_id();
        if control_flow_graph.is_dominated_by(current_node_id, first_to_move_element_first_bb_id)
            && control_flow_graph.is_in_same_branch_side(
                move_to_if_condition_bb_id,
                first_to_move_element_first_bb_id,
                current_node_id,
            )
        {
            to_move.push(current_element_selector.clone());
        } else {
            break;
        }
        next = root_element.next_element_sibling(¤t_element_selector);
    }
    to_move
}
fn fold_if_else(function_definition: &FunctionDefinition, content: &mut ControlFlowElement) {
    loop {
        let cfg = ControlFlowGraph::new();
        let control_flow_graph = cfg.bind(function_definition);
        if !fold_if_else_once(content, control_flow_graph) {
            break;
        }
    }
}
fn fold(function_definition: &FunctionDefinition) -> Vec<ControlFlowElement> {
    let analyzer = Analyzer::new();
    let binded = analyzer.bind(function_definition);
    let current_result = (0..(function_definition.content.len()))
        .map(ControlFlowElement::new_node)
        .collect_vec();
    let mut content = ControlFlowElement::new_block(current_result);
    let control_flow_graph = binded.control_flow_graph();
    let root_scc = control_flow_graph.top_level_scc();
    fold_loop(&root_scc, content.unwrap_content_mut());
    fold_if_else(function_definition, &mut content);
    mem::take(content.unwrap_content_mut())
}
fn generate_function(
    result: (
        &mut TypeSection,
        &mut FunctionSection,
        &mut ExportSection,
        &mut CodeSection,
    ),
    function_definition: &FunctionDefinition,
    control_flow_root: &ControlFlowElement,
) {
    let function_index = result.0.len();
    let (param_type, return_type) = lower_function_type(&function_definition.header);
    result.0.function(param_type, return_type);
    result.1.function(function_index);
    result.2.export(
        &function_definition.header.name,
        ExportKind::Func,
        function_index,
    );
    let cfg = ControlFlowGraph::new();
    let cfg = cfg.bind(function_definition);
    let function = lower_function_body(function_definition, control_flow_root, &cfg);
    result.3.function(&function);
}
pub fn compile(ir_content: &[IR]) -> Module {
    let mut module = Module::new();
    let mut types = TypeSection::new();
    let mut functions = FunctionSection::new();
    let mut exports = ExportSection::new();
    let mut codes = CodeSection::new();
    let mut global_section = GlobalSection::new();
    let mut memory_section = MemorySection::new();
    global_section.global(
        GlobalType {
            val_type: ValType::I32,
            mutable: true,
            shared: false,
        },
        &ConstExpr::i32_const(0),
    );
    memory_section.memory(MemoryType {
        minimum: 1,
        maximum: None,
        memory64: false,
        shared: false,
        page_size_log2: None,
    });
    for ir_part in ir_content {
        if let IR::FunctionDefinition(function) = ir_part {
            let folded = fold(function);
            let root = ControlFlowElement::new_block(folded);
            generate_function(
                (&mut types, &mut functions, &mut exports, &mut codes),
                function,
                &root,
            );
        }
    }
    module.section(&types);
    module.section(&functions);
    module.section(&memory_section);
    module.section(&global_section);
    module.section(&exports);
    module.section(&codes);
    module
}
#[cfg(test)]
mod tests {
    use std::{assert_matches::assert_matches, fs::File, io::Write, str::FromStr};
    use analyzer::Analyzer;
    use wasm_encoder::{
        ConstExpr, GlobalSection, GlobalType, MemorySection, MemoryType, Module, ValType,
    };
    use crate::{
        ir::{
            self,
            analyzer::{self, IsAnalyzer},
            editor::Editor,
            function::{basic_block::BasicBlock, test_util::*},
            optimize::pass::{FixIrreducible, IsPass, TopologicalSort},
            statement::Ret,
        },
        utility::data_type,
    };
    use super::*;
    #[test]
    fn test_fold_if_else_once() {
        let function_definition = FunctionDefinition {
            header: ir::FunctionHeader {
                name: "f".to_string(),
                parameters: Vec::new(),
                return_type: data_type::Type::None,
            },
            content: vec![
                jump_block(0, 1),
                branch_block(1, 2, 3),
                jump_block(2, 4),
                jump_block(3, 4),
                ret_block(4),
            ],
        };
        let mut editor = Editor::new(function_definition);
        let pass = FixIrreducible;
        pass.run(&mut editor);
        let pass = TopologicalSort;
        pass.run(&mut editor);
        let function_definition = editor.content;
        let analyzer = Analyzer::new();
        let binded = analyzer.bind(&function_definition);
        let control_flow_graph = binded.control_flow_graph();
        let current_result = (0..(function_definition.content.len()))
            .map(ControlFlowElement::new_node)
            .collect_vec();
        let mut content = ControlFlowElement::new_block(current_result);
        fold_if_else_once(&mut content, control_flow_graph);
        assert_matches!(
            content[&CFSelector::from_str("1").unwrap()],
            ControlFlowElement::If { .. }
        );
        assert_matches!(
            content[&CFSelector::from_str("1/success->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 2 }
        );
        assert_eq!(
            content.get(&CFSelector::from_str("1/failure->0").unwrap()),
            None
        );
        let control_flow_graph = binded.control_flow_graph();
        fold_if_else_once(&mut content, control_flow_graph);
        assert_eq!(
            content[&CFSelector::from_str("1/failure->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 3 }
        );
        assert_eq!(
            content[&CFSelector::from_str("2").unwrap()],
            ControlFlowElement::BasicBlock { id: 4 }
        );
        let function_definition = FunctionDefinition {
            header: ir::FunctionHeader {
                name: "f".to_string(),
                parameters: Vec::new(),
                return_type: data_type::Type::None,
            },
            content: vec![
                jump_block(0, 1),
                branch_block(1, 2, 4),
                jump_block(2, 3),
                jump_block(4, 5),
                jump_block(3, 6),
                jump_block(5, 6),
                jump_block(6, 7),
                ret_block(7),
            ],
        };
        let mut editor = Editor::new(function_definition);
        let pass = FixIrreducible;
        pass.run(&mut editor);
        let pass = TopologicalSort;
        pass.run(&mut editor);
        let function_definition = editor.content;
        let analyzer = Analyzer::new();
        let binded = analyzer.bind(&function_definition);
        let control_flow_graph = binded.control_flow_graph();
        let current_result = (0..(function_definition.content.len()))
            .map(ControlFlowElement::new_node)
            .collect_vec();
        let mut content = ControlFlowElement::new_block(current_result);
        fold_if_else_once(&mut content, control_flow_graph);
        assert_matches!(
            content[&CFSelector::from_str("1").unwrap()],
            ControlFlowElement::If { .. }
        );
        assert_matches!(
            content[&CFSelector::from_str("1/success->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 2 }
        );
        assert_matches!(
            content[&CFSelector::from_str("1/success->1").unwrap()],
            ControlFlowElement::BasicBlock { id: 3 }
        );
        assert_eq!(
            content.get(&CFSelector::from_str("1/failure->0").unwrap()),
            None
        );
        let control_flow_graph = binded.control_flow_graph();
        fold_if_else_once(&mut content, control_flow_graph);
        assert_matches!(
            content[&CFSelector::from_str("1/failure->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 4 }
        );
        assert_matches!(
            content[&CFSelector::from_str("1/failure->1").unwrap()],
            ControlFlowElement::BasicBlock { id: 5 }
        );
        let function_definition = FunctionDefinition {
            header: ir::FunctionHeader {
                name: "f".to_string(),
                parameters: Vec::new(),
                return_type: data_type::Type::None,
            },
            content: vec![
                branch_block(0, 1, 5),
                jump_block(1, 2),
                jump_block(2, 4),
                branch_block(4, 3, 7),
                jump_block(3, 2),
                ret_block(7),
                jump_block(5, 6),
                jump_block(6, 7),
            ],
        };
        let mut editor = Editor::new(function_definition);
        let pass = FixIrreducible;
        pass.run(&mut editor);
        let pass = TopologicalSort;
        pass.run(&mut editor);
        let function_definition = editor.content;
        let analyzer = Analyzer::new();
        let binded = analyzer.bind(&function_definition);
        let control_flow_graph = binded.control_flow_graph();
        let mut content = ControlFlowElement::new_block(vec![
            ControlFlowElement::new_node(0),
            ControlFlowElement::new_node(1),
            ControlFlowElement::Loop {
                content: vec![
                    ControlFlowElement::new_node(2),
                    ControlFlowElement::new_node(4),
                    ControlFlowElement::new_node(3),
                ],
            },
            ControlFlowElement::new_node(5),
            ControlFlowElement::new_node(6),
            ControlFlowElement::new_node(7),
        ]);
        fold_if_else_once(&mut content, control_flow_graph);
        let control_flow_graph = binded.control_flow_graph();
        fold_if_else_once(&mut content, control_flow_graph);
        let control_flow_graph = binded.control_flow_graph();
        fold_if_else_once(&mut content, control_flow_graph);
        assert_matches!(
            &content[&CFSelector::from_str("0").unwrap()],
            ControlFlowElement::If { .. }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/if_condition").unwrap()],
            ControlFlowElement::BasicBlock { id: 0 }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 1 }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->1").unwrap()],
            ControlFlowElement::Loop { .. }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->1/0").unwrap()],
            ControlFlowElement::BasicBlock { id: 2 }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->1/1").unwrap()],
            ControlFlowElement::If { .. }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->1/1/if_condition").unwrap()],
            ControlFlowElement::BasicBlock { id: 3 }
        );
        assert_matches!(
            &content[&CFSelector::from_str("0/success->1/1/success->0").unwrap()],
            ControlFlowElement::BasicBlock { id: 4 }
        );
    }
    #[test]
    fn test_loop() {
        let function_definition = FunctionDefinition {
            header: ir::FunctionHeader {
                name: "f".to_string(),
                parameters: Vec::new(),
                return_type: data_type::Type::None,
            },
            content: vec![
                BasicBlock {
                    name: Some("bb0".to_string()),
                    content: vec![jump("bb1")],
                },
                BasicBlock {
                    name: Some("bb1".to_string()),
                    content: vec![branch("bb2", "bb8")],
                },
                BasicBlock {
                    name: Some("bb2".to_string()),
                    content: vec![jump("bb3")],
                },
                BasicBlock {
                    name: Some("bb3".to_string()),
                    content: vec![jump("bb4")],
                },
                BasicBlock {
                    name: Some("bb4".to_string()),
                    content: vec![jump("bb5")],
                },
                BasicBlock {
                    name: Some("bb5".to_string()),
                    content: vec![branch("bb6", "bb7")],
                },
                BasicBlock {
                    name: Some("bb6".to_string()),
                    content: vec![jump("bb3")],
                },
                BasicBlock {
                    name: Some("bb7".to_string()),
                    content: vec![branch("bb2", "bb15")],
                },
                BasicBlock {
                    name: Some("bb8".to_string()),
                    content: vec![branch("bb9", "bb10")],
                },
                BasicBlock {
                    name: Some("bb9".to_string()),
                    content: vec![jump("bb11")],
                },
                BasicBlock {
                    name: Some("bb11".to_string()),
                    content: vec![jump("bb12")],
                },
                BasicBlock {
                    name: Some("bb12".to_string()),
                    content: vec![jump("bb13")],
                },
                BasicBlock {
                    name: Some("bb10".to_string()),
                    content: vec![jump("bb12")],
                },
                BasicBlock {
                    name: Some("bb13".to_string()),
                    content: vec![jump("bb14")],
                },
                BasicBlock {
                    name: Some("bb14".to_string()),
                    content: vec![branch("bb12", "bb15")],
                },
                BasicBlock {
                    name: Some("bb15".to_string()),
                    content: vec![jump("bb16")],
                },
                BasicBlock {
                    name: Some("bb16".to_string()),
                    content: vec![Ret { value: None }.into()],
                },
            ],
        };
        let mut editor = Editor::new(function_definition);
        let pass = FixIrreducible;
        pass.run(&mut editor);
        let pass = TopologicalSort;
        pass.run(&mut editor);
        let function_definition = editor.content;
        let analyzer = Analyzer::new();
        let binded = analyzer.bind(&function_definition);
        let control_flow_graph = binded.control_flow_graph();
        let scc = control_flow_graph.top_level_scc();
        let mut current_result = (0..(function_definition.content.len()))
            .map(ControlFlowElement::new_node)
            .collect_vec();
        fold_loop(&scc, &mut current_result);
        dbg!(current_result);
    }
    #[test]
    fn test_fold_all() {
        let function_definition = FunctionDefinition {
            header: ir::FunctionHeader {
                name: "f".to_string(),
                parameters: Vec::new(),
                return_type: data_type::Type::None,
            },
            content: vec![
                jump_block(0, 1),
                jump_block(1, 2),
                jump_block(2, 3),
                branch_block(3, 4, 1),
                branch_block(4, 1, 5),
                ret_block(5),
            ],
        };
        let result = fold(&function_definition);
        dbg!(result);
    }
    #[test]
    fn test_generate_function() {
        let function = ir::function::parse(
            "fn test_code(i32 %a, i32 %b) -> i32 {
  test_code_entry:
    %0 = add i32 %a, 2
    %2 = add i32 %b, 1
    %4 = add i32 %0, %2
    ret %4
}
",
        )
        .unwrap()
        .1;
        let folded = fold(&function);
        let root = ControlFlowElement::new_block(folded);
        let mut types = TypeSection::new();
        let mut functions = FunctionSection::new();
        let mut exports = ExportSection::new();
        let mut codes = CodeSection::new();
        let mut global_section = GlobalSection::new();
        let mut memory_section = MemorySection::new();
        generate_function(
            (&mut types, &mut functions, &mut exports, &mut codes),
            &function,
            &root,
        );
        let mut module = Module::new();
        global_section.global(
            GlobalType {
                val_type: ValType::I32,
                mutable: true,
                shared: false,
            },
            &ConstExpr::i32_const(0),
        );
        memory_section.memory(MemoryType {
            minimum: 1,
            maximum: None,
            memory64: false,
            shared: false,
            page_size_log2: None,
        });
        module.section(&types);
        module.section(&functions);
        module.section(&memory_section);
        module.section(&global_section);
        module.section(&exports);
        module.section(&codes);
        let bytes = module.finish();
        let mut f = File::create("./test.wasm").unwrap();
        f.write_all(&bytes).unwrap();
    }
    #[test]
    fn test_generate_function2() {
        let function = ir::function::parse(
            "fn test_condition(i32 %a, i32 %b) -> i32 {
  test_condition_entry:
    %a_0_addr = alloca i32
    store i32 %a, address %a_0_addr
    %b_0_addr = alloca i32
    store i32 %b, address %b_0_addr
    %1 = load i32 %a_0_addr
    %2 = load i32 %b_0_addr
    %0 = slt i32 %1, %2
    bne %0, 0, if_0_success, if_0_fail
  if_0_success:
    %3 = load i32 %a_0_addr
    ret %3
  if_0_fail:
    %4 = load i32 %b_0_addr
    ret %4
  if_0_end:
    ret 0
}
",
        )
        .unwrap()
        .1;
        let folded = fold(&function);
        let root = ControlFlowElement::new_block(folded);
        let mut types = TypeSection::new();
        let mut functions = FunctionSection::new();
        let mut exports = ExportSection::new();
        let mut codes = CodeSection::new();
        let mut global_section = GlobalSection::new();
        let mut memory_section = MemorySection::new();
        generate_function(
            (&mut types, &mut functions, &mut exports, &mut codes),
            &function,
            &root,
        );
        let mut module = Module::new();
        global_section.global(
            GlobalType {
                val_type: ValType::I32,
                mutable: true,
                shared: false,
            },
            &ConstExpr::i32_const(0),
        );
        memory_section.memory(MemoryType {
            minimum: 1,
            maximum: None,
            memory64: false,
            shared: false,
            page_size_log2: None,
        });
        module.section(&types);
        module.section(&functions);
        module.section(&memory_section);
        module.section(&global_section);
        module.section(&exports);
        module.section(&codes);
        let bytes = module.finish();
        let mut f = File::create("./test.wasm").unwrap();
        f.write_all(&bytes).unwrap();
    }
    #[test]
    fn test_generate_function3() {
        let function = ir::function::parse(
            "fn test_condition(i32 %a, i32 %b) -> i32 {
  test_condition_entry:
    %a_0_addr = alloca i32
    store i32 %a, address %a_0_addr
    %b_0_addr = alloca i32
    store i32 %b, address %b_0_addr
    %result_0_addr = alloca i32
    store i32 0, address %result_0_addr
    %i_0_addr = alloca i32
    %0 = load i32 %a_0_addr
    store i32 %0, address %i_0_addr
    j loop_0_condition
  loop_0_condition:
    %2 = load i32 %i_0_addr
    %3 = load i32 %b_0_addr
    %1 = slt i32 %2, %3
    bne %1, 0, loop_0_success, loop_0_fail
  loop_0_success:
    %5 = load i32 %result_0_addr
    %6 = load i32 %i_0_addr
    %4 = add i32 %5, %6
    store i32 %4, address %result_0_addr
    %8 = load i32 %i_0_addr
    %7 = add i32 %8, 1
    store i32 %7, address %i_0_addr
    j loop_0_condition
  loop_0_fail:
    %9 = load i32 %result_0_addr
    ret %9
}
",
        )
        .unwrap()
        .1;
        let folded = fold(&function);
        let root = ControlFlowElement::new_block(folded);
        let mut types = TypeSection::new();
        let mut functions = FunctionSection::new();
        let mut exports = ExportSection::new();
        let mut codes = CodeSection::new();
        let mut global_section = GlobalSection::new();
        let mut memory_section = MemorySection::new();
        generate_function(
            (&mut types, &mut functions, &mut exports, &mut codes),
            &function,
            &root,
        );
        let mut module = Module::new();
        global_section.global(
            GlobalType {
                val_type: ValType::I32,
                mutable: true,
                shared: false,
            },
            &ConstExpr::i32_const(0),
        );
        memory_section.memory(MemoryType {
            minimum: 1,
            maximum: None,
            memory64: false,
            shared: false,
            page_size_log2: None,
        });
        module.section(&types);
        module.section(&functions);
        module.section(&memory_section);
        module.section(&global_section);
        module.section(&exports);
        module.section(&codes);
        let bytes = module.finish();
        let mut f = File::create("./test.wasm").unwrap();
        f.write_all(&bytes).unwrap();
    }
}