#
# Script for generating the matrix of interaction-based functional similarities for the proteins
# in a multiple sequence alignment (MSA).
#
# For more information see:
#
#   Pitarch B, Ranea JAG, Pazos F. (2021)
#   Protein Residues Determining Interaction Specificity In Paralogous Families. 
#   Bioinformatics. 37(8):1076-1082. PMID: 33135068.
#
# This script reads a multiple sequence alignment (in PIR format) and an interactome (.dl or raw A<TAB>B)
# and reports the absolute ("abs") or relative ("rel") number of shared interactors for all pairs 
# of proteins in the MSA. Such output can be directly read by Xdet (http://csbg.cnb.csic.es/pazos/Xdet/)
# with the "-M" option. If the % of MSA proteisn without interactome info is higher than a 
# provided value (min_%prot_int) the program stops. The protein identifiers in the MSA should be in
# the format ">P1;xxxxxxx|ID", where "xxxxxxx" is ignored and "ID" is used to match the corresponding
# proteins in the interactome.
#
# Example of usage:
#
#   python3  interaction_based_matrix.py  intact_igraph_human_interactome.dl  Ras_curated.pir  rel  30  > Ras_si.dist
#   xdet  Ras_curated.pir  Maxhom_McLachlan.metric  -M=Ras_si.dist | sort -nr -k 9  > Ras_int_SDPs
#
# The files used in that example command line are available at:
# http://csbg.cnb.csic.es/pazos/Xdet/intact_igraph_human_interactome.dl
# http://csbg.cnb.csic.es/pazos/Xdet/Ras_curated.pir
# http://csbg.cnb.csic.es/pazos/Xdet/Ras_si.dist
#
# Author: Borja Pitarch
#
 
import re
import sys
if len(sys.argv) != 5:
        print ("** Usage : python3  int_map.py  interactome  alignment.pir  abs/rel  allowed_%prot_noint")
        exit()

interactome = sys.argv[1]
MSA = sys.argv[2]
mode  = sys.argv[3]
cut = 100 - float(sys.argv[4])
f1type= interactome.split(".")[-1]

if mode == "abs":
        absolute = True
elif mode == "rel":
        absolute = False
else:
        print ("No mode selected")
        exit(1)

try:
        fh = open(interactome,"r")
except:
        print("Interactome file does not exist")
        exit(2)

interaction =[]
header=fh.readline().split()
lines =fh.read()
proteins =list(dict.fromkeys(lines.split()))

fh.close()
interaction=[data.split("\t") for data in lines.split("\n")]


if f1type == "dl":
        interaction.remove(["format = edgelist1"])
        interaction.remove(["labels embedded:"])
        interaction.remove(["data:"])
        proteins.remove("format")
        proteins.remove("=")
        proteins.remove("edgelist1")
        proteins.remove("labels")
        proteins.remove("embedded:")
        proteins.remove("data:")



def interactions_list(prot): 
                int_prot=[]
                for element in interaction:
                        if prot in element:
                                element.remove(prot)
                                int_prot += element
                                element.append(prot)
                return int_prot


try:
        fh = open(MSA,"r")
except:
        print("MSA file does not exist")
        exit(2)

prot_alin=[]
for linea in fh: 
        if linea.startswith(">"):
                linea = ((linea.split()[0]).split("|")[-1]).strip()
                prot_alin.append(linea)

fh.close()

num_prot=0
for prot in prot_alin:
        if prot in proteins:
                num_prot +=1

print ("Number of proteins from the MSA file: " + str(len(prot_alin)),file=sys.stderr)
print ("Number of proteins from the MSA file in the interactome:" + str(num_prot),file=sys.stderr)

if (100*(num_prot/len(prot_alin))) < cut:
        print ("Not enough MSA proteins in the interactome",file=sys.stderr)
        exit(3)


for i in range(0,(len(prot_alin)-1)): 
        for j in range(i+1,len(prot_alin)):
                common_int=len(list(set(interactions_list(prot_alin[i])).intersection(interactions_list(prot_alin[j]))))
                if absolute:
                        print (str(i+1) + "\t" + str(j+1) + "\t" + str(common_int))
                else:
                        total= float(len(interactions_list(prot_alin[i])))+ float(len(interactions_list(prot_alin[j]))) - float(common_int)
                        if total != 0:
                                perc_int = 100*float(common_int)/total
                        else:
                                perc_int = 0.00

                        print("%d\t%d\t%.2f"%((i+1),(j+1),perc_int))
