#!/usr/bin/python
# Load required modules
import sys, os, networkx as nx
from itertools import combinations, product
def parse_args(input_list=None):
    # Parse arguments
    import argparse
    class Args: pass
    args = Args()
    description = 'Calculates network permutation test for given pathway set '\
                  'on the given network.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('-n', '--network_edgelist', required=True,
                        help='PPI edgelist location.')
    parser.add_argument('-i', '--permuted_networks_dir', required=True,
                        help='Directory of permuted networks.')
    parser.add_argument('-p', '--pathway_set_file', required=True,
                        help='File containing input pathway set.')
    parser.add_argument('-o', '--output_file', default=None,
                        help='Name of output file.')
    parser.add_argument('-d', '--distance', default=False, action='store_true',
                        help='Flag average pairwise distance test.')
    parser.add_argument('-v', '--verbose', default=False, action='store_true',
                        help='Flag verbose mode.')
    # If called from the command line, parse command line args.
    if input_list: parser.parse_args(input_list, namespace=args)
    else: parser.parse_args(namespace=args)
    
    return args
[docs]def load_pathway_set(pathway_set_file):
	return [ l.rstrip().split("\t")[1:] for l in open(pathway_set_file) ]
 
[docs]def load_network(network_edgelist): return nx.read_edgelist(network_edgelist)
 
[docs]def remove_name_annotation(genes):
    return map(lambda g: g.replace("(A)", "").replace("(D)", ""), genes)
 
[docs]def interact(network, g1, g2):
    '''Returns true if g1 interacts with g2 in PPI'''
    try: return g2 in nx.neighbors(network, g1)
    except nx.exception.NetworkXError: return False
 
def combos(xs, n): return list(combinations(xs, n))
def prod(xs, ys): return list(product(xs, ys))
[docs]def pairs_within(pathways):
    '''Given a set of pathways, returns all {n \choose 2} pairs from within EACH pathway'''
    return reduce(lambda pairs, p: pairs + combos(p, 2), pathways, [])
 
[docs]def pairs_between(pathways):
    '''Given a set of pathways, returns all pairs from each distinct pair of pathways'''
    pairs_of_pathways = combos(pathways, 2)
    return reduce(lambda pairs, (p1, p2): pairs + prod(p1, p2), pairs_of_pathways, [])
 
[docs]def count_interactions(network, pairs):
    '''Given a PPI and a list of gene pairs, returns the number of genes that interact'''
    return sum( [int(interact(network, g1, g2)) for g1, g2 in pairs])
 
[docs]def num_interactions_in_pathway(network, pathway):
    best_count = 0
    pathway = list( set( pathway ) )
    k = len(pathway)
    pairs = [(pathway[i], pathway[j]) for i in range(k) for j in range(i+1, k)]
    return count_interactions(network, pairs)
 
[docs]def num_interaction_statistic(network, pathways):
    within_pairs = pairs_within(pathways)
    interactions_within  = count_interactions(network, within_pairs)
    nu_within = float(interactions_within) / float(len(within_pairs))
    between_pairs = pairs_between(pathways)
    interactions_between = count_interactions(network, between_pairs)
    nu_between = float( interactions_between ) / float(len(between_pairs))
    return nu_within - nu_between
 
[docs]def num_interactions_network_test(pathways, G, Hs):
	# Remove name annotation for genes in the pathways
	pathways = [ list(set(remove_name_annotation(P))) for P in pathways ]
	
	# Calculate the number of interactions statistic in the original network
	# and the permuted networks. Higher statistics indicate stronger
	# pathway sets.
	stat  = num_interaction_statistic(G, pathways)
	permuted_stats = [ num_interaction_statistic(H, pathways) for H in Hs ]
	count = len( [s for s in permuted_stats if s >= stat ] )
	pval  = float( count ) / float( len( Hs ) )
	return stat, pval
 
[docs]def avg_pair_dist_of_pathway(network, pathway):
    min_dist = 1e100
    pathway = list( set( pathway ) )
    k = len(pathway)
    pairs = [(pathway[i], pathway[j]) for i in range(k) for j in range(i+1, k)]
    return float( sum_dist(network, pairs) ) / float( len(pairs ) )
 
[docs]def sum_dist(network, pairs):
    '''Given a PPI and a list of gene pairs, returns the sum of the shortest paths between
       each pair.'''    
    return sum( [ dist(network, g1, g2) for g1, g2 in pairs ] )
 
[docs]def dist(network, g1, g2):
    '''Returns the length of the shortest path between g1 and g2 in PPI.
       If no path exists, returns 1e100'''
    try: return len(nx.shortest_path(network, g1, g2))-1
    except (nx.NetworkXNoPath, nx.NetworkXError): return 1e100
 
