import * as uuid from 'uuid';
import { OpenAI } from 'openai';
import { CloudFunctions } from '@/common/utils/cloud-functions-utils';
import { snippets2DocResponseToSwmd } from '@swimm/swmd';
import {
  type ChatTextCompletionPrompt,
  type LLMChatCompletionMessage,
  type LLMCodeCompletion,
  type LlmCallGenericRequest,
  type OpenAIModelName,
  QuotaExceededError,
  type SnippetsToDocPrompt,
  type TextCompletionParams,
  type TextCompletionPrompt,
  countTokensInPromptMessages,
  getMaxContextTokenCountForModel,
  getSnippets2DocPrompt,
  openAIModelValues,
} from '@swimm/shared';
import {
  type GeneratePRToDocRequest,
  GenerativeAIStreamingCloseReason,
  GenerativeAiRequestType,
  type LLMResponseFormat,
  StiggFeatures,
  type SwimmDocument,
  getLoggerNew,
  parseGenerateMermaidResponse,
  streamOpenAIResponse,
} from '@swimm/shared';
import { StatusCodes } from 'http-status-codes';
import { SWIMM_ONPREM_AGENT_CLOUD_RUN_URL, hostUrl } from '@/config';

const logger = getLoggerNew(__modulename);

export async function* createPRToDoc(
  request: GeneratePRToDocRequest,
  streamingEndpoint: string,
  readLines: (path: string, startLine: number, endLine: number) => Promise<string[]>
): AsyncGenerator<{
  type: string;
  swimmDocument?: SwimmDocument;
  reason?: string;
  code?: GenerativeAIStreamingCloseReason;
}> {
  try {
    const { swimmDocument, repoId, workspaceId } = request;
    if (!swimmDocument) {
      logger.error(
        { repoId: repoId, workspaceId: workspaceId },
        `Error generating PR to Doc: The request is missing Swimm Doc`
      );
      throw new Error('The request is missing Swimm Doc');
    }

    const prompt = getSnippets2DocPrompt(request);

    const stream = await streamTextFromOpenAI(prompt, swimmDocument, streamingEndpoint, readLines);

    for await (const message of stream) {
      yield message;
    }
  } catch (error) {
    yield {
      type: 'error',
      code: GenerativeAIStreamingCloseReason.ERROR_FROM_CLIENT,
      reason: error.message,
    };
  }
}

export async function streamTextFromOpenAI(
  { workspaceId, messages, options, requestId, shouldCancel }: SnippetsToDocPrompt,
  swimmDocument: SwimmDocument,
  streamingEndpoint: string,
  readLines: (path: string, startLine: number, endLine: number) => Promise<string[]>
): Promise<
  AsyncGenerator<{
    type: string;
    swimmDocument?: SwimmDocument;
    reason?: string;
    code?: GenerativeAIStreamingCloseReason;
  }>
> {
  const authToken = await CloudFunctions.getAuthToken();
  let model: OpenAIModelName = 'GPT_4_SHORT';

  const promptTokens = countTokensInPromptMessages(
    messages as LLMChatCompletionMessage[],
    openAIModelValues[model].tiktokenModelName,
    options.max_tokens
  );
  if (promptTokens > getMaxContextTokenCountForModel(model)) {
    logger.info(`Recived long prompt (${promptTokens}) in streamTextFromOpenAI, switching to GPT_4_LONG`);
    model = 'GPT_4_LONG';
  }

  const openAIParameters = {
    model,
    messages,
    ...options,
    stream: true,
  };

  const response = await fetch(`${streamingEndpoint}/generative-ai/completion/workspaces/${workspaceId}`, {
    headers: {
      Authorization: `Bearer ${authToken}`,
      Accept: 'text/event-stream',
      'Content-Type': 'application/json',
    },
    method: 'POST',
    body: JSON.stringify({ openAIParameters, askSwimmRequestId: requestId }),
  });

  if (!response.ok || !response.body) {
    if (response.status === StatusCodes.TOO_MANY_REQUESTS) {
      throw new QuotaExceededError(`Error streaming text from Azure OpenAI: Cap limit reached`);
    }
    throw new Error(`Error streaming text from Azure OpenAI: ${response.statusText}`);
  }
  return streamOpenAIResponse({
    response,
    swimmDocument,
    processOpenAIResponseSwmd: (response, swimmDocument) =>
      snippets2DocResponseToSwmd(response, swimmDocument, readLines),
    shouldCancel,
  });
}

export async function callGenerateText(
  requestParams: LlmCallGenericRequest,
  generateTextEndpoint: string,
  model: OpenAIModelName = 'GPT_3_5'
) {
  const body = {
    openAIParameters: {
      messages: requestParams.prompt.messages,
      ...requestParams.prompt.options,
      model,
    },
    askSwimmRequestId: requestParams.requestId,
    requestType: requestParams.type,
  };
  return callOpenAI(
    body,
    generateTextEndpoint,
    'generative-ai/generate-text',
    requestParams.workspaceId,
    requestParams.type
  ) as Promise<LLMResponseFormat>;
}

export async function callCompleteText(
  prompt: TextCompletionParams,
  serverUrl: string,
  workspaceId: string,
  route: string
) {
  const result = await callOpenAI(
    {
      openAIParameters: prompt.prompt,
      askSwimmRequestId: uuid.v4(),
      requestType: GenerativeAiRequestType.GENERATE_TEXT_COMPLETION,
    },
    serverUrl,
    route,
    workspaceId,
    GenerativeAiRequestType.GENERATE_TEXT_COMPLETION
  );
  switch (route) {
    case 'text-completion/create':
      return result as LLMResponseFormat & { cost: number };
    case 'generative-ai/generate-text':
      return { completion: result } as LLMCodeCompletion;
    default:
      throw new Error('Invalid route');
  }
}

async function callOpenAI(
  requestBody: {
    openAIParameters: TextCompletionPrompt | ChatTextCompletionPrompt | Partial<OpenAI.Chat.ChatCompletionCreateParams>;
    askSwimmRequestId: string;
    requestType: GenerativeAiRequestType;
  },
  generateTextEndpoint: string,
  route: string,
  workspaceId: string,
  requestType?: GenerativeAiRequestType
): Promise<LLMResponseFormat | LLMCodeCompletion | string> {
  try {
    const authToken = await CloudFunctions.getAuthToken();
    const response = await fetch(`${generateTextEndpoint}/${route}/workspaces/${workspaceId}`, {
      headers: {
        Authorization: `Bearer ${authToken}`,
        'Content-Type': 'application/json',
      },
      method: 'POST',
      body: JSON.stringify(requestBody),
    });
    if (!response.ok || !response.body) {
      if (response.status === StatusCodes.TOO_MANY_REQUESTS) {
        return { error: 'Quota exceeded for GenAI request', code: StatusCodes.TOO_MANY_REQUESTS };
      }
      return { error: 'Bad request', code: StatusCodes.BAD_REQUEST };
    }

    const result = await response.json();

    if (
      [
        GenerativeAiRequestType.GENERATE_TEXT_MODIFIER,
        GenerativeAiRequestType.GENERATE_SNIPPET_COMMENT,
        GenerativeAiRequestType.GENERATE_TEXT_COMPLETION,
      ].includes(requestType)
    ) {
      return JSON.parse(result.response ?? '{}') as LLMResponseFormat;
    }

    if (requestType === GenerativeAiRequestType.GENERATE_MERMAID) {
      try {
        return parseGenerateMermaidResponse(result.response as string);
      } catch (err) {
        logger.error({ err }, 'Failed to parse response from OpenAI');
        return { error: 'Failed to parse response from OpenAI', code: StatusCodes.INTERNAL_SERVER_ERROR };
      }
    }

    if (result.response) {
      return result.response;
    }

    return result;
  } catch (err) {
    logger.error({ err }, 'Failed to call OpenAI');
    return { error: 'Unknown error has occurred', code: StatusCodes.INTERNAL_SERVER_ERROR };
  }
}

export async function incrementGenerativeAICap(endpoint: string, workspaceId: string, requestId: string) {
  try {
    if (endpoint === SWIMM_ONPREM_AGENT_CLOUD_RUN_URL) {
      logger.debug(`Skipping incrementing cap for cloud environment`);
      return;
    }
    const authToken = await CloudFunctions.getAuthToken();
    const response = await fetch(
      `${hostUrl}api/allowance/increment/workspaces/${workspaceId}/features/${StiggFeatures.GENERATIVE_AI_CAP}`,
      {
        headers: { Authorization: `Bearer ${authToken}`, 'Content-Type': 'application/json' },
        method: 'POST',
        body: JSON.stringify({ requestId }),
      }
    );
    if (!response.ok) {
      throw new Error('Failed to increment feature usage for workspace');
    }
  } catch (err) {
    logger.error({ err }, `Error incrementing ${StiggFeatures.GENERATIVE_AI_CAP} feature usage`);
  }
}
