#!/usr/bin/python
# Load required modules
import sys
def parse_args(input_list=None):
    # Parse arguments
    import argparse
    class Args: pass
    args = Args()
    description = 'Extracts core modules from input collections of gene sets.'\
                  ' Core modules are defined by how often genes appear in the'\
                  ' same gene set together.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('--stability_threshold', type=float, default=1,
                        help='Minimum proportion of pathways two genes must '\
                             'both be a member of to be connected in the '\
                             'core modules.')
    parser.add_argument('-i', '--input_collections', nargs="*", required=True,
    	                help='Input files that are collections of gene sets '\
    	                      'output by Multi-Dendrix')
    # If called from the command line, parse command line args.
    if input_list: parser.parse_args(input_list, namespace=args)
    else: parser.parse_args(namespace=args)
    
    return args
def load_collection(collection_file):
	"""Extracts the gene sets from a collection file output by Multi-Dendrix
	:type collection_file: string
	:param collection_file: file location of collection to be loaded.
	:returns: **collection** (*list*) - the gene sets (*without* weights) in the input collection file.
	"""
	return [ l.rstrip().split("\t")[1:] for l in open(collection_file) ]
[docs]def extract_core_modules(collections, stability_threshold):
    """Extracts the core modules from a set of collections output by Multi-Dendrix.
    Core modules are defined by how often genes appear in the same gene set together
    for different parameters of Multi-Dendrix ("how often" is tunable using the 
    	stability threshold).
	
	:type collections: list of lists of lists
	:param collections: multiple collections of gene sets identified by Multi-Dendrix.
	:type stability_threshold: int
	:param stability_threshold: 
    """
    # Load required modules
    import networkx as nx
    
    # Make a set of all genes in any gene set
    genespace = set()
    all_collections = [ ]
    for t in collections.keys():
        for k_max in collections[t].keys():
            collection, weights = collections[t][k_max]
            genespace.update( [g for p in collection for g in p] )
            all_pathways_sets.append( collection )
    # Determine pathway membership for all pathways
    membership, num_sets = {}, float(len(all_collections))
    gene2count = dict([(g, 0) for g in genespace])
    for g in genespace:
        membership[g] = dict([(g2, 0) for g2 in genespace if g2 != g])
        gene2count[g] = len([p for P in all_collections for p in P if g in p])
        
        for g2 in genespace:
            if g2 == g: continue
            for collection in all_collections:
                for p in collection:
                    if g in p and g2 in p:
                        membership[g][g2] += 1
    # Create a graph using membership proportions as edge weights
    G = nx.Graph()
    G.add_nodes_from(genespace)
    edges = [(g, g2, membership[g][g2])\
             for g in genespace for g2 in membership[g].keys()
             if membership[g][g2] > args.stability_threshold]
    G.add_edges_from([(u, v, dict(weight=w)) for u, v, w in edges])
    
    # Return the graph and its connected components
    return [ subG for subG in nx.connected_components(G) if len(subG) > 1 ], G
def run(args):
	# Load pathway sets
	return
	
if __name__ == "__main__": run(parse_args())