import os
import pandas as pd


configfile: "config/config.yaml"


localrules:
    all,


rule all:
    input:
        "results/plots/gtex_dce.png",
        "results/plots/gtex_combined.png",

checkpoint download_input_data:
    output:
        gencode="results/raw_data/gencode.v26.GRCh38.genes.gtf",
        all_expressions_file="results/raw_data/GTEx_Analysis_2017-06-05_v8_RNASeQCv1.1.9_gene_tpm.gct",
        annotations_file="results/raw_data/GTEx_Analysis_v8_Annotations_SampleAttributesDS.txt",
        covariates_dir=directory("results/raw_data/GTEx_Analysis_v8_sQTL_covariates/"),
    shell:
        """
        cd results/raw_data
        wget https://storage.googleapis.com/gtex_analysis_v8/reference/gencode.v26.GRCh38.genes.gtf https://storage.googleapis.com/gtex_analysis_v8/single_tissue_qtl_data/GTEx_Analysis_v8_sQTL_covariates.tar.gz https://storage.googleapis.com/gtex_analysis_v8/rna_seq_data/GTEx_Analysis_2017-06-05_v8_RNASeQCv1.1.9_gene_tpm.gct.gz https://storage.googleapis.com/gtex_analysis_v8/annotations/GTEx_Analysis_v8_Annotations_SampleAttributesDS.txt

        gunzip GTEx_Analysis_2017-06-05_v8_RNASeQCv1.1.9_gene_tpm.gct.gz
        tar -xvf GTEx_Analysis_v8_sQTL_covariates.tar.gz
        rm GTEx_Analysis_v8_sQTL_covariates.tar.gz
        """

def gencode_file(wildcards):
    return checkpoints.download_input_data.get().output.gencode

def all_expressions_file(wildcards):
    return checkpoints.download_input_data.get().output.all_expressions_file

def annotations_file(wildcards):
    return checkpoints.download_input_data.get().output.annotations_file
    
def covariates_file(wildcards):
    file_directory = checkpoints.download_input_data.get().output.covariates_dir
    for file_name in os.listdir(file_directory):
        if wildcards.tissue == file_name.split(".")[0]:
            return os.path.join(file_directory, file_name)

checkpoint process_expressions:
    input:
        all_expressions_file=all_expressions_file,
        annotations_file=annotations_file,
    output:
        expressions_dir=directory(
            "results/raw_data/expression_matrices/"
        ),
    script:
        "scripts/prepare_expressions.R"

def expressions_file(wildcards):
    file_directory = checkpoints.process_expressions.get().output.expressions_dir
    for file_name in os.listdir(file_directory):
        if wildcards.tissue == file_name.split(".")[0]:
            return os.path.join(file_directory, file_name)

rule process_encodings:
    input:
        gencode=gencode_file,
    output:
        "results/processed_data/encodings.Rdata",
    script:
        "scripts/prepare_gene_encodings.R"

rule get_dataset:
    input:
        gencode="results/processed_data/encodings.Rdata",
        expressions=expressions_file,
        covariates=covariates_file,
    output:
        "results/processed_data/{tissue}.Rdata",
    script:
        "scripts/prepare_dataset.R"

rule compare_tissues:
    input:
        "results/processed_data/{tissue1}.Rdata",
        "results/processed_data/{tissue2}.Rdata",
    output:
        "results/output/{tissue1}#{tissue2}.Rdata",
    script:
        "scripts/compare_tissues.R"


def valid_tissues(wildcards):
    file_directory = checkpoints.download_input_data.get().output.covariates_dir

    valid_tissues = []
    for file_name in os.listdir(file_directory):
        if file_name.startswith("."):
            continue
        tissue = file_name.split(".")[0]
        df = pd.read_csv(os.path.join(file_directory, file_name), sep="\t")
        if len(df.columns) >= 200 and tissue not in [
            config["target_tissue"],
            "Testis",
            "Prostate",
        ]:
            valid_tissues.append(tissue)

    print("{num} valid tissues: ".format(num=len(valid_tissues)), valid_tissues)
    return expand(
        "results/output/" + config["target_tissue"] + "#{tissue}.Rdata",
        tissue=valid_tissues,
    )


rule aggregate_results:
    input:
        valid_tissues,
    output:
        "results/output/statistics.Rdata",
    script:
        "scripts/get_statistics.R"


rule create_plots:
    input:
        "results/output/statistics.Rdata",
    output:
        "results/plots/gtex_mse.png",
        "results/plots/gtex_correlation.png",
        "results/plots/gtex_combined.png",
    script:
        "scripts/plot_results.R"


rule create_dce_plots:
    input:
        "results/output/"
        + config["target_tissue"]
        + "#"
        + config["comparison_tissue"]
        + ".Rdata",
    output:
        "results/plots/gtex_dce.png",
    script:
        "scripts/plot_dce.R"
