import pandas as pd
import csv
import sys


def generator_from_text_file(filename,chunksize=2048,outputstyle='encint',therank=None,thewordlength=None):
    """
    Returns a generator that reads from encoded text file. 
    output = 'int', 'str', 'encint', or 'encstr' - 'int' and 'str' are word in free group represented as either a list of nonzero integers or as a string, resp, in human readable form. 'encint' and 'encstr' are this word encoded as an integer or string using functions intencode and enc64, resp. 'encint' is more memory efficient to work with in python. 'encstr' is more efficient for saving to disk.
    """
    if therank is not None:
        rank=therank
    else:
        rank=int(filename[4+filename.index('Rank')]) # try to read the rank from the filename, assumes 1 digit rank
    if thewordlength is not None:
        wordlength=thewordlength
    else:
        wordlength=int(filename[3+filename.index('Len'):5+filename.index('Len')]) # try to read wordlength from the filename, assumes 2 digit wordlength
    if sys.version_info.major==2:
        def chunks(filename,size):
            with open(filename,'rb') as infile:
                chunk=infile.read(size)
                while chunk:
                    yield chunk
                    chunk=infile.read(size)
    else: #data files are ASCII. In python3 we have to decode data bytes to ASCII strings.
        def chunks(filename,size):
            with open(filename,'rb') as infile:
                chunk=infile.read(size)
                while chunk:
                    yield chunk.decode('ASCII')
                    chunk=infile.read(size)
    def words(filename,wordlength,chunksize):
        encodedwordlength=wordlength%2+wordlength//2
        for chunk in chunks(filename,encodedwordlength*chunksize):
            assert(len(chunk)%encodedwordlength==0)
            for i in range(len(chunk)//encodedwordlength):
                yield chunk[i*encodedwordlength:(i+1)*encodedwordlength]
    for word in words(filename,wordlength,chunksize):
        if outputstyle=='encint':
            yield intencode(rank,dec64(word))
        elif outputstyle=='encstr':
            yield word
        elif outputstyle=='int':
            yield dec64(word)
        elif outputstyle=='str':
            yield freegroupinttoletters(dec64(word))
        else:
            raise ValueError("Output format "+str(output)+" not recognized.")

def read_csv(filename,freegrouprank=None,keystyle='encint',outputtype='dataframe',drop=True):
    """
    Read orgcensus data from a csv file.
    outputtype can be either 'dict' or 'dataframe'
    if compress=True then keys/indices will be integers that 
    """
    if freegrouprank is not None:
        rank=freegrouprank
    else:
        rank=int(filename[4+filename.index('Rank')]) # try to read the rank from the filename, assumes 1 digit rank
    if outputtype=='dataframe':
        return read_csv_to_df(filename,rank,keystyle,drop=drop)
    elif outputtype=='dict':
        return read_csv_to_dict(filename,rank,keystyle)
    else:
        raise InputError

    
def read_csv_to_dict(inputfilename,rank,keystyle):
    data=dict()
    with open(inputfilename,'rb') as myfile:
        rdr=csv.reader(myfile)
        #header=rdr.next()
        for row in rdr:
            wletters=dec64(row[0])
            if keystyle=='encint':
                k=intencode(rank,wletters)
            elif keystyle=='str':
                k=freegroupinttoletters(wletters)
            elif keystyle=='int':
                k=tuple(wletters)
            elif keystyle=='encstr':
                k=row[0]
            data[k]=dict()
            (irank,hyp,hyppar)=decrecord(row[1])
            data[k]['irank']=irank
            if hyp is not None:
                data[k]['hyperbolic']=hyp
            if hyppar is not None:
                data[k]['hyperbolicparabolic']=hyppar
            if len(row)>2:
                hreason=dechreason(row[2])
                if hreason:
                    data[k]['hyperbolicreason']=hreason
    return data

def read_csv_to_df(filename,rank,keystyle,drop=True):
    df=pd.read_csv(filename,header=None,names=['k64','record','hr'])
    if keystyle=='encint':
        df['k']=df.apply(lambda row: intencode(rank,dec64(row['k64'])),axis=1)
    elif keystyle=='str':
        df['k']=df.apply(lambda row: freegroupinttoletters(dec64(row['k64'])),axis=1)
    elif keystyle=='int':
        df['k']=df.apply(lambda row: dec64(row['k64']),axis=1)
    elif keystyle=='encstr':
        df.rename(columns={'k64':'k'}, inplace=True)
    df['longrecord']=df.apply(lambda row: decrecord(row['record']),axis=1)
    df['irank']=df.apply(lambda row: row['longrecord'][0],axis=1)
    df['hyperbolic']=df.apply(lambda row: row['longrecord'][1],axis=1)
    df['hyperbolicparabolic']=df.apply(lambda row: row['longrecord'][2],axis=1)
    df['hyperbolicreason']=df.apply(lambda row: dechreason(row['hr']),axis=1)
    df.set_index('k',inplace=True)
    df.index.name=None
    if drop:
        if keystyle=='encstr':
            df.drop(['record','hr','longrecord'],axis=1,inplace=True)
        else:
            df.drop(['k64','record','hr','longrecord'],axis=1,inplace=True)
    return df


            

        
####### auxiliary functions for encoding and decoding relevant python data.########
def enc64(w):
    """
    Takes as input a nonempty list w of nozero integers of absolute value at most 4 representing a freely reduced word in a free group of rank at most 4 and encodes it as a string.
    """
    # we don't actually care that the word is freely reduced except that if it has odd length we pad it by adding the inverse of the first element as the new first element. Then the decoder will check for a free reduction at the first pair and if present discard the padding digit.
    assert(len(w))
    assert(all(abs(x)<=4 and abs(x)>0 for x in w))
    thestring=''
    if len(w)%2:
        intlist=[-w[0]]+[x for x in w]
    else:
        intlist=[x for x in w]
    def intshift(x): #[-4,-3,-2,-1,1,2,3,4]->[0,1,2,3,4,5,6,7]
        if x<0:
            return x+4
        else:
            return x+3
    inttochar=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y','z','0','1','2','3','4','5','6','7','8','9','+','/']
    for i in range(len(intlist)//2):
        thestring+=inttochar[intshift(intlist[2*i])*8+intshift(intlist[2*i+1])]
    return thestring

def dec64(e):
    """
    Inverse of enc64.
    """
    chartoint={'A':0,'B':1,'C':2,'D':3,'E':4,'F':5,'G':6,'H':7,'I':8,'J':9,'K':10,'L':11,'M':12,'N':13,'O':14,'P':15,'Q':16,'R':17,'S':18,'T':19,'U':20,'V':21,'W':22,'X':23,'Y':24,'Z':25,'a':26,'b':27,'c':28,'d':29,'e':30,'f':31,'g':32,'h':33,'i':34,'j':35,'k':36,'l':37,'m':38,'n':39,'o':40,'p':41,'q':42,'r':43,'s':44,'t':45,'u':46,'v':47,'w':48,'x':49,'y':50,'z':51,'0':52,'1':53,'2':54,'3':55,'4':56,'5':57,'6':58,'7':59,'8':60,'9':61,'+':62,'/':63}
    intlist=[]
    def reshift(x): #[0,1,2,3,4,5,6,7]->[-4,-3,-2,-1,1,2,3,4]
        if x<4:
            return x-4
        else:
            return x-3
    for c in e:
        seconddigit=chartoint[c]%8
        firstdigit=(chartoint[c]-seconddigit)//8
        intlist.append(reshift(firstdigit))
        intlist.append(reshift(seconddigit))
    if intlist[0]==-intlist[1]:
        return intlist[1:]
    else:
        return intlist

def encrecord(t):
    """
    Encode data about a free group element as a string.

    t is tuple (irank,hyperbolic,hyperbolicparabolic)
    where irank is digit 1,2,3,4,float('inf') and other fields are True/False/None
    """
    tval={(1,True,True):'A',(2,False,False):'B',(2,False,None):'C',(2,False,True):'D',(2,None,False):'E',(2,None,None):'F',(2,None,True):'G',(2,True,False):'H',(2,True,None):'I',(2,True,True):'J',(3,False,False):'K',(3,False,None):'L',(3,False,True):'M',(3,None,False):'N',(3,None,None):'O',(3,None,True):'P',(3,True,False):'Q',(3,True,None):'R',(3,True,True):'S',(4,False,False):'T',(4,False,None):'U',(4,False,True):'V',(4,None,False):'W',(4,None,None):'X',(4,None,True):'Y',(4,True,False):'Z',(4,True,None):'a',(4,True,True):'b',(float('inf'),True,None):'c'}
    return tval[t]

def decrecord(encodedrecord):
    """
    inverse of encrecord
    """
    valt={'A':(1,True,True),'B':(2,False,False),'C':(2,False,None),'D':(2,False,True),'E':(2,None,False),'F':(2,None,None),'G':(2,None,True),'H':(2,True,False),'I':(2,True,None),'J':(2,True,True),'K':(3,False,False),'L':(3,False,None),'M':(3,False,True),'N':(3,None,False),'O':(3,None,None),'P':(3,None,True),'Q':(3,True,False),'R':(3,True,None),'S':(3,True,True),'T':(4,False,False),'U':(4,False,None),'V':(4,False,True),'W':(4,None,False),'X':(4,None,None),'Y':(4,None,True),'Z':(4,True,False),'a':(4,True,None),'b':(4,True,True),'c':(float('inf'),True,None)}
    return valt[encodedrecord]


def enchreason(hreason):
    if hreason is None:
        return 'A'
    if 'k' in hreason:
        return 'k'
    if 'f' in hreason:
        return 'f'
    if 't' in hreason:
        return 't'
    reasons=[0,0,0,0,0]
    if 'c' in hreason:
        reasons[0]=1
    if 'i' in hreason:
        reasons[1]=1
    if 's' in hreason:
        reasons[2]=1
    if 'b' in hreason:
        reasons[3]=1
    if 'w' in hreason:
        reasons[4]=1
    reasoncode=int(''.join(str(x) for x in reasons),2)
    return ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','g'][reasoncode]

def dechreason(e):
    if e is None or e=='' or type(e)==float:
        return ''
    if e in {'k','f','t'}:
        return e
    reasoncode=['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z','a','b','c','d','e','g'].index(e)
    reasons=''
    if reasoncode//16:
        reasons+='c'
        reasoncode=reasoncode%16
    if reasoncode//8:
        reasons+='i'
        reasoncode=reasoncode%8
    if reasoncode//4:
        reasons+='s'
        reasoncode=reasoncode%4
    if reasoncode//2:
        reasons+='b'
        reasoncode=reasoncode%2
    if reasoncode:
        reasons+='w'
    if reasons:
        return reasons
    else:
        return None
        


def intencode(rank,sequenceofnonzerointegers):
    """
    Given rank of free group and contianer of non-zero integers denoting a word in terms of numerbed generators and their inverses, encode the word as a single integer.
    """
    thedigits=[x+rank if x>0 else x+rank+1 for x in sequenceofnonzerointegers]
    thedigits.reverse()        
    return sum([thedigits[i]*(2*rank+1)**i for i in range(len(thedigits))])

def intdecode(rank,theint):
    """
    Given rank of free group and integer encoding word in terms of generators and relators, decode the word. Returns a list of non-zero integers.
    """
    thelist=[]
    while theint:
        thelist.append(theint%(2*rank+1))
        theint//=(2*rank+1)
    thelist=[x-(rank+1) if x<=rank else x-rank for x in thelist]
    thelist.reverse()
    return thelist

def freegroupinttoletters(intlist):
    s=''
    d={1:'a',2:'b',3:'c',4:'d',-1:'A',-2:'B',-3:'C',-4:'D'}
    for i in intlist:
        s+=d[i]
    return s

def freegroupletterstoints(string):
    intlist=[]
    d={'a':1,'b':2,'c':3,'d':4,'A':-1,'B':-2,'C':-3,'D':-4}
    for c in string:
        intlist.append(d[c])
    return intlist
