import { RequestParams } from "duck/agents/ClaimAnalytics/types";
import getPrompts from "duck/agents/ClaimAnalytics/utils/getPrompts";
import getTools from "duck/agents/ClaimAnalytics/utils/tools/getTools";
import getAgentNode from "duck/agents/common/utils/nodes/getAgentNode";
import getRejectOrClarifyNode from "duck/agents/common/utils/nodes/getRejectOrClarifyNode";
import getRespondToUserNode from "duck/agents/common/utils/nodes/getRespondToUserNode";
import getToolNode from "duck/agents/common/utils/nodes/getToolNode";
import getValidateNode from "duck/agents/common/utils/nodes/getValidateNode";
import { graphState, GraphStateType } from "duck/agents/common/utils/state";
import { Runnable } from "@langchain/core/runnables";
import { END, MemorySaver, START, StateGraph } from "@langchain/langgraph/web";
import { ChatOpenAI } from "@langchain/openai";

import { MODELSPEC, OPENAI_API_KEY } from "./constants";
import { shouldContinue } from "./utils";

/**
 * @summary Get an agent executor for the claim analytics expert.
 * @param params The parameters for the agent from the UI
 * @param withMemory True to use the memory saver, false to not use memory at all
 * @returns The agent to process the user utterance
 */
const getAgentExecutor = async (
  params: RequestParams,
  withMemory: boolean = false
): Promise<Runnable> => {
  const llmAgent = new ChatOpenAI({
    openAIApiKey: OPENAI_API_KEY,
    model: MODELSPEC.modelName,
    temperature: MODELSPEC.temperature,
    modelKwargs: MODELSPEC.modelKwargs,
  });

  const tools = getTools(params);

  const { claimAnalyticsPrompt, validatePrompt, rejectPrompt, clarifyPrompt } =
    await getPrompts();

  // Create agent executor with memory saver
  // TODO: MemorySaver is meant for experimatation and does not work in production
  //     Need to implement our own checkpoint saver (probably using postgres or local storage)
  const memory = withMemory ? new MemorySaver() : undefined;

  // define the state graph
  const stateGraph = new StateGraph(graphState)
    .addNode("agent", getAgentNode(llmAgent, tools, claimAnalyticsPrompt))
    .addNode("clarify", getRejectOrClarifyNode(llmAgent, tools, clarifyPrompt))
    .addNode("reject", getRejectOrClarifyNode(llmAgent, tools, rejectPrompt))
    .addNode("validate", getValidateNode(llmAgent, tools, validatePrompt))
    .addNode("tools", getToolNode(tools))
    .addNode("respond", getRespondToUserNode())
    .addEdge(START, "validate")
    .addConditionalEdges("validate", (x: GraphStateType) => x.next, {
      agent: "agent",
      clarify: "clarify",
      reject: "reject",
    })
    .addConditionalEdges("agent", shouldContinue, {
      tools: "tools",
      // respond: "respond",
      [END]: END,
    })
    .addEdge("tools", "agent")
    .addEdge("clarify", "respond")
    .addEdge("reject", "respond")
    .addEdge("respond", END);

  // compile the state graph with checkpointer
  const app = stateGraph.compile({
    checkpointer: memory,
  });

  return app;
};

export default getAgentExecutor;