[docs]def avg_pair_dist_ratio(network, pathways):
	# Ensure all genes are in the network
	if not all([ g in network.nodes() for P in pathways for g in P ]):
		return 1e100
	# Enumerate pairs of genes within the same and between pathways
	within, between = pairs_within(pathways), pairs_between(pathways)
	within_dists    = [ dist(network, g1, g2) for g1, g2 in within ]
	between_dists   = [ dist(network, g1, g2) for g1, g2 in between ]
	# Ensure all genes in the same pathway are in the same component
	if not all([ d != 1e100 for d in within_dists ]):
		return 1e100
	# Check if any genes in different pathways are not in the same component
	if not all([ d != 1e100 for d in between_dists ]):
		return 0
	# If all checks out, return the ratio
	d_within  = float(sum(within_dists)) / float(len(within))
	d_between = float(sum(between_dists)) / float(len(between))
	return d_within / d_between
 
[docs]def avg_pair_dist_network_test(pathways, G, Hs):
	# Remove name annotation for genes in the pathways
	pathways = [ list(set(remove_name_annotation(P))) for P in pathways ]
	
	# Calculate the average pairwise distance ratio in the original network
	# and the permuted networks. Lower ratios indicate stronger pathway sets.
	ratio  = avg_pair_dist_ratio(G, pathways)
	permuted_ratios = [ avg_pair_dist_ratio(H, pathways) for H in Hs ]
	count = len( [r for r in permuted_ratios if r <= ratio ] )
	pval  = float( count ) / float( len( Hs ) )
	return ratio, pval
 
[docs]def eval_pathways_by_dist(pathways, G, Hs):
	stats, pvals = [], []
	for p in pathways:
		avg_pair_dist = avg_pair_dist_of_pathway(G, p)
		permuted_dists  = [ avg_pair_dist_of_pathway(H, p) for H in Hs ]
		extreme = [ n for n in permuted_dists if n <= avg_pair_dist]
		pval = float( len(extreme) ) / float( len(Hs) )
		stats.append( avg_pair_dist )
		pvals.append( pval )
	return zip(pathways, stats, pvals)
 
[docs]def eval_pathways_by_interactions(pathways, G, Hs):
	stats, pvals = [], []
	for p in pathways:
		num_interactions = num_interactions_in_pathway(G, p)
		permuted_counts  = [ num_interactions_in_pathway(H, p) for H in Hs ]
		extreme = [ n for n in permuted_counts if n >= num_interactions]
		pval = float( len(extreme) ) / float( len(Hs) )
		stats.append( num_interactions )
		pvals.append( pval )
	return zip(pathways, stats, pvals)
 
[docs]def evaluate_pathway_set(pathways, G, Hs, distance=False):
	# Calculate test statistic and p-value
	if distance:
		test_name = 'Average_Pairwise_Distance_Ratio'
		statistic, pval = avg_pair_dist_network_test(pathways, G, Hs)
		pathway_results = eval_pathways_by_dist(pathways, G, Hs)
	else:
		test_name = 'Number_of_Interactions_Statistic'
		statistic, pval = num_interactions_network_test(pathways, G, Hs)
		pathway_results = eval_pathways_by_interactions(pathways, G, Hs)
	return test_name, statistic, pval, pathway_results
 
[docs]def run(args):
	# Load input
	if args.verbose: print 'Loading networks...'
	pathways = load_pathway_set(args.pathway_set_file)
	G        = load_network(args.network_edgelist)
	Hs       = [ load_network(args.permuted_networks_dir + "/" + fh)
	             for fh in os.listdir(args.permuted_networks_dir) ]
	
	# Evaluate the pathway set
	if args.verbose: print 'Evaluating input pathway set...'
	results = evaluate_pathway_set(pathways, G, Hs, args.distance)
	test_name, statistic, pval, pathway_results = results	
	# Output and return results
	if args.output_file:
		header = 'Testing\t' + test_name + '\tP-Value\n'
		output = [ '\t'.join(['Pathway set', str(statistic), str(pval)]) ]
		for pathway, stat, p in pathway_results:
			output.append( '\t'.join([', '.join(pathway), str(stat), str(p)]) )
		open(args.output_file, 'w').write(header + '\n'.join(output))
	
	if args.verbose:
		print 'Pathway set test'
		print '\t' + test_name + ':', statistic
		print '\tP-value:', pval
		print '\nIndividual pathways test'
		for p, stat, pval in pathway_results:
			print '\tGenes:', ', '.join(p)
			print '\tStatistic:', stat
			print '\tP-value:', pval, '\n'
	return statistic, pval, pathway_results
 
if __name__ == "__main__": run(parse_args())