import { Tiktoken, TiktokenModel, encodingForModel } from 'js-tiktoken';
import { OpenAIModelName, isOpenAIModelName, openAIModelValues } from './openai-models';
import OpenAI from 'openai';

export type LLMChatCompletionMessage = OpenAI.Chat.ChatCompletionMessageParam & { content: string | null };
export type LLMChatCompletionFunctionCall = OpenAI.ChatCompletionMessageToolCall;

const encodingCache = new Map<TiktokenModel, Tiktoken>();

export const getEncodingForModelCached = (model: TiktokenModel) => {
  let encoding = encodingCache.get(model);
  if (encoding) {
    return encoding;
  }
  encodingCache.set(model, (encoding = encodingForModel(model)));
  return encoding;
};

export function countTokens(text: string, model: OpenAIModelName | TiktokenModel): number {
  const tiktokenModel: TiktokenModel = isOpenAIModelName(model) ? openAIModelValues[model].tiktokenModelName : model;
  const encoding = getEncodingForModelCached(tiktokenModel);
  return encoding.encode(text).length;
}

// These numbers are based on: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const EXTRA_TOKENS_PER_MESSAGE = 4; // officially it's 4 for gpt-3.5-turbo and 3 for the others, but seems like it's needed for gpt-4 as well.
export const EXTRA_TOKENS_PER_REQUEST = 3; // from open-ai: every reply is primed with <|start|>assistant<|message|>

export function countTokensInMessage(message: LLMChatCompletionMessage, countTokens: (text: string) => number): number {
  return (
    EXTRA_TOKENS_PER_MESSAGE +
    (message.content ? countTokens(message.content) : 0) +
    (message.role === 'assistant' && message.tool_calls?.length
      ? message.tool_calls.reduce(
          (total, toolCall) => total + countTokens(`${toolCall.function.name} ${toolCall.function.arguments}`),
          0
        )
      : 0)
  );
}

export function countTokensInPromptMessages(
  messages: LLMChatCompletionMessage[],
  model: OpenAIModelName | TiktokenModel,
  maxCompletionTokens?: number
): number {
  const countTokensFunction = (text) => countTokens(text, model);
  return (
    messages.reduce((sum, message) => sum + countTokensInMessage(message, countTokensFunction), 0) +
    EXTRA_TOKENS_PER_REQUEST +
    (maxCompletionTokens ?? 0)
  );
}
