import * as d3 from 'd3';

export function get_confusionMatrix(model, selected_data, selected_DT,split_by) { // Latest version
    var tp = [], fp = [], tn = [], fn = []
    selected_data.map(element => {
        var y = parseInt(element["y"])
        var y_pred = element[model] >= selected_DT[element[split_by]] ? 1 : 0
        if (y == 0 && y_pred == 0) {
            tn.push(element)
        }
        if (y == 1 && y_pred == 1) {
            //console.log("Found 1 TP")
            tp.push(element)
        }
        if (y == 1 && y_pred == 0) {
            fn.push(element)
        }
        if (y == 0 && y_pred == 1) {
            fp.push(element)
        }
    })
    return { "TP": tp, "FN": fn, "FP": fp, "TN": tn }
}

export function get_confusionMatrix2(model, selected_data, selected_DT) {
    var tp = [], fp = [], tn = [], fn = []
    selected_data.map(element => {
        var y = parseInt(element["y"])
        var y_pred = element[model] >= selected_DT ? 1 : 0
        if (y == 0 && y_pred == 0) {
            tn.push(element)
        }
        if (y == 1 && y_pred == 1) {
            tp.push(element)
        }
        if (y == 1 && y_pred == 0) {
            fn.push(element)
        }
        if (y == 0 && y_pred == 1) {
            fp.push(element)
        }
    })
    return { "TP": tp, "FN": fn, "FP": fp, "TN": tn }
}


export function Creatematrix(data, models) { // Creates visual matrix
    var model_color = { "Random Forest": '#66c2a5', "Logistic Regression": '#fc8d62', "K Nearest Neighbor": '#8da0cb', "Support Vector Classifier": '#e78ac3', "Decision Tree": '#a6d854', "Naive Bayes": '#ffd92f' }
    var margin = { left: 62, right: 0, top: 25, bottom: 0 };
    var parent_width = 680, parent_height = 200, matrix_container_width = parent_width - margin.left, matrix_container_height = parent_height - margin.top;
    var quadrant_width = matrix_container_width / 2
    var quadrant_height = matrix_container_height / 2
    var parent_container = d3.select(".parent_container").attr("width", parent_width).attr('height', parent_height)
    parent_container.selectAll('.top_pos_neg').data(["Positive", "Negative"]).join('text').attr('class', "top_pos_neg").attr('x', (d, i) => (i * matrix_container_width / 2) + (matrix_container_width / 3)).attr('y', 15).text(d => d).attr('font-size', 14)
    parent_container.selectAll('.left_pos_neg').data(["Positive", "Negative"]).join('text').attr('class', "left_pos_neg").attr('x', 0).attr('y', (d, i) => (i * matrix_container_height / 2) + matrix_container_height / 2.4).text(d => d).attr('font-size', 14)
    var matrix_container = parent_container.selectAll('.matrix_container').data([0]).join('svg').attr('x', margin.left).attr('y', margin.top).attr('class', 'matrix_container').attr('width', matrix_container_width).attr('height', matrix_container_width)
    matrix_container.selectAll('.svg_quadrants').data(['TPs', 'FPs', 'FNs', 'TNs']).join('svg').attr('class', d => 'svg_quadrants ' + d).attr('x', (d, i) => (i % 2) * (quadrant_width)).attr('y', (d, i) => i > 1 ? quadrant_height : 0)
        .attr('models_svg', function (d, i) {
            var model_width = quadrant_width / data[d].length // This is for equal width of each quadrant
            var model_height = quadrant_height
            //-------------------------This is to make the real algorithm for the matrix
            var model_data = data[d].sort((a, b) => b[1].length - a[1].length)
            var model_max = model_data[0][1].length
            // Add svgs for each model
            d3.select(this).selectAll('.model_svg').data(model_data).join('svg').attr('class', "model_svg").attr('x', (d, i) => i * model_width).attr('width', model_width).attr('height', model_height)
                .attr('Add_rect_background', function (rect_back_data, i) {
                    d3.select(this).selectAll('.model_bground').data([rect_back_data]).join('rect').attr('class', "model_bground").attr('x', 0)
                        .attr('width', (d, i) => { return ((model_width * d[1].length) / model_max) })
                        .attr('height', model_height)
                        .attr('fill', d => model_color[d[0]])
                })
                .attr('add_title', function (d, i) { // Adds the model name as title
                    d3.select(this).selectAll('.title_text').data([d[0]]).join('text').attr('class', 'title_text').text(d => d).attr('y', quadrant_height / 2.1)
                        .attr('x', '50%').attr('font-size', 11).attr('text-anchor', 'middle')
                })
                .attr('add_number', function (d, i) { // adds the TP, TN .... Numbers 
                    d3.select(this).selectAll('.number_text').data([d[1]]).join('text').attr('class', 'number_text').text(d => d.length).attr('y', quadrant_height / 1.55).attr('x', 25)
                        .attr('x', '50%').attr('font-size', 11).attr('text-anchor', 'middle')
                })
                .attr('add_model_border', function (d, i) { //adds border to each model container
                    d3.select(this).selectAll('.model_border').data([0]).join('rect').attr('class', "model_border").attr('width', model_width - 1).attr('height', model_height - 1).attr('fill', 'none').attr('stroke', 'grey').attr('stroke-width', 0.1)
                })
            //--------------------------------------Add subgroups ends here
        })
    //------------Create legends
    var l_width = 18, l_height = 45
    var legend_container = d3.select('.legend').attr('width', 140).attr('height', l_height * models.length)
    legend_container.selectAll('rect').data(models).join('rect').attr('width', l_width).attr('height', l_height).attr('y', (d, i) => i * l_height).attr('fill', d => model_color[d])
    legend_container.selectAll('text').data(models).join('text').attr('x', l_width + 2).attr('y', (d, i) => i * l_height + l_height / 1.5).text(d => d)
    .attr('font-size', 12)
    //------------Outer border
    matrix_container.selectAll('.bord_rect').data([0]).join('rect').attr('class', 'bord_rect').attr('x', 0).attr('y', 0).attr('width', matrix_container_width).attr('height', matrix_container_height)
        .attr('stroke-width', 4).attr('stroke', 'gray').attr('fill', 'none')
    // Horizontal and vertical line
    matrix_container.selectAll('.horizontal_line').data([0]).join('line').attr('class', 'horizontal_line').attr('x1', 0).attr('x2', matrix_container_width).attr('y1', matrix_container_height / 2).attr('y2', matrix_container_height / 2).attr('stroke', 'gray').attr('stroke-width', 2)
    matrix_container.selectAll('.vertical_line').data([0]).join('line').attr('class', 'vertical_line').attr('x1', matrix_container_width / 2 + 1).attr('x2', matrix_container_width / 2 + 1).attr('y1', 0).attr('y2', matrix_container_height).attr('stroke', 'gray').attr('stroke-width', 2)
}





