#!/usr/bin/python
"""
Script for running the whole Multi-Dendrix pipeline. Consists of the following
steps:
  1. Runs Multi-Dendrix 
"""
# Load required modules and add the lib to the path
import sys, os
sys.path.insert(1, os.path.abspath('./lib'))
from output_functions import *
def parse_args(input_list=None):
    # Parse arguments
    import argparse
    class Args: pass
    args = Args()
    description = 'Runs Multi-Dendrix for a set of parameters. Evaluates the '\
                  'results and outputs them as text and as a website.'
    parser = argparse.ArgumentParser(description=description)
    # General options
    parser.add_argument('-o', '--output_dir', required=True,
                        help='Name of output directory.')
    parser.add_argument('-v', '--verbose', default=False, action='store_true',
                        help='Flag verbose mode.')
    # Options for Multi-Dendrix
    parser.add_argument('-k_min', '--min_gene_set_size', required=True, type=int,
                        help='Minimum gene set size.')
    parser.add_argument('-k_max', '--max_gene_set_size', required=True, type=int,
                        help='Maximum gene set size.')
    parser.add_argument('-t_min', '--min_num_gene_sets', required=True, type=int,
                        help='Minimum number of gene sets.')
    parser.add_argument('-t_max', '--max_num_gene_sets', required=True, type=int,
                        help='Maximum number of gene sets.')
    parser.add_argument('-n', '--db_name', required=True,
                        help='Name of mutation data for use in output.')
    parser.add_argument('-m', '--mutation_matrix', required=True,
                        help='File name for mutation data.')
    parser.add_argument('-c', '--cutoff', type=int, default=0,
                        help='Minimum gene mutation frequency.')
    parser.add_argument('-p', '--patient_whitelist', default=None,
                        help='File of patients to be included.')
    parser.add_argument('-bp', '--patient_blacklist', default=None,
                        help='File of patients to be excluded.')
    parser.add_argument('-g', '--gene_whitelist', default=None,
                        help='File of genes to be included.')
    parser.add_argument('-bg', '--gene_blacklist', default=None,
                        help='File of genes to be excluded.')
    parser.add_argument('-a', '--alpha', type=float, default=1.0,
                        help='Parameter that changes weight function W by '\
                        'weighting the penalty of coverage overlap.')    
    parser.add_argument('--delta', type=int, default=0,
                        help='Number of overlaps allowed per gene set.')    
    parser.add_argument('--lmbda', type=int, default=1,
                        help='Number of gene sets a gene can be a member of.')
    
    # Options for core modules
    parser.add_argument('--stability_threshold', type=int, default=1,
                        help='Minimum proportion of gene sets two genes must '\
                             'both be a member of to be connected in the '\
                             'core modules.')
    # Options for (sub)type analysis
    parser.add_argument('--subtypes', default=False, action='store_true',
                        help='Perform (sub)type analysis.')
    parser.add_argument('--subtype_sig_threshold', default=0.05, type=float,
                        help='Significance threshold for subtype association '\
                             '(use 1.0 to report all associations).')
    # Options for permutation tests
    parser.add_argument('--permute', default=False, action='store_true',
                        help='Perform permutation test.')
    parser.add_argument('--network_edgelist', default=None,
                        help='PPI edgelist location.')
    parser.add_argument('--num_permuted_networks', default=5, type=int,
                        help='The number of permuted networks to create '\
                             '(only if a directory of permuted networks '\
                             'is not provided).')
    parser.add_argument('--permuted_networks_dir', default=None,
                        help='Directory of permuted networks.')
    parser.add_argument('--distance', default=False, action='store_true',
                        help='Flag average pairwise distance test.')
    parser.add_argument('--Q', default=100, type=int,
                        help='Multiplier of edge swaps for permuting networks.')
    parser.add_argument('--permuted_matrices_dir', default=None,
                        help='Directory of permuted matrices.')
    parser.add_argument('--num_permuted_matrices', default=5, type=int,
                        help='The number of permuted matrices to create '\
                             '(only if a directory of permuted matrices '\
                             'is not provided).')
    # If called from the command line, parse command line args.
    if input_list: parser.parse_args(input_list, namespace=args)
    else: parser.parse_args(namespace=args)
    
    return args
 
[docs]def batch_multi_dendrix(args):
    """Runs Multi-Dendrix for each parameter setting on the input
    mutation data.
    **Returns:**
      A tuple containing the following:
        * **collections** (*dictionary*) - mapping of t -> k -> output of Multi-Dendrix
        * **mutation_data** (*tuple*) - mutation data tuple (see :func:`multi_dendrix.multi_dendrix` for details).
        * **runtime** (*float) - total runtime (in seconds) of Multi-Dendrix on all the parameter settings
    """
    # Import required modules
    import multi_dendrix as Multi
    from time import time
    # Load mutation data used in each run
    start = time()
    include = Multi. white_and_blacklisting(args.patient_whitelist,
              args.patient_blacklist, args.gene_whitelist, args.gene_blacklist)
    gene2include, sample2include = include
    mutation_data = Multi.load_mutation_data_w_cutoff(args.mutation_matrix,
                    sample2include, gene2include, args.cutoff)
    m, n, genes, patients, mutation2patients, patient2mutations = mutation_data
    # Run Multi-Dendrix for the range of parameters 
    ts = range(args.min_num_gene_sets, args.max_num_gene_sets + 1)
    ks = range(args.min_gene_set_size, args.max_gene_set_size + 1)
    collections = dict( [(t, {}) for t in ts] )
    for t, k_max in [(t, k) for t in ts for k in ks]:
        multi_params = [ mutation_data, t, args.min_gene_set_size, k_max,
                         args.alpha, args.delta, args.lmbda ]
        collection_w_weights = Multi.multi_dendrix(*multi_params,
                                                   verbose=args.verbose)
        collections[t][k_max] = zip(*collection_w_weights)
    return collections, mutation_data, time() - start
 
