#Imports
import pandas as pd
import argparse

#Function Definitions

def file_len(fname):
    with open(fname) as f:
        for i, l in enumerate(f):
            pass
    return i + 1


def sg_rule(i):
    '''Places a 'G' in front of every sgRNA sequence i, which does not start with a 'G' already and in this case discards the last character
    in order to keep the sequence length of 20'''
    if i[0] != "G":
        return str("G" + i[:19])
    else:
        return str(i)


def sgRNA_dict_custom(library_file):
    '''Takes a .csv file, containing sgRNA names in the first column
        and sgRNA sequences in the second column and generates a dictionary out
        of them. To every sgRNA that does not start with a 'G' this letter is added to the first position of the sequence, while the last
        character is discarded to maintain a seq of length 20'''
    df = pd.read_csv(library_file, sep = ";")
    
    df["sgRNA2"] = (df.apply(lambda x: sg_rule(x["sgRNA Sequence"]), axis = 1)).str.upper()
    
    d = dict(zip(df["sgRNA2"], df["sgRNA ID"]))
    return d


def read_identifier(fastq_file, vector):
    '''Takes a fastQ file and returns a list of two lists, the first one containing all sequences with the
    identifier and the second one containing all sequences without the identifier'''
    
    if vector == "H1":
        identifier = "CTGTATGAGACCACTCTTTCCC"
    elif vector == "U6":
        identifier = "CTTGTGGAAAGGACGAAACACC"
    else:
        print("Invalid vector")

    #List Initialization
    ident_pos = []
    ident_neg = []

    #FastQ File reading
    with open(fastq_file, "r") as f:
        for i, seq in enumerate(f, 1):
            if i%4 == 2:
                seq = seq.strip()
                P1 = seq.find(identifier)
                if P1>=0:
                    sgRNA = seq[P1+len(identifier):P1+len(identifier) + 20]
                    ident_pos.append(sgRNA)
                else:
                    ident_neg.append(seq)

    export = [ident_pos, ident_neg]

    return export


def map_sgRNAs(read_list, seq_name_dict):
    '''Maps sequences to names and returns a list of two lists, with the first one carrying
    all mapped sgRNAs and the second one carrying all unmapped reads'''
    
    #List Initializing
    mapped_sgRNAs = []
    unmapped_reads = []
    
    for i in read_list:
        if i in seq_name_dict:
            name = seq_name_dict[i]
            tup = (i, name)
            mapped_sgRNAs.append(tup)
        else:
            unmapped_reads.append(i)

    output = [mapped_sgRNAs, unmapped_reads]

    return output


def map_sgRNAs_wo_ident(read_list, seq_name_dict):
    '''Maps all 20 nt substrings of a string to an sgRNA dictionary.'''

    #List Initializing
    mapped_sgRNAs = []
    unmapped_reads =[]
    
    for i in read_list:
        for j in range(len(i) - 19):
            sg = i[-20 - j:len(i) - j]
            if sg in seq_name_dict:
                name = seq_name_dict[sg]
                tup = (sg, name)
                mapped_sgRNAs.append(tup)
                break
            elif j == len(i) - 20:
                unmapped_reads.append(i)

    output = [mapped_sgRNAs, unmapped_reads]

    return output


def empty_vector_counter(read_list, vector):
    '''Counts occurences of empty vector sequences in a list of sequences'''
    
    if vector == "H1":
        identifier = "GGGTCTTCGTTTGTTTTGTG"
    elif vector == "U6":
        identifier = "GGAGACGGTTGTAAATGAGC"
    else:
        print("Invalid vector")

    counter = 0

    for q in read_list:
        if identifier in q:
            counter +=1

    return counter


def oligopool_to_sgRNA(oligopool, vector):
    '''Reads in a Oligo Pool and generates a list of unique sgRNAs'''

    if vector == "H1":
        identifier = "CTGTATGAGACCACTCTTTCCC"
    elif vector == "U6":
        identifier = "CTTGTGGAAAGGACGAAACACC"
    else:
        print("Invalid vector")

    df = pd.read_csv(oligopool, sep=";")
    oligos = df["Oligo"].tolist()
    sgRNAs = []

    for i in oligos:
        pos = i.find(identifier)
        sg = i[pos+20:pos+40]
        sg = sg.upper()
        sgRNAs.append(sg)

    return sgRNAs


