//
// Clip encoder includes code which was modified from https://github.com/josephrocca/clip-bpe-js
//

const textEncoder = new TextEncoder()
const encodeStr = (str: string) => {
    return [...textEncoder.encode(str)]
}

const textDecoder = new TextDecoder('utf8')
const decodeStr = (arr: Iterable<number>) => {
    return textDecoder.decode(new Uint8Array(arr))
}

const dictZip = (x: any, y: any) => {
    const result: any = {}
    x.map((_: any, i: any) => {
        result[x[i]] = y[i]
    })
    return result
}

const range = (x: number | undefined, y: any) => {
    const res = [...Array.from({ length: y }).keys()].slice(x)
    return res
}

const ord = (x: string) => {
    // eslint-disable-next-line unicorn/prefer-code-point
    return x.charCodeAt(0)
}

const chr = (x: number) => {
    // eslint-disable-next-line unicorn/prefer-code-point
    return String.fromCharCode(x)
}

function get_pairs(word: any[]) {
    const pairs = new Set<any>()
    let prev_char = word[0]
    for (let i = 1; i < word.length; i++) {
        const char = word[i]
        pairs.add([prev_char, char])
        prev_char = char
    }
    return pairs
}

const bytes_to_unicode = () => {
    const bs = [
        ...range(ord('!'), ord('~') + 1),
        ...range(ord('¡'), ord('¬') + 1),
        ...range(ord('®'), ord('ÿ') + 1),
    ]

    const cs = [...bs]
    let n = 0
    for (let b = 0; b < 2 ** 8; b++) {
        if (!bs.includes(b)) {
            bs.push(b)
            cs.push(2 ** 8 + n)
            n = n + 1
        }
    }

    const csm = cs.map((x: number) => chr(x))

    const result: Record<number, string> = {}
    bs.map((_, i) => {
        result[bs[i]] = csm[i]
    })
    return result
}

function basic_clean(text: string, htmlEntities: { decode: (_: string) => string }) {
    text = htmlEntities.decode(htmlEntities.decode(text))
    return text.trim()
}
function whitespace_clean(text: string) {
    return text.replace(/\s+/g, ' ').trim()
}

function bracket_clean(text: string) {
    return text.replace(/[[\]{}]/g, ' ').trim()
}

const clip_pattern =
    /<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+/giu

const byte_encoder: Record<number, string> = bytes_to_unicode()
const byte_decoder: Record<string, number> = {}
Object.keys(byte_encoder).map((x) => {
    byte_decoder[byte_encoder[Number.parseInt(x)]] = Number.parseInt(x)
})

export class ClipEncoder {
    constructor(
        bpeArr: string[],
        htmlEntities: { encode: (_: string) => string; decode: (_: string) => string }
    ) {
        this.htmlEntities = htmlEntities
        const merges = bpeArr.slice(1, 49152 - 256 - 2 + 1).map((merge) => merge.split(' '))
        // eslint-disable-next-line max-len, prettier/prettier
        let vocab = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', '¡', '¢', '£', '¤', '¥', '¦', '§', '¨', '©', 'ª', '«', '¬', '®', '¯', '°', '±', '²', '³', '´', 'µ', '¶', '·', '¸', '¹', 'º', '»', '¼', '½', '¾', '¿', 'À', 'Á', 'Â', 'Ã', 'Ä', 'Å', 'Æ', 'Ç', 'È', 'É', 'Ê', 'Ë', 'Ì', 'Í', 'Î', 'Ï', 'Ð', 'Ñ', 'Ò', 'Ó', 'Ô', 'Õ', 'Ö', '×', 'Ø', 'Ù', 'Ú', 'Û', 'Ü', 'Ý', 'Þ', 'ß', 'à', 'á', 'â', 'ã', 'ä', 'å', 'æ', 'ç', 'è', 'é', 'ê', 'ë', 'ì', 'í', 'î', 'ï', 'ð', 'ñ', 'ò', 'ó', 'ô', 'õ', 'ö', '÷', 'ø', 'ù', 'ú', 'û', 'ü', 'ý', 'þ', 'ÿ', 'Ā', 'ā', 'Ă', 'ă', 'Ą', 'ą', 'Ć', 'ć', 'Ĉ', 'ĉ', 'Ċ', 'ċ', 'Č', 'č', 'Ď', 'ď', 'Đ', 'đ', 'Ē', 'ē', 'Ĕ', 'ĕ', 'Ė', 'ė', 'Ę', 'ę', 'Ě', 'ě', 'Ĝ', 'ĝ', 'Ğ', 'ğ', 'Ġ', 'ġ', 'Ģ', 'ģ', 'Ĥ', 'ĥ', 'Ħ', 'ħ', 'Ĩ', 'ĩ', 'Ī', 'ī', 'Ĭ', 'ĭ', 'Į', 'į', 'İ', 'ı', 'Ĳ', 'ĳ', 'Ĵ', 'ĵ', 'Ķ', 'ķ', 'ĸ', 'Ĺ', 'ĺ', 'Ļ', 'ļ', 'Ľ', 'ľ', 'Ŀ', 'ŀ', 'Ł', 'ł', 'Ń']
        vocab = [...vocab, ...vocab.map((v) => v + '</w>')]
        for (const merge of merges) {
            vocab.push(merge.join(''))
        }
        vocab.push('<|startoftext|>', '<|endoftext|>')
        this.encoder = Object.fromEntries(vocab.map((v, i) => [v, i]))
        this.decoder = Object.fromEntries(Object.entries(this.encoder).map(([k, v]) => [v, k]))
        this.bpeRanks = Object.fromEntries(merges.map((v, i) => [v.join('·😎·'), i])) // ·😎· because js doesn't yet have tuples
        this.cache = { '<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>' }
    }