[docs]def run_network_permutation_test(args, collections, core_modules):
    """Runs the direct interactions or average pairwise distance test
    on each of the collections and the core_modules.
    **Returns**:
      * **evaluation** (*dictionary*) - a mapping of t -> k -> the network evaluation tuple of each collection (see :func:`network_tests.evaluate_collection` for details)
    """
    from permute_ppi_network import load_network, permute_network
    # Load original network and generate permuted networks
    G = load_network(args.network_edgelist)
    if args.permuted_networks_dir:
        network_files = [ args.permuted_networks_dir + "/" + fh
                          for fh in os.listdir(args.permuted_networks_dir)]
        Hs = [ load_network(H) for H in network_files ]
    else:
        Hs = [ permute_network(G, args.Q) for i in range(args.num_permuted_networks) ]
    # Perform network test
    from network_tests import evaluate_collection
    evaluation = dict( [ (t, {}) for t in collections.keys() ] )
    for t in collections.keys():
        for k_max in collections[t].keys():
            gene_sets, weights = collections[t][k_max]
            results = evaluate_collection(gene_sets, G, Hs, args.distance)
            # test_name, statistic, pval, gene_set_results = results
            evaluation[t][k_max] = results
    evaluation["core_modules"] = evaluate_collection(core_modules, G, Hs,
                                                     args.distance)
    return evaluation
 
[docs]def run_matrix_permutation_test(args, collections, mutation_data):
    """Runs the direct interactions or average pairwise distance test
    on each of the collections and the core_modules.
    **Returns**:
      * **evaluation** (*dictionary*) - a mapping of t -> k -> the network evaluation tuple of each collection (see :func:`network_tests.evaluate_collection` for details)
    """
    import permute_mutation_data as Permut
    from matrix_permutation_test import load_permuted_matrices, matrix_permutation_test
    # Load / generate networks
    if args.permuted_matrices_dir:
        permuted_matrices = load_permuted_matrices(args.permuted_matrices_dir)
    else:
        m, n, genes, patients, G2T, T2G = mutation_data
        G = Permut.construct_mutation_graph(G2T, T2G)
        Hs = [ Permut.permute_mutation_data(G, genes, patients)
               for i in range(args.num_permuted_matrices) ]
        permuted_matrices = [ Permut.graph_to_mutation_data(H) for H in Hs ]
    # Perform network test
    evaluation = dict( [ (t, {}) for t in collections.keys() ] )
    for t in collections.keys():
        for k_max in collections[t].keys():
            gene_sets, weights = collections[t][k_max]
            test_args = [sum(weights), permuted_matrices,
                         t, args.min_gene_set_size, k_max, args.alpha,
                         args.delta, args.lmbda]
            pval = matrix_permutation_test(*test_args)
            evaluation[t][k_max] = pval
    return evaluation
 
def permutation_tests(args, collections, core_modules, mutation_data):
    """Wrapper function that performs both the network and matrix permutation tests
    on the input collections."""
    matrix_results = run_matrix_permutation_test(args, collections,
                                                mutation_data)
    network_results = run_network_permutation_test(args, collections, 
                                                   core_modules)
    return network_results, matrix_results
def flatten_collections(collections):
    """Takes a dictionary of parameter settings to Multi-Dendrix results (as output by 
    :func:`batch_multi_dendrix`), and flattens the map into a list of collections."""
    all_collections = [ ]
    for t in collections.keys():
        for k_max in collections[t].keys():
            collection, weights = collections[t][k_max]
            all_collections.append( collection )
    return all_collections
[docs]def run(args):
    """Runs the whole Multi-Dendrix pipeline for the given command-line arguments."""
    # Run Multi-Dendrix for all parameter settings 
    collections, mutation_data, runtime = batch_multi_dendrix(args)
    
    # Extract the stable modules
    import core_modules as Core
    all_collections = flatten_collections(collections)
    core = Core.extract_core_modules(all_collections, args.stability_threshold)
    core_modules, module_graph = core
    # Perform the permutation test (if required)
    if args.permute:
        evaluation  = permutation_tests(args, collections, core_modules, 
                                        mutation_data)
    else:
        evaluation  = None, None
    # Perform subtype analysis (if required)
    if args.subtypes and args.patient_whitelist:
        from subtype_specific_genes import subtype_analysis, create_subtype_tbl
        gene2specificity = subtype_analysis(mutation_data,
                                            args.patient_whitelist,
                                            args.subtype_sig_threshold)
        subtype_tbl = create_subtype_tbl(gene2specificity)
    else:
        gene2specificity, subtype_tbl = None, None
        if args.subtypes:
            print 'No patient whitelist (w/ (sub)types) provided, '\
                  
'skipping (sub)type analysis.'
    # Create tables used for text and/or html output
    collection_tbls = create_collection_tbls(args, collections,
                                               core_modules, evaluation)
    params_tbl = create_params_tbl(args, mutation_data)
    if args.permute:   
        network_tbl =  create_network_results_tbl(evaluation[0])
        matrix_tbl = create_matrix_results_tbl(evaluation[1])
    else:
        network_tbl, matrix_tbl = None, None
    # output results to text and html
    text_output_args = [ args, collection_tbls, runtime, params_tbl,
                         network_tbl, matrix_tbl, subtype_tbl ]
    output_to_text(*text_output_args)
    output_to_html(args, collections, runtime, module_graph, evaluation,
                   params_tbl, subtype_tbl, gene2specificity)
 
if __name__ == "__main__": run(parse_args())