import { MODEL_MESSAGE_ROLES, ModelTrainingExample } from "@super-real/types";
import { MODEL_MESSAGE_ROLE_RECORD } from "../consts/MODEL_MESSAGE_ROLE_RECORD";
import { MODEL_NEW } from "../consts/MODEL_NEW";

export const toModelTrainingExamples = (value: string) => {
  const rows = value.split("\n");
  let currentTrainingExample: ModelTrainingExample = { messages: [] };
  const trainingExamples: ModelTrainingExample[] = [currentTrainingExample];

  if (rows[0] === MODEL_NEW) {
    throw new Error(`First row cannot be ${MODEL_NEW}`);
  }

  if (rows[rows.length - 1] === MODEL_NEW) {
    throw new Error(`Last row cannot be ${MODEL_NEW}`);
  }

  rows.forEach((row, index) => {
    if (row === MODEL_NEW) {
      currentTrainingExample = { messages: [] };
      trainingExamples.push(currentTrainingExample);
      return;
    }

    for (const role of MODEL_MESSAGE_ROLES) {
      const recordValue = MODEL_MESSAGE_ROLE_RECORD[role];
      const startString = `${recordValue.short}: `;

      if (row.startsWith(startString)) {
        currentTrainingExample.messages.push({
          role,
          content: row.replace(startString, "").trim(),
        });
        return;
      }
    }

    const content = shortenString(row.trim()) || "empty line";
    throw new Error(`Line #${index + 1} is invalid: "${content}"`);
  });

  if (trainingExamples.length < 20) {
    throw new Error(
      `Training file has ${trainingExamples.length} examples, but must have at least 20`
    );
  }

  trainingExamples.forEach(({ messages }, index) => {
    const messagesWithoutSystem = messages.filter(
      ({ role }) => role !== "SYSTEM"
    );
    const hasMinimumMessages = messages.length >= 2;

    if (!hasMinimumMessages) {
      throw new Error(`Example #${index + 1} must have at least 2 messages`);
    }

    const firstMessage = messagesWithoutSystem[0];
    const startsWithUserMessage = firstMessage.role === "USER";

    if (!startsWithUserMessage) {
      const content = shortenString(firstMessage.content);
      throw new Error(
        `Example #${index + 1} must start with a user message: "${content}"`
      );
    }

    const lastMessage = messages[messages.length - 1];
    const endsWithAssistantMessage = lastMessage.role === "ASSISTANT";

    if (!endsWithAssistantMessage) {
      const content = shortenString(lastMessage.content);
      throw new Error(
        `Example #${index + 1} must end with an assistant message: "${content}"`
      );
    }

    const hasAlternatingRoles = messagesWithoutSystem.every(
      ({ role }, index) => {
        const expectedRole = index % 2 === 0 ? "USER" : "ASSISTANT";
        return role === expectedRole;
      }
    );

    if (!hasAlternatingRoles) {
      throw new Error(`Example #${index + 1} must have alternating roles`);
    }
  });

  return trainingExamples;
};

function shortenString(value: string, maxLength = 100) {
  if (value.length <= maxLength) {
    return value;
  }

  return `${value.slice(0, maxLength - 3)}...`;
}
