import os

import pandas as pd

from snakemake.utils import Paramspace


configfile: 'config/config.yaml'
paramspace = Paramspace(pd.read_csv('config/params.csv', comment='#'))


# gather input configuration
case_list = []
all_perturbed_genes = set()
for study_entry in os.scandir('resources/data/'):
    study = study_entry.name

    if study not in config['studies']:
        continue

    for counts_entry in os.scandir(study_entry):
        if 'Ctrl' in counts_entry.name:
            continue

        _, gene, treatment = counts_entry.name.split('.')[0].split('_')

        case_list.append((study, gene, treatment))
        all_perturbed_genes.add(gene)


# helper functions
def aggregate_pathway_files(wildcards):
    checkpoint_output = checkpoints.retrieve_pathways.get(**wildcards).output.pathway_dir

    pathway_list = glob_wildcards(
        os.path.join(checkpoint_output, 'csv_files/{pathway}.csv')
    ).pathway

    return expand(
        os.path.join(checkpoint_output, 'csv_files/{pathway}.csv'),
        pathway=pathway_list)


def aggregate_dce_results(wildcards):
    checkpoint_output = checkpoints.retrieve_pathways.get(**wildcards).output.pathway_dir

    pw_stats_dir = checkpoints.pathway_statistics.get(**wildcards).output.out_dir
    pw_stats_fname = os.path.join(pw_stats_dir, 'pathway_statistics.csv')

    pathway_template = os.path.join(checkpoint_output, 'csv_files/{pathway}.csv')
    pw_stats = (pd.read_csv(pw_stats_fname)
                  .set_index('pathway')
                  .to_dict(orient='index'))

    pathway_list = []
    # pathway_list = glob_wildcards(pathway_template).pathway

    if config['pathway_list'] is None:
        for pw, stats in pw_stats.items():
            if (
                stats['pathway_node_size'] >= config['pathway_constraints']['min_pathway_node_size']
                and stats['pathway_edge_size'] <= config['pathway_constraints']['max_pathway_edge_size']
            ):
                pathway_list.append(pw)
    else:
        pathway_list = config['pathway_list']

    file_list = []
    for pathway in pathway_list:
        fname = pathway_template.format(pathway=pathway)
        df = pd.read_csv(fname)
        pathway_nodes = set(df['source']) | set(df['sink'])

        for study, gene, treatment in case_list:
            # check if perturbed genes occur in pathway
            if set(gene.split(',')) & pathway_nodes:
                # TODO choose correct one from wildcards
                param = paramspace.wildcard_pattern.format(
                    **{k: v.iloc[0]
                       for k, v in paramspace.instance(wildcards).items()}
                )

                file_list.append(
                    f'results/dce_results/{pathway}/{gene}/{study}/{treatment}/{param}/'
                )

    return file_list


# rules
localrules: all, aggregate_expression_statistics, aggregate_dce_lists


rule all:
    input:
        'results/pathway_statistics/',
        # 'results/perturbation_analysis/expression_statistics.csv',
        expand(
            'results/performance/plots/{params}/threshold_{threshold}/',
            params=paramspace.instance_patterns,
            threshold=[0.1, 0.05, 0.01]
        ),
        expand(
            'results/aggregated_statistics/{params}/details/',
            params=paramspace.instance_patterns
        ),
        expand(
            'results/performance/comparison/threshold_{threshold}/',
            threshold=[0.1, 0.05, 0.01]
        )


checkpoint retrieve_pathways:
    output:
        pathway_dir = directory('results/pathways/')
    resources:
        mem_mb = 5_000
    script:
        'scripts/retrieve_pathways.R'


checkpoint pathway_statistics:
    input:
        count_wt_file = 'resources/data/epistasis/Counts_Ctrl_1.csv',
        graph_files = aggregate_pathway_files
    output:
        out_dir = directory('results/pathway_statistics/')
    params:
        perturbed_genes = all_perturbed_genes
    resources:
        mem_mb = 10_000
    script:
        'scripts/pathway_statistics.R'