    private encoder: Record<string, number>
    private decoder: Record<number, string>
    private cache: Record<string, string>
    private bpeRanks: Record<string, number>
    private htmlEntities: { encode: (_: string) => string; decode: (_: string) => string }

    private bpe(token: string) {
        if (this.cache[token] !== undefined) {
            return this.cache[token]
        }

        let word = [...token.slice(0, -1), token.slice(-1) + '</w>']
        let pairs = get_pairs(word)

        if (pairs.size === 0) {
            return token + '</w>'
        }

        // eslint-disable-next-line no-constant-condition
        while (true) {
            let bigram = null
            let minRank = Number.POSITIVE_INFINITY
            for (const p of pairs) {
                const r = this.bpeRanks[p.join('·😎·')]
                if (r === undefined) continue
                if (r < minRank) {
                    minRank = r
                    bigram = p
                }
            }

            if (bigram === null) {
                break
            }

            const [first, second] = bigram
            const newWord = []
            let i = 0
            while (i < word.length) {
                const j = word.indexOf(first, i)

                if (j === -1) {
                    newWord.push(...word.slice(i))
                    break
                }

                newWord.push(...word.slice(i, j))
                i = j

                if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
                    newWord.push(first + second)
                    i += 2
                } else {
                    newWord.push(word[i])
                    i += 1
                }
            }
            word = newWord
            if (word.length === 1) {
                break
            } else {
                pairs = get_pairs(word)
            }
        }
        const joined = word.join(' ')
        this.cache[token] = joined
        return joined
    }

    encode(text: string): number[] {
        const bpeTokens = []
        text = whitespace_clean(basic_clean(bracket_clean(text), this.htmlEntities)).toLowerCase()
        for (let token of [...text.matchAll(clip_pattern)].map((m) => m[0])) {
            token = encodeStr(token)
                .map((b) => {
                    return byte_encoder[b]
                })
                .join('')
            bpeTokens.push(
                ...this.bpe(token)
                    .split(' ')
                    .map((bpe_token) => this.encoder[bpe_token])
            )
        }
        return bpeTokens
    }

    decode = (tokens: any[]): string => {
        const text = tokens.map((x) => this.decoder[x]).join('')
        const arr = [...text].flatMap((x) => {
            const converted = byte_decoder[x] ?? encodeStr(x)
            return converted
        })
        return decodeStr(arr).replace(/<\/w>/g, ' ')
    }

    tokensContaining = (str: string): { token: string; id: number }[] => {
        const keys = Object.keys(this.encoder)
        const arr = []
        for (const key of keys) {
            if (key.includes(str)) arr.push({ token: key, id: this.encoder[key] })
        }
        return arr
    }

    makeUnitrim(): number[] {
        const unicodeReq: number[] = []
        for (let i = 0; i < Object.keys(this.encoder).length; i++) {
            const v = this.decoder[i]
            let need = 0
            let min_need = 0
            // Turn the string into bytes.
            let bytes: number[] = []
            bytes = encodeStr(v)

            for (const c of bytes) {
                if ((c & 0b10000000) === 0) {
                    need = 0
                } else if ((c & 0b11000000) === 0b10000000) {
                    need -= 1
                } else if ((c & 0b11100000) === 0b11000000) {
                    need = 1
                } else if ((c & 0b11110000) === 0b11100000) {
                    need = 2
                } else if ((c & 0b11111000) === 0b11110000) {
                    need = 3
                }
                if (need < min_need) {
                    min_need = need
                }
            }
            if (need === 0) {
                need = min_need
            }
            unicodeReq.push(need)
        }

        return unicodeReq
    }

    totalTokens(): number {
        return Object.keys(this.encoder).length
    }
}
