#Imports
from collections import Counter
import pandas as pd
from pandas.api.types import is_numeric_dtype
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import argparse

plt.style.use('ggplot')

#Function Definitions

def sg_rule(i):
    if i[0] != "G":
        return str("G" + i[:19])
    else:
        return str(i[:20])


def sgRNA_dict_pool(pool, vector):
    '''Generates an sgRNA dictionary from an oligo pool file, as produced from the CLUE script'''
    if vector == "H1":
        identifier = "GTATGAGACCACTCTTTCCC"
    elif vector == "U6":
        identifier = "TGTGGAAAGGACGAAACACC"
    else:
        print("Invalid vector")

    df = pd.read_csv(pool, sep=";")
    oligos = df["Oligo"].tolist()
    sgRNAs = []

    for i in oligos:
        pos = i.find(identifier)
        sg = i[pos+20:pos+40]
        sg = sg.upper()
        sgRNAs.append(sg)

    df["sgRNA"] = sgRNAs
    df.drop_duplicates(subset=["sgRNA"], inplace=True)
    
    d = dict(zip(df["sgRNA"], df["sgRNA ID"]))

    if is_numeric_dtype(df["sgRNA ID"]):
        for key in d:
            d[key] = str(d[key])

    return d


def sgRNA_dict_custom(library_file):
    '''Generates an sgRNA Library dictionary from an sgRNA sub-library file generated by the CLUE script'''
    df = pd.read_csv(library_file, sep=";")
    df["Seq2"] = (df.apply(lambda x: sg_rule(x["sgRNA Sequence"]), axis = 1)).str.upper()
    return dict(zip(df["Seq2"], df["sgRNA ID"]))


def map_sgRNAs(fastq_file, seq_name_dict, vector):
    '''Reads in a fastQ file and a sequence-to-name dictionary for sgRNAs to generate a list of tuples,
    where each tuple carries the sgRNA_seq and sgRNA_name from a read of the fastQ file. This
    list is subsequently to be used for counting of sgRNAs'''

    if vector == "H1":
        identifier = "GTATGAGACCACTCTTTCCC"
    elif vector == "U6":
        identifier = "TGTGGAAAGGACGAAACACC"
    else:
        print("Invalid vector")

    #List Initializing
    export = []
    
    #FastQ File Reading
    with open(fastq_file, "r") as f:
        for i, seq in enumerate(f, 1):
            if i%4 == 2:
                seq = seq.strip()
                P1 = seq.find(identifier)
                if P1>0:
                    sgRNA = seq[P1+20:P1+40]
                    if sgRNA in seq_name_dict:
                        name = seq_name_dict[sgRNA]
                        tup = (sgRNA, name)
                        export.append(tup)

    return export


def sgRNA_counter(sgRNA_reads, seq_name_dict, name):
    '''Generates an alphabetically sorted list of tuples with (sgRNA_name, counts) from an sgRNA list
    produced by function "map_sgRNAs"'''
    #Counter
    counter_dict = Counter(elem[1] for elem in sgRNA_reads)
    names = list(seq_name_dict.values())

    for i in names:
        if i not in counter_dict:
            counter_dict[i] = 0
            
    out_table = []
    for i in counter_dict:
        a = i
        c = counter_dict[i]
        tup = (a, c)
        out_table.append(tup)
    out_table.sort()
    df = pd.DataFrame(out_table, columns=["sgRNA ID", name])
    
    return df


def read_count_table(file_list, seq_name_dict, vector):
    df = pd.DataFrame({"sgRNA ID" : list(seq_name_dict.values())})
    for i in file_list:
        a = map_sgRNAs(i, seq_name_dict, vector)
        b = sgRNA_counter(a, seq_name_dict, i[:i.find(".")])
        df = pd.merge(df, b, how='left', on='sgRNA ID')
    return df


def density_graph(rct, names):
    '''Generates density graphs from a given read-count table'''
    for name in names:
        #Read Count Normalization
        m = rct[name].mean()
        l = rct[name].tolist()
        l = [(j/m)*1000 for j in l]
        x = [np.log10(k+1) for k in l]

        g = sns.kdeplot(x, shade=True, color="k")
        sns.rugplot(x, color="k")
        plt.xlabel("log10(norm. sgRNA counts)")
        plt.ylabel("Density")
        plt.title(name)
        plt.xticks(np.arange(0, 4.1, 0.5))
        plt.savefig(str(name) + "_density.png", dpi=600)
        plt.clf()


def bar_graph(rct, names):
    '''Generates bar graphs from a given read-count table'''
    for name in names:
        plt.bar(x=rct["sgRNA ID"], height=rct[name], color="k")
        plt.title(name)
        plt.xlabel("sgRNAs")
        plt.ylabel("Read Counts")
        plt.gca().yaxis.grid(True)
        plt.gca().axes.get_xaxis().set_ticks([])
        plt.savefig(str(name) + "_bar.png", dpi = 600)
        plt.clf()


#Parser
my_parser = argparse.ArgumentParser()

my_parser.add_argument("-fq", "--fastq",
                        dest="fastq",
                        type=str,
                        action="store",
                        nargs="*",
                        help="input file names separated by space")

my_parser.add_argument("-l", "--library",
                        dest="library",
                        type=str,
                        action="store",
                        nargs=1,
                        help="library file name, i.e. oligo-pool file or sgRNA sub-library file")

my_parser.add_argument("-v", "--vector",
                        dest="vec",
                        type=str,
                        action="store",
                        choices=["H1", "U6"],
                        help="type of sgRNA expression vector system used (H1 or U6)")

my_parser.add_argument("-t", "--type",
                        dest="type",
                        type=str,
                        action="store",
                        choices=["sub-lib", "oligo-pool"],
                        help="library type, cloned sgRNA sub-library (sub-lib) or TOPO cloned oligo-pool (oligo-pool)")

my_parser.add_argument("-o", "--output",
                        dest="output",
                        type=str,
                        default="output",
                        action="store",
                        help="file name for the read count table output")

args = my_parser.parse_args()



#Input-to-use
names = [i[0:i.find(".")] for i in args.fastq]


if args.type == "sub-lib":
    d = sgRNA_dict_custom(args.library[0])
elif args.type == "oligo-pool":
    d = sgRNA_dict_pool(args.library[0], args.vec)

print(">>>sgRNA Dictionary generated<<<")
rct = read_count_table(args.fastq, d, args.vec)

bar_graph(rct, names)
print(">>>Bar Graph generated<<<")

density_graph(rct, names)
print(">>>Density Plot generated<<<")

l = rct["sgRNA ID"].tolist()
l2 = []
for i in l:
    a = i[:i.rfind("_")]
    l2.append(a)

rct.insert(loc=1, column="Gene", value=l2)
rct.set_index("sgRNA ID", inplace=True)
rct.to_csv(args.output+".txt", sep="\t")
print(">>>Read count table generated<<<")
print("--------Script finished--------")