import { context_engine } from '@dropbox/api-v2-client/types/dropbox_types';
import { ContentCacheLoaded } from '@mirage/mosaics/ComposeAssistant/data/ComposeSourcesCache';
import { callAssistChatApi } from '@mirage/mosaics/ComposeAssistant/data/llm/llm-apis';
import { messageToPromptMessage } from '@mirage/mosaics/ComposeAssistant/data/llm/llm-prompts';
import { handleToolCall } from '@mirage/mosaics/ComposeAssistant/data/llm/llm-tools';
import {
  AssistantResponse,
  GetAssistantResponseComposeParams,
  GetAssistantResponseModifyVoiceParams,
  GetAssistantResponseParams,
} from '@mirage/mosaics/ComposeAssistant/data/llm/llm-types';
import { tagged } from '@mirage/service-logging';
import { getSourceUUID } from '@mirage/shared/compose/compose-session';
import { composeVoiceSamples } from '@mirage/shared/compose/compose-voice';
import { getSourceContentFromCache, getSources } from './tools/compose-tools';

const logger = tagged('ComposeAssistant/llm');

/**
 * Flag to print getAssistantResponse params and responses to logger.
 */
const LOG_ASSISTANT_RESPONSE_CALL = false;

/**
 * Get LLM response for the current session. This may involve multiple onResponse calls.
 * Upon full completion (after 0 or more onResponse calls), the returned Promise will be resolved.
 */
export async function getAssistantResponse(
  params: GetAssistantResponseParams,
  onResponse: (response: AssistantResponse) => void,
): Promise<void> {
  const callLogger = LOG_ASSISTANT_RESPONSE_CALL
    ? createAssistantResponseLogger(params, onResponse)
    : undefined;
  onResponse = callLogger ? callLogger.onResponse : onResponse;

  const history: context_engine.ChatMessage[] = [];
  for (const message of params.messagesHistory) {
    const historyMessage: context_engine.ChatMessage | undefined =
      messageToPromptMessage(message);
    if (historyMessage) {
      history.push(historyMessage);
    }
  }
  const { sources, indexedCachedSources } = getSources(
    params.sourcesContents,
    !params.mustIncludeSourceContents,
  );

  const newMessage = messageToPromptMessage(params.newMessage)!;
  const composeParams = makeCEComposeParams(params.composeParams);
  const modifyVoiceParams = makeCEModifyVoiceParams(params.modifyVoiceParams);

  const toolMessages: context_engine.ChatMessage[] = [];
  let isDone = false;
  while (!isDone) {
    const assistChatParams: context_engine.AssistChatParams = {
      history,
      new_message: newMessage,
      include_source_content: params.mustIncludeSourceContents,
      sources,
      tool_messages: toolMessages,
      dash_search_enabled: params.featureFlags.dashSearchEnabled,
    };
    if (modifyVoiceParams) {
      assistChatParams.modify_voice_params = modifyVoiceParams;
    } else if (composeParams) {
      assistChatParams.compose_params = composeParams;
    }
    const pendingMessages = await getAssistantResponseIteration(
      assistChatParams,
      indexedCachedSources,
      onResponse,
    );
    if (pendingMessages.length > 0) {
      toolMessages.push(...pendingMessages);
      isDone = false;
    } else {
      isDone = true;
    }
  }
  callLogger?.logCompletion();
}

/**
 * Run single iteration of LLM API call + response handling.
 * Returns new messages that should be appended for the next iteration.
 */
async function getAssistantResponseIteration(
  params: context_engine.AssistChatParams,
  indexedCachedSources: Map<number, ContentCacheLoaded>,
  onResponse: (response: AssistantResponse) => void,
): Promise<context_engine.ChatMessage[]> {
  const response = await callAssistChatApi({
    params,
  });
  if (response.toolCalls.length > 0) {
    const toolCallPromises = response.toolCalls.map((toolCall) =>
      handleToolCall(response, toolCall, indexedCachedSources, onResponse),
    );
    const toolCallResults = await Promise.all(toolCallPromises);
    const allMessages = toolCallResults.flat();

    if (allMessages.length == 0) return [];

    const toolMessages = allMessages.filter((m) => m['.tag'] == 'tool_message');
    return [response.responseMessage, ...toolMessages];
  } else {
    onResponse({
      type: 'message',
      responseText: response.responseText,
    });
    return [];
  }
}

function makeCEComposeParams(
  composeParams?: GetAssistantResponseComposeParams,
): context_engine.ComposeParams | null {
  if (!composeParams) {
    return null;
  }
  const voiceSources: context_engine.SourceContent[] = [];
  for (const cacheContent of Object.values(composeParams.voiceSourceContents)) {
    if (cacheContent.state !== 'loaded') {
      logger.error(
        'unable to load source content',
        getSourceUUID(cacheContent.source),
      );
      continue;
    }
    voiceSources.push(getSourceContentFromCache(cacheContent, false));
  }

  return {
    markdown_content: composeParams.markdownContent,
    must_generate_draft: composeParams.mustGenerateDraft,
    voice_id: composeParams.voiceID,
    voice_sources: voiceSources,
    voice_modification_history: composeParams.voiceModificationHistory,
  };
}

function makeCEModifyVoiceParams(
  modifyVoiceParams?: GetAssistantResponseModifyVoiceParams,
): context_engine.ModifyVoiceParams | null {
  if (!modifyVoiceParams) {
    return null;
  }
  const voiceSources: context_engine.SourceContent[] = [];
  for (const cacheContent of Object.values(
    modifyVoiceParams.voiceSourceContents,
  )) {
    if (cacheContent.state !== 'loaded') {
      logger.error(
        'unable to load source content',
        getSourceUUID(cacheContent.source),
      );
      continue;
    }
    voiceSources.push(getSourceContentFromCache(cacheContent, false));
  }

  return {
    voice_id: modifyVoiceParams.voiceID,
    voice_samples: composeVoiceSamples,
    voice_modification_history: modifyVoiceParams.voiceModificationHistory,
    voice_sources: voiceSources,
  };
}

export function createAssistantResponseLogger(
  params: GetAssistantResponseParams,
  onResponse: (response: AssistantResponse) => void,
) {
  const recordedResponses: AssistantResponse[] = [];
  const wrappedOnResponse = (response: AssistantResponse) => {
    recordedResponses.push(response);
    onResponse(response);
  };
  return {
    onResponse: wrappedOnResponse,
    logCompletion: () => {
      logger.log('getAssistantResponse completion', {
        params,
        responses: recordedResponses,
      });
    },
    getResponses: () => recordedResponses,
  };
}
