import sys
import pandas as pd

from snakemake.utils import Paramspace


# setup workflow
configfile: "config/config.yaml"


if not config["subworkflow_run"]:

    # do not install package multiple times when run as subworkflow
    # package installation is needed for multiprocessing
    onstart:
        shell("Rscript -e 'devtools::install(quick = TRUE, dependencies = FALSE)'")


# method selection
if config["method_list"] is None:
    (method_list,) = glob_wildcards(
        workflow.source_path("../resources/method_definitions/{method}.R")
    )
else:
    method_list = config["method_list"]

print(method_list, file=sys.stderr)


# setup paramspace
df_params = pd.read_csv(config["params_path"], comment="#")

for param, values in config["parameters"].items():
    df_tmp = pd.DataFrame({param: values, "tmp_col": 1})
    df_params = (
        df_params.assign(tmp_col=1).merge(df_tmp, on="tmp_col").drop("tmp_col", axis=1)
    )

paramspace = Paramspace(df_params, filename_params="*")


# rule definitions
rule all:
    input:
        expand("results/plots/{params}/", params=paramspace.instance_patterns),
        "results/aggregated_plots/",


rule create_term_database:
    output:
        fname="results/terms/{termsource}/term_database.csv",
    params:
        term_filter=config["term_filter"],
    resources:
        mem_mb=15_000,
    script:
        "scripts/create_term_database.R"


rule compute_term_similarities:
    input:
        fname_terms="results/terms/{termsource}/term_database.csv",
    output:
        fname="results/terms/{termsource}/{similaritymeasure}/term_similarities.csv",
        plotdir=directory("results/terms/{termsource}/{similaritymeasure}/plots/"),
    params:
        params=paramspace.instance,
    resources:
        mem_mb=15_000,
    script:
        "scripts/compute_term_similarities.R"


rule create_synthetic_study:
    input:
        fname_terms="results/terms/{termsource}/term_database.csv",
        fname_sim="results/terms/{termsource}/{similaritymeasure}/term_similarities.csv",
    output:
        fname_rds=f"results/studies/{paramspace.wildcard_pattern}/replicates/{{replicate}}/study.rds",
        plotdir=directory(
            f"results/studies/{paramspace.wildcard_pattern}/replicates/{{replicate}}/plots/"
        ),
    params:
        params=paramspace.instance,
    resources:
        mem_mb=5_000,
    script:
        "scripts/create_synthetic_study.R"


rule run_method:
    input:
        script=srcdir("../resources/method_definitions/{method}.R"),
        fname_study=f"results/studies/{paramspace.wildcard_pattern}/replicates/{{replicate}}/study.rds",
        fname_terms="results/terms/{termsource}/term_database.csv",
        fname_term_sim="results/terms/{termsource}/{similaritymeasure}/term_similarities.csv",
    output:
        fname=f"results/enrichments/{paramspace.wildcard_pattern}/replicates/{{replicate}}/{{method}}/enrichment_result.csv",
    params:
        params=paramspace.instance,
        script_name=lambda wildcards, input: input.script,
        setup_code_fname=workflow.source_path("scripts/model_setup_code.R"),
    threads: lambda wildcards: (32 if "cv" in wildcards.method else 2)
    resources:
        mem_mb=lambda wildcards: (1_000 if "cv" in wildcards.method else 10_000),
        time_min=lambda wildcards, attempt: (
            60 * 24 if "cv" in wildcards.method else 60 * 4
        )
        * attempt,
    script:
        "{params.script_name}"


rule aggregate_results:
    input:
        fname_list_enr=expand(
            f"results/enrichments/{paramspace.wildcard_pattern}/replicates/{{replicate}}/{{method}}/enrichment_result.csv",
            replicate=range(config["replicate_count"]),
            method=method_list,
            allow_missing=True,
        ),
        fname_list_study=[
            f"results/studies/{paramspace.wildcard_pattern}/replicates/{replicate}/study.rds"
            for replicate in range(config["replicate_count"])
            for _ in method_list
        ],
    output:
        fname=f"results/enrichments/{paramspace.wildcard_pattern}/all_enr.csv",
    script:
        "scripts/aggregate_results.R"


rule visualize_results:
    input:
        fname_enr=f"results/enrichments/{paramspace.wildcard_pattern}/all_enr.csv",
    output:
        outdir=directory(f"results/plots/{paramspace.wildcard_pattern}/"),
    resources:
        mem_mb=5_000,
    script:
        "scripts/visualize_results.R"


rule aggregated_visualizations:
    input:
        fname_list=expand(
            "results/enrichments/{params}/all_enr.csv",
            params=paramspace.instance_patterns,
        ),
    output:
        outdir=directory("results/aggregated_plots/"),
        fname_aucs="results/aggregated_plots/roc_aucs.csv",
    resources:
        mem_mb=20_000,
    script:
        "scripts/aggregated_visualization.R"