rule analyze_single_perturbed_gene:
    input:
        count_wt_file = 'resources/data/{study}/Counts_Ctrl_{treatment}.csv',
        count_mt_file = 'resources/data/{study}/Counts_{gene}_{treatment}.csv'
    output:
        out_dir = directory('results/perturbation_analysis/{study}/{gene}/{treatment}/')
    resources:
        mem_mb = 10_000
    script:
        'scripts/analyze_single_perturbed_gene.R'


rule aggregate_expression_statistics:
    input:
        dir_list = [f'results/perturbation_analysis/{study}/{gene}/{treatment}/'
                    for study, gene, treatment in case_list]
    output:
        fname = 'results/perturbation_analysis/expression_statistics.csv'
    run:
        import pandas as pd

        df_list = []
        for dname in input.dir_list:
            tmp = pd.read_csv(os.path.join(dname, 'expression_stats.csv'))

            parts = dname.split('/')
            tmp['study'] = parts[2]
            assert tmp['perturbed.gene'].iloc[0] == parts[3]
            tmp['treatment'] = parts[4]

            df_list.append(tmp)

        pd.concat(df_list).to_csv(output.fname, index=False)


rule compute_dces:
    input:
        count_wt_file = 'resources/data/{study}/Counts_Ctrl_{treatment}.csv',
        count_mt_file = 'resources/data/{study}/Counts_{gene}_{treatment}.csv',
        graph_file = 'results/pathways/csv_files/{pathway}.csv'
    output:
        out_dir = directory(f'results/dce_results/{{pathway}}/{{gene}}/{{study}}/{{treatment}}/{paramspace.wildcard_pattern}/')
    params:
        computation = paramspace.instance
    resources:
        mem_mb = 10_000
    script:
        'scripts/compute_dces.R'


rule aggregate_dce_lists:
    input:
        dir_list = aggregate_dce_results
    output:
        fname = f'results/aggregated_statistics/{paramspace.wildcard_pattern}/dce_lists.csv'
    run:
        import pandas as pd
        from tqdm import tqdm

        df_list = []
        for dir_ in tqdm(input.dir_list):
            *_, pathway, gene, study, treatment, deconf_param = dir_.split('/')
            fname = os.path.join(dir_, f'dce_list_{pathway}_{gene}.csv')

            tmp = pd.read_csv(fname)
            tmp['study'] = study
            tmp['treatment'] = treatment
            tmp['perturbed_gene'] = gene
            tmp['pathway'] = pathway

            df_list.append(tmp)

        df = pd.concat(df_list)
        df.to_csv(output.fname, index=False)


rule aggregated_statistics:
    input:
        fname = f'results/aggregated_statistics/{paramspace.wildcard_pattern}/dce_lists.csv'
    output:
        out_dir = directory(f'results/aggregated_statistics/{paramspace.wildcard_pattern}/details/')
    script:
        'scripts/aggregated_statistics.R'


rule compute_performance:
    input:
        fname = f'results/aggregated_statistics/{paramspace.wildcard_pattern}/dce_lists.csv'
    output:
        out_dir = directory(f'results/performance/results/{paramspace.wildcard_pattern}/threshold_{{threshold}}/')
    script:
        'scripts/compute_performance.py'


rule plot_performance:
    input:
        dname = f'results/performance/results/{paramspace.wildcard_pattern}/threshold_{{threshold}}/'
    output:
        out_dir = directory(f'results/performance/plots/{paramspace.wildcard_pattern}/threshold_{{threshold}}/')
    script:
        'scripts/plot_performance.py'


rule compare_performance:
    input:
        dname_list = expand(
            'results/performance/results/{params}/threshold_{threshold}/',
            params=paramspace.instance_patterns,
            allow_missing=True
        )
    output:
        dname = directory('results/performance/comparison/threshold_{threshold}/')
    script:
        'scripts/compare_performance.py'