export function CreatematrixCat(data, models, split_by, subgroups) { // Creates visual matrix for categories
    //console.log(split_by, subgroups)
    var model_color = { "Random Forest": '#66c2a5', "Logistic Regression": '#fc8d62', "K Nearest Neighbor": '#8da0cb', "Support Vector Classifier": '#e78ac3', "Decision Tree": '#a6d854', "Naive Bayes": '#ffd92f' }
    var margin = { left: 60, right: 0, top: 25, bottom: 0 };
    var parent_width = 740, parent_height = 200;
    if(parent_height<60*subgroups.length){parent_height=60*subgroups.length}
    var matrix_container_width = parent_width - margin.left, matrix_container_height = parent_height - margin.top;
    var quadrant_width = matrix_container_width / 2
    var quadrant_height = matrix_container_height / 2
    var parent_container = d3.select(".parent_container").attr("width", parent_width).attr('height', parent_height)
    parent_container.selectAll('.top_pos_neg').data(["Positive", "Negative"]).join('text').attr('class', "top_pos_neg").attr('x', (d, i) => (i * matrix_container_width / 2) + (matrix_container_width / 3)).attr('y', 15).text(d => d).attr('font-size', 15)
    parent_container.selectAll('.left_pos_neg').data(["Positive", "Negative"]).join('text').attr('class', "left_pos_neg").attr('x', 0).attr('y', (d, i) => (i * matrix_container_height / 2) + matrix_container_height / 2.4).text(d => d).attr('font-size', 15)
    var matrix_container = parent_container.selectAll('.matrix_container').data([0]).join('svg').attr('x', margin.left).attr('y', margin.top).attr('class', 'matrix_container').attr('width', matrix_container_width).attr('height', matrix_container_width)
    matrix_container.selectAll('.svg_quadrants').data(['TPs', 'FPs', 'FNs', 'TNs']).join('svg').attr('class', d => 'svg_quadrants ' + d).attr('x', (d, i) => (i % 2) * (quadrant_width)).attr('y', (d, i) => i > 1 ? quadrant_height : 0)
        .attr('models_svg', function (d, i) {
            var model_width = quadrant_width / data[d].length // This is for equal width of each quadrant
            var model_height = quadrant_height
            //-------------------------This is to make the real algorithm for the matrix
            var model_data = data[d].sort((a, b) => b[1].length - a[1].length)
            var model_max = model_data[0][1].length
            // Add svgs for each model
            d3.select(this).selectAll('.model_svg').data(model_data).join('svg').attr('class', "model_svg").attr('x', (d, i) => i * model_width).attr('width', model_width).attr('height', model_height)
                .attr('Add_rect_background', function (rect_back_data, i) { // rect_back_data has model and array of all elements
                    var subgrouped_temp_data = []
                    subgroups.map(subgroup => {
                        var temp = rect_back_data[1].filter(item => item[split_by] == subgroup)
                        subgrouped_temp_data.push([subgroup, temp])
                    })
                    var subgrouped_data_final = subgrouped_temp_data.sort((a, b) => b[1].length - a[1].length)
                    var subgroup_height = model_height / subgrouped_data_final.length
                    var subgroup_max = subgrouped_data_final[0][1].length
                    //console.log(subgroup_max, subgrouped_data_final)
                    d3.select(this).selectAll('.subgroup_svg').data(subgrouped_data_final).join('svg').attr('class', "subgroup_svg").attr('x', 0)
                        .attr('y', (d, i) => i * model_height / subgrouped_data_final.length).attr('height', subgroup_height)
                        .attr('width', model_width)
                        .attr('add_subgroup_rect', function (subgroup_bground_data, i) {
                            d3.select(this).selectAll('.subgroup_bground').data([0]).join('rect').attr('class', "subgroup_bground").attr('x', 0)
                                .attr('width', (model_width * rect_back_data[1].length) / model_max)
                                .attr('height', (subgroup_height * subgroup_bground_data[1].length) / subgroup_max -.2)
                                .attr('fill', model_color[rect_back_data[0]])
                        })
                        .attr('add_border', function (subgroup_bground_data, i) {
                            d3.select(this).selectAll('.subgroup_border').data([0]).join('line').attr('class', "subgroup_border")
                            .attr('x1', 0).attr('x2', model_width).attr('fill', 'none')
                            .attr('stroke', '#a5a4a4').attr('stroke-width', 2)
                        })
                        .attr('add_title',function(subgroup_bground_data, i) {
                            d3.select(this).selectAll('.add_title').data([subgroup_bground_data]).join('text').attr('class','add_title').attr('x', "50%").attr('y', 10).text(d => d[0]).attr('font-size', 10).attr('text-anchor','middle')
                        })

                        .attr('add_number_text',function(subgroup_bground_data, i) {
                            d3.select(this).selectAll('.add_number_text').data([subgroup_bground_data]).join('text').attr('class','add_number_text').attr('x', "50%").attr('y', 20).text(d => d[1].length).attr('font-size', 10).attr('text-anchor','middle')
                        })
                })
                .attr('add_subgroup_border', function (d, i) { //adds border to each model container
                    d3.select(this).selectAll('.model_border').data([0]).join('rect').attr('class', "model_border").attr('width', model_width - 1).attr('height', model_height - 1).attr('fill', 'none').attr('stroke', 'grey').attr('stroke-width', 0.1)
                })

            //--------------------------------------Add subgroups ends here
        })
    //------------Create legends
    var l_width = 25, l_height = 45
    var legend_container = d3.select('.legend').attr('width', 200).attr('height', l_height * models.length)
    legend_container.selectAll('rect').data(models).join('rect').attr('width', l_width).attr('height', l_height).attr('y', (d, i) => i * l_height).attr('fill', d => model_color[d])
    legend_container.selectAll('text').data(models).join('text').attr('x', l_width + 5).attr('y', (d, i) => i * l_height + l_height / 1.5).text(d => d).attr('font-size', 14)
    //------------Outer border
    matrix_container.selectAll('.bord_rect').data([0]).join('rect').attr('class', 'bord_rect').attr('x', 0).attr('y', 0).attr('width', matrix_container_width).attr('height', matrix_container_height)
        .attr('stroke-width', 4).attr('stroke', 'gray').attr('fill', 'none')
    // Horizontal and vertical line
    matrix_container.selectAll('.horizontal_line').data([0]).join('line').attr('class', 'horizontal_line').attr('x1', 0).attr('x2', matrix_container_width).attr('y1', matrix_container_height / 2).attr('y2', matrix_container_height / 2).attr('stroke', 'gray').attr('stroke-width', 2)
    matrix_container.selectAll('.vertical_line').data([0]).join('line').attr('class', 'vertical_line').attr('x1', matrix_container_width / 2 + 1).attr('x2', matrix_container_width / 2 + 1).attr('y1', 0).attr('y2', matrix_container_height).attr('stroke', 'gray').attr('stroke-width', 2)
}