def cross_library_counter(unmapped_reads, pool_sgRNAs):

    counter = 0

    for i in unmapped_reads:
        if len(i) == 20:
            if i in pool_sgRNAs:
                counter += 1
        elif len(i) > 20:
            for j in range(len(i) - 19):
                sg = i[-20 - j:len(i) - j]
                if sg in pool_sgRNAs:
                    counter += 1
                    break

    return counter


###Parser###
my_parser = argparse.ArgumentParser(
    description="Count NGS reads mapped to sub-library, empty vector, cross-library or unmapped")

my_parser.add_argument("-fq", "--fastq",
                        dest="fastq",
                        type=str,
                        action="store",
                        nargs=1,
                        help="fastq file names separated by space")

my_parser.add_argument("-l", "--library",
                        dest="library",
                        type=str,
                        action="store",
                        nargs=1,
                        help="sgRNA library file name")

my_parser.add_argument("-p", "--pool",
                        dest="pool",
                        type=str,
                        action="store",
                        nargs=1,
                        help="oligo-pool file name")

my_parser.add_argument("-v", "--vector",
                        dest="vec",
                        type=str,
                        action="store",
                        choices=["H1", "U6"],
                        help="type of sgRNA expression vector system used (H1 or U6)")

my_parser.add_argument("-o", "--output",
                        dest="output",
                        type=str,
                        action="store",
                        nargs=1,
                        help="output file name")

args = my_parser.parse_args()

d = sgRNA_dict_custom(args.library[0])

reads1 = read_identifier(args.fastq[0], args.vec)

no_ident_mapped = map_sgRNAs_wo_ident(reads1[1], d)

perfect_sg = len(no_ident_mapped[0])  #2 Perfect sgRNA
node2 = len(no_ident_mapped[1])  #Node2

ident_mapped = map_sgRNAs(reads1[0], d)

perfect_read = len(ident_mapped[0]) #1 Perfect Reads
node1 = len(ident_mapped[1])  #Node1
                       
imperfect_empty_v = empty_vector_counter(no_ident_mapped[1], args.vec)  #4 Imperfect EV                      
perfect_empty_v = empty_vector_counter(ident_mapped[1], args.vec)  #3 Perfect EV

pool_sgRNAs = oligopool_to_sgRNA(args.pool[0], args.vec)

perfect_cross_map = cross_library_counter(ident_mapped[1], pool_sgRNAs)  #5 Perfect cross-map
imperfect_cross_map = cross_library_counter(no_ident_mapped[1], pool_sgRNAs)  #6 Imperfect cross-map

node3 = node1 - perfect_empty_v
perfect_unmapped = node3 - perfect_cross_map

node4 = node2 - imperfect_empty_v
imperfect_unmapped = node4 - imperfect_cross_map

labels = ["Perfect Reads", "No H1 but sgRNA", "Perfect Empty Vec.", "No H1 but Empty Vec.", "Perfect Cross-Map",
          "No H1 but Cross-Map", "H1 but Unmapped", "No H1 and Unmapped"]

labels2= ["Sub-Library mapped", "Empty Vector", "Cross-Mapped", "Unmapped"]

values = [perfect_read, perfect_sg, perfect_empty_v, imperfect_empty_v, perfect_cross_map, imperfect_cross_map,
          perfect_unmapped, imperfect_unmapped]

values2 = [perfect_read + perfect_sg, perfect_empty_v + imperfect_empty_v, perfect_cross_map + imperfect_cross_map,
          perfect_unmapped + imperfect_unmapped]

total_reads = sum(values2)

values_percent = [(i/total_reads)*100 for i in values2]

out = pd.DataFrame({"Read Type":labels2,
                    "Read Count":values2,
                    "Percentage":values_percent})
out.to_csv(args.output[0] + ".txt", sep="\t")

print("--------Script finished--------")