import { StatusCodes } from 'http-status-codes';
import { OpenAI } from 'openai';
import { AiClient } from './AiClient';
import { StiggFeatures, TextCompletionRequest } from '../types';
import { computeAiModelRequestedUsage } from '../billing/utils';
import { OpenAIModelName } from '../code-analysis';
import { QuotaExceededError } from './utils';
import { removeSuffix } from '../utils/string-utils';

export class OnPremAiClient extends AiClient {
  private readonly usageHost: string;

  constructor(getDBAuthToken: () => Promise<string>, host: string, usageHost: string) {
    super(getDBAuthToken, host);
    this.usageHost = removeSuffix(usageHost, '/');
  }

  private async checkAllowed(
    workspaceId: string,
    featureId: StiggFeatures,
    model: OpenAIModelName,
    prompts: string[],
    maxTokens?: number
  ) {
    const authToken = await this.getDBAuthToken();
    const requestedUsage = computeAiModelRequestedUsage(featureId, prompts, model, maxTokens);
    const canRunRequest = await fetch(
      `${this.usageHost}/ask-swimm-backend/allowance/validate/workspaces/${workspaceId}/features/${featureId}/${requestedUsage}`,
      {
        headers: { Authorization: `Bearer ${authToken}` },
      }
    );
    if (canRunRequest.status === StatusCodes.TOO_MANY_REQUESTS) {
      throw new QuotaExceededError(`Quota exceeded for feature ${featureId} in workspace ${workspaceId}`);
    }
  }

  private async incrementUsage(workspaceId: string, featureId: StiggFeatures, requestId: string, consumption: number) {
    const authToken = await this.getDBAuthToken();
    const response = await fetch(
      `${this.usageHost}/ask-swimm-backend/allowance/increment/workspaces/${workspaceId}/features/${featureId}`,
      {
        headers: { Authorization: `Bearer ${authToken}`, 'Content-Type': 'application/json' },
        method: 'POST',
        body: JSON.stringify({ requestId, usageCount: consumption }),
      }
    );
    if (!response.ok) {
      throw new Error(`Failed to increment feature ${featureId} usage for workspace ${workspaceId}`);
    }
  }

  protected override async _completeText(
    request: TextCompletionRequest
  ): Promise<{ generatedText: string; usage?: OpenAI.Completions.CompletionUsage; cost: number }> {
    const prompts = [request.openAIParameters.prompt];
    if (request.openAIParameters.suffix) {
      prompts.push(request.openAIParameters.suffix);
    }
    await this.checkAllowed(
      request.workspaceId,
      StiggFeatures.TEXT_COMPLETION_CAP,
      request.openAIParameters.model,
      prompts,
      request.openAIParameters.max_tokens
    );
    const response = await super._completeText(request);
    await this.incrementUsage(request.workspaceId, StiggFeatures.TEXT_COMPLETION_CAP, request.requestId, response.cost);
    return response;
  }
}