export function Matrix(options, tSne_data, Set_scatter_plot_data) {
    var margin = { left: 70, right: 0, top: 25, bottom: 0 };
    var cell_width = 80,
        cell_height = 40,
        width = cell_width * 2 + margin.left + margin.right,
        height = cell_height * 2 + margin.top + margin.bottom,
        data = options.data,
        container = options.container,
        startColor = options.start_color,
        endColor = options.end_color,
        index = options.index
    var data_arr = Object.entries(data) // data_arr contains an array converted from the javascript object
    var tp_fp_tn_fn = data_arr.map(element => element[1].length)
    var maxValue = d3.max(tp_fp_tn_fn);
    var minValue = d3.min(tp_fp_tn_fn);
    var colorMap = d3.scaleLinear().domain([minValue, maxValue]).range([startColor, endColor]);

    var parent = d3.select("#" + container).attr("width", width).attr("height", height).attr("class", 'parent')

    var svg1 = parent.selectAll('.svg1').data([0]).join('svg').attr("class", 'svg1').attr("width", width).attr("height", height)
        .attr('x', margin.left).attr('y', margin.top)

    svg1.selectAll('svg').data(data_arr).join('svg').attr('class', d => container + d[0])
        .attr('width', d => cell_width)
        .attr('height', cell_height)
        .attr('x', (d, i) => (i % 2) * cell_width)
        .attr('y', (d, i) => i > 1 ? cell_height : 0)
        .selectAll('.cell').data(d => [d]).join('rect').attr('class', 'cell')
        .on("click", function (d) {
            d3.selectAll(".cell").classed('cell_clicked', false)
            d3.select(this).classed('cell_clicked', true)
            d3.selectAll(".scat").attr('opacity', 0.02)
            d[1].map(item => d3.selectAll(".scat" + item['index']).attr('opacity', 1))
            /*
            var filtered_scatterplot_data={}
            d[1].map(item=>{filtered_scatterplot_data[item['index']]=tSne_data[item['index']]})
            Set_scatter_plot_data(filtered_scatterplot_data)
            */
        })
        .style('fill', (d, i) => {
            d3.select('.' + container + d[0]).selectAll('text').data([d[1].length]).join('text')
                .attr('x', cell_width / 4)
                .attr('y', cell_height / 1.8)
                .text(d => d)
            return colorMap(d[1].length)
        })
        .attr('width', d => cell_width)
        .attr('height', cell_height)
        .attr('x', (d, i) => (i % 2) * cell_width)
        .attr('y', (d, i) => i > 1 ? cell_height : 0)
    //----------------------------------------------------------------------------------------------------------------------------
    parent.selectAll(".col_pos").data([0]).join('text').attr('class', 'col_pos').attr('x', margin.left).attr('y', 12).text("Positive")
    parent.selectAll(".col_neg").data([0]).join('text').attr('class', 'col_neg').attr('x', cell_width + margin.left + 7).attr('y', 12).text("Negative")
    //----------------------------------------------------------------------------------------------------------------------------
    //----------------------------------------------------------------------------------------------------------------------------
    if (index % 4 == 0) {
        parent.selectAll(".row_pos").data([0]).join('text').attr('class', 'row_pos').attr('x', 8).attr('y', (cell_height / 1.8) + margin.top).text("Positive")
        parent.selectAll(".row_neg").data([0]).join('text').attr('class', 'row_neg').attr('x', 0).attr('y', ((cell_height * 2) / 1.3) + margin.top).text("Negative")
    }
    //----------------------------------------------------------------------------------------------------------------------------
    // svg1.selectAll(".title").data([0]).join('text').attr('class', 'title').attr('x', cell_width/2).attr('y', ((cell_height * 2))+margin.top).text(container)

}



