import { LogProbs, LogProbToken } from '../data/request/remoterequest'
import Encoder from './encoder'
import { EncoderType } from './enums'
import { genjiUnitrim } from './unitrim/genji'
import { gpt2Unitrim } from './unitrim/gpt2'
import { pileUnitrim } from './unitrim/pile'
import { nerdstashUnitrim } from './unitrim/nerdstash'
import { nerdstashV2Unitrim } from './unitrim/nerdstash_v2'
import { clipUnitrim } from './unitrim/clip'
import { llama3Unitrim } from './unitrim/llama3'

export const PILE_NAI_EXTRA_TOKENS = {
    '─': 50257,
    ' ': 50258,
    ' ': 50259,
    '⁂': 50260,
}

export const NAI_INLINE_EXTRA_TOKENS = {
    '<|infillstart|>': 50257,
    '<|infillend|>': 50258,
    '<|masklen1|>': 50259,
    '<|masklen2|>': 50260,
    '<|masklen3|>': 50261,
    '<|masklen4|>': 50262,
    '<|standartmask|>': 50263,
}

export function getTokenizerFileUrl(tokenizer: EncoderType): string {
    switch (tokenizer) {
        case EncoderType.Pile:
        case EncoderType.PileNAI: {
            return 'pile_tokenizer.def'
        }
        case EncoderType.Genji: {
            return 'genji_tokenizer.def'
        }
        case EncoderType.CLIP: {
            return 'clip_tokenizer.def'
        }
        case EncoderType.Nerdstash: {
            return 'nerdstash_tokenizer.def'
        }
        case EncoderType.NerdstashV2: {
            return 'nerdstash_tokenizer_v2.def'
        }
        case EncoderType.Llama3: {
            return 'llama3nai_tokenizer.def'
        }
        default: {
            return 'gpt2_tokenizer.def'
        }
    }
}

export function getTokenizerExtraTokens(tokenizer: EncoderType): any {
    switch (tokenizer) {
        case EncoderType.PileNAI: {
            return PILE_NAI_EXTRA_TOKENS
        }
        case EncoderType.NAIInline: {
            return NAI_INLINE_EXTRA_TOKENS
        }
        default: {
            return {}
        }
    }
}

export function getUnitrim(tokenizer: EncoderType): { [key: number]: number } {
    switch (tokenizer) {
        case EncoderType.Pile:
        case EncoderType.PileNAI: {
            return pileUnitrim
        }
        case EncoderType.Genji: {
            return genjiUnitrim
        }
        case EncoderType.CLIP: {
            return clipUnitrim
        }
        case EncoderType.Nerdstash: {
            return nerdstashUnitrim
        }
        case EncoderType.NerdstashV2: {
            return nerdstashV2Unitrim
        }
        case EncoderType.GPT2:
        case EncoderType.NAIInline: {
            return gpt2Unitrim
        }
        case EncoderType.Llama3: {
            return llama3Unitrim
        }
    }
}

export function generateBanString(encoder: Encoder, input: string[]): string {
    const tokens: {
        token: string
        id: number
    }[] = []
    for (const s of input) {
        const arr = encoder.tokensContaining(s)
        for (const a of arr) {
            if (tokens.some((t) => t.id === a.id)) {
                //
            } else {
                tokens.push(a)
            }
        }
    }
    tokens.sort((a, b) => a.id - b.id)
    let str = ''
    for (const { token, id } of tokens) {
        str += `[${id}], //${token}\n`
    }
    return str
}

export function checkNeed(tokens: number[], encoderType: EncoderType): { complete: boolean; error: boolean } {
    const table = getUnitrim(encoderType)
    let need = 0
    let nonZero = false
    for (const token of tokens) {
        const val = table[token] ?? 0
        need += val
        if (val === 0 && nonZero) {
            return { complete: false, error: true }
        } else if (val !== 0) {
            nonZero = true
        }
    }
    return { complete: need === 0, error: false }
}

export function cutNeedyTokens(tokens: number[], encoderType: EncoderType): number[] {
    // Cut off tokens on the end that don't result in a complete unitrim
    const table = getUnitrim(encoderType)
    let need = 0
    const result: number[] = []
    let holding: number[] = []
    for (const token of tokens) {
        const val = table[token] ?? 0
        holding.push(token)
        need += val
        if (need === 0) {
            result.push(...holding)
            holding = []
        }
    }
    return result
}

export function groupMultiCharacterTokens(tokens: number[], encoderType: EncoderType): number[][] {
    const grouped: number[][] = []
    let carry: number[] = []
    for (const token of tokens) {
        carry = [...carry, token]
        const need = checkNeed(carry, encoderType)
        if (need.complete) {
            grouped.push([...carry])
            carry = []
        } else if (need.error) {
            grouped.push([...carry.slice(0, -1)], [...carry.slice(-1)])
            carry = []
        }
    }
    return grouped
}

export function groupMultiCharacterLogprobs(logprobs: LogProbs[], encoderType: EncoderType): LogProbs[][] {
    const grouped: LogProbs[][] = []
    let carry: LogProbs[] = []
    for (const logprob of logprobs) {
        carry = [...carry, logprob]
        const need = checkNeed(
            carry.map((c) => c.chosen.token),
            encoderType
        )
        if (need.complete) {
            grouped.push([...carry])
            carry = []
        } else if (need.error) {
            grouped.push([...carry.slice(0, -1)], [...carry.slice(-1)])
            carry = []
        }
    }
    return grouped
}

export function mergeLogprobs(logprob: LogProbs): LogProbToken[] {
    const chosen = logprob.chosen
    const before = [...logprob.befores]
    const after = [...logprob.afters]
    // Combine lists. A token can be in both befores, afters and the chosen token.
    const combined = new Map<number, LogProbToken>()
    for (const a of after) {
        combined.set(a.token, a)
    }
    for (const b of before) {
        combined.set(b.token, b)
    }
    combined.set(chosen.token, chosen)
    return [...combined.values()].sort(
        (a, b) => (b.after ?? Number.NEGATIVE_INFINITY) - (a.after ?? Number.NEGATIVE_INFINITY)
    )
}
