import { GraphStateType } from "duck/agents/common/utils/state";
import { JsonOutputToolsParser } from "@langchain/core/output_parsers/openai_tools";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { StructuredTool } from "@langchain/core/tools";
import { ChatOpenAI, ChatOpenAICallOptions } from "@langchain/openai";

import { NodeOutputType } from "./utils";

const toolDef = {
  type: "function",
  function: {
    name: "route",
    description: "Select the next node based on the last message",
    parameters: {
      title: "routeSchema",
      type: "object",
      properties: {
        next: {
          title: "Next",
          anyOf: [{ enum: ["agent", "reject", "clarify"] }],
        },
      },
      required: ["next"],
      additionalProperties: false,
    },
  },
} as const;

/**
 * @summary Create and return the node responsible for running the validation LLM.
 * @param llmAgent The LLM agent that performs the validation
 * @param tools The tools available to the LLM
 * @param validatePrompt The prompt to send to the LLM
 * @returns The node responsible for running the validation LLM
 */
const getValidateNode = (
  llmAgent: ChatOpenAI<ChatOpenAICallOptions>,
  tools: StructuredTool[],
  validatePrompt: ChatPromptTemplate
): ((state: GraphStateType) => Promise<NodeOutputType>) => {
  const validateLLM = llmAgent.bindTools([...tools, toolDef], {
    tool_choice: { type: "function", function: { name: "route" } },
    strict: true,
  });

  const validateChain = validatePrompt
    .pipe(validateLLM)
    .pipe(new JsonOutputToolsParser())
    // select the first one
    .pipe((x: any) => x[0].args);

  return async (state: GraphStateType) => {
    const response = await validateChain.invoke({
      messages: state.messages,
      current_state: JSON.stringify(state.pageState),
    });
    return response;
  };
};

export default getValidateNode;