export function MatrixCat(options, tSne_data, Set_scatter_plot_data) {
    var margin = { left: 70, right: 0, top: 25, bottom: 0 };
    var cell_width = 100,
        cell_height = 100,
        width = cell_width * 2 + margin.left + margin.right,
        height = cell_height * 2 + margin.top + margin.bottom,
        data = options.data,
        container = options.container,
        startColor = options.start_color,
        endColor = options.end_color,
        index = options.index
    var data_arr = Object.entries(data) // data_arr contains an array converted from the javascript object
    var tp_fp_tn_fn = data_arr.map(element => element[1].length)
    var maxValue = d3.max(tp_fp_tn_fn);
    var minValue = d3.min(tp_fp_tn_fn);
    var colorMap = d3.scaleLinear().domain([minValue, maxValue]).range([startColor, endColor]);

    var parent = d3.select("#" + container).attr("width", width).attr("height", height).attr("class", 'parent')

    var svg1 = parent.selectAll('.svg1').data([0]).join('svg').attr("class", 'svg1').attr("width", width).attr("height", height)
        .attr('x', margin.left).attr('y', margin.top)

    svg1.selectAll('svg').data(data_arr).join('svg').attr('class', d => container + d[0])
        .attr('width', d => cell_width)
        .attr('height', cell_height)
        .attr('x', (d, i) => (i % 2) * cell_width)
        .attr('y', (d, i) => i > 1 ? cell_height : 0)
        .selectAll('.cell').data(d => [d]).join('rect').attr('class', 'cell')
        .on("click", d => {
            d3.selectAll(".scat").attr('opacity', 0.02)
            d[1].map(item => d3.selectAll(".scat" + item['index']).attr('opacity', 1))

            /*
            var filtered_scatterplot_data={}
            d[1].map(item=>{filtered_scatterplot_data[item['index']]=tSne_data[item['index']]})
            Set_scatter_plot_data(filtered_scatterplot_data)
            */
        })

        .style('fill', (d, i) => {
            d3.select('.' + container + d[0]).selectAll('text').data([d[1]]).join('text')
                .attr('x', cell_width / 4)
                .attr('y', cell_height / 1.8)
                .text(d => d.length)

            d3.select('.' + container + d[0]).selectAll('circle').data(d[1]).join('circle')
                .style("stroke", "gray")
                .style("fill", "black")
                .attr("r", 2)
                .attr("cx", (d, i) => i * 10)
                .attr("cy", (d, i) => i * 10)

            return colorMap(d[1].length)
        })
        .attr('width', d => cell_width)
        .attr('height', cell_height)
        .attr('x', (d, i) => (i % 2) * cell_width)
        .attr('y', (d, i) => i > 1 ? cell_height : 0)

    //----------------------------------------------------------------------------------------------------------------------------
    parent.selectAll(".col_pos").data([0]).join('text').attr('class', 'col_pos').attr('x', margin.left).attr('y', 12).text("Positive")
    parent.selectAll(".col_neg").data([0]).join('text').attr('class', 'col_neg').attr('x', cell_width + margin.left + 7).attr('y', 12).text("Negative")
    //----------------------------------------------------------------------------------------------------------------------------
    //----------------------------------------------------------------------------------------------------------------------------
    if (index % 3 == 0) {
        parent.selectAll(".row_pos").data([0]).join('text').attr('class', 'row_pos').attr('x', 8).attr('y', (cell_height / 1.8) + margin.top).text("Positive")
        parent.selectAll(".row_neg").data([0]).join('text').attr('class', 'row_neg').attr('x', 0).attr('y', ((cell_height * 2) / 1.3) + margin.top).text("Negative")
    }
    //----------------------------------------------------------------------------------------------------------------------------
    // svg1.selectAll(".title").data([0]).join('text').attr('class', 'title').attr('x', cell_width/2).attr('y', ((cell_height * 2))+margin.top).text(container)

}
