program TransformerDemo;
{$MODE OBJFPC}{$H+}
{$RANGECHECKS ON}

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}


uses
  Transformer, MatrixOps, DataUtils, SysUtils, Math;

var
  model: TTransformer;
  config: TTransformerConfig;
  input, output, target, gradOutput, oldEmbedding: TDoubleMatrix;
  i,j,k,l: Integer;
  loss, initialLoss, finalLoss: Double;
stepLoss: Double;
begin
  WriteLn('=== ТЕСТ ТРАНСФОРМЕРА ===');
  WriteLn;

  // Простая конфигурация для теста
  config.InputSize := 64;
  config.NumLayers := 2;
  config.NumHeads := 2;
  config.FFNDim := 128;
  config.MaxSeqLength := 50;
  config.DropoutRate := 0.1;
  config.WeightDecay := 0.0001;
  config.GradientClipValue := 1.0;
  config.UseLayerNorm := True;

  WriteLn('Конфигурация модели:');
  WriteLn('  InputSize: ', config.InputSize);
  WriteLn('  NumLayers: ', config.NumLayers);
  WriteLn('  NumHeads: ', config.NumHeads);
  WriteLn('  FFNDim: ', config.FFNDim);
  WriteLn('  MaxSeqLength: ', config.MaxSeqLength);
  WriteLn;

  // Инициализация
  WriteLn('1. Инициализация трансформера...');
  InitTransformer(model, config);
  CheckModelDimensions(model);
  WriteLn('✓ Модель инициализирована успешно');
  WriteLn;

  // Тестовые данные
  WriteLn('2. Создание тестовых данных...');
  input := CreateRandomMatrix(5, config.InputSize, -1.0, 1.0); // 5 токенов
  target := CreateRandomMatrix(5, config.InputSize, -0.5, 0.5); // Целевые значения

  WriteLn('   Вход: ', Length(input), 'x', Length(input[0]));
  WriteLn('   Цель: ', Length(target), 'x', Length(target[0]));
  WriteLn;

  // Тест 1: Прямой проход
  WriteLn('3. Тест прямого прохода...');
  try
    ForwardTransformer(model, input, output, nil, True); // isTraining = True
    WriteLn('   ✓ Прямой проход выполнен успешно');
    WriteLn('   Размер вывода: ', Length(output), 'x', Length(output[0]));

    // Проверка что выход не содержит NaN/Inf
    for i := 0 to High(output) do
      for j := 0 to High(output[0]) do
        if IsNan(output[i][j]) or IsInfinite(output[i][j]) then
          WriteLn('   ⚠ Внимание: обнаружены некорректные значения в выводе');

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка прямого прохода: ', E.Message);
      Exit;
    end;
  end;
  WriteLn;

  // Тест 2: Вычисление потерь
  WriteLn('4. Вычисление начальных потерь...');
  try
    initialLoss := 0.0;
    for i := 0 to High(output) do
      for j := 0 to High(output[0]) do
        initialLoss := initialLoss + Sqr(output[i][j] - target[i][j]);

    initialLoss := initialLoss / (Length(output) * Length(output[0]));
    WriteLn('   Начальные потери (MSE): ', initialLoss:0:6);

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка вычисления потерь: ', E.Message);
      Exit;
    end;
  end;
  WriteLn;

  // Тест 3: Обратный проход
  WriteLn('5. Тест обратного прохода...');
  try
    // Создаем градиент (разность между выходом и целью)
    SetLength(gradOutput, Length(output), Length(output[0]));
    for i := 0 to High(output) do
      for j := 0 to High(output[0]) do
        gradOutput[i][j] := 2.0 * (output[i][j] - target[i][j]) / 
                           (Length(output) * Length(output[0]));

    WriteLn('   Градиент: ', Length(gradOutput), 'x', Length(gradOutput[0]));

    // Выполняем обратный проход
    BackwardTransformer(model, input, gradOutput);
    WriteLn('   ✓ Обратный проход выполнен успешно');

    // Проверяем что градиенты были вычислены
    WriteLn('   Проверка градиентов...');
    if Length(model.Embedding_Grad) > 0 then
      WriteLn('   ✓ Градиенты эмбеддингов вычислены')
    else
      WriteLn('   ⚠ Градиенты эмбеддингов пусты');

    for i := 0 to High(model.Layers) do
    begin
      if Length(model.Layers[i].FFN1_Grad) > 0 then
        WriteLn('   ✓ Градиенты FFN1 слоя ', i, ' вычислены')
      else
        WriteLn('   ⚠ Градиенты FFN1 слоя ', i, ' пусты');
    end;

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка обратного прохода: ', E.Message);
      Exit;
    end;
  end;
  WriteLn;

  // Тест 4: Обновление весов
  WriteLn('6. Тест обновления весов...');
  try
    WriteLn('   Обновление весов с learning rate = 0.01...');
    UpdateTransformer(model, 0.01);
    WriteLn('   ✓ Веса успешно обновлены');

    // Проверяем что веса изменились
    WriteLn('   Проверка изменений весов...');

    // Сохраняем старые веса для сравнения
    oldEmbedding := CopyMatrix(model.Embedding);

    // Еще один прямой проход с обновленными весами
    ForwardTransformer(model, input, output, nil, True);

    // Вычисляем новые потери
    finalLoss := 0.0;
    for i := 0 to High(output) do
      for j := 0 to High(output[0]) do
        finalLoss := finalLoss + Sqr(output[i][j] - target[i][j]);

    finalLoss := finalLoss / (Length(output) * Length(output[0]));
    WriteLn('   Потери после обновления: ', finalLoss:0:6);

    if finalLoss < initialLoss then
      WriteLn('   ✓ Потери уменьшились - обучение работает!')
    else if Abs(finalLoss - initialLoss) < 0.001 then
      WriteLn('   ⚠ Потери не изменились (возможно, маленький learning rate)')
    else
      WriteLn('   ⚠ Потери увеличились (возможно, слишком большой learning rate)');

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка обновления весов: ', E.Message);
      Exit;
    end;
  end;
  WriteLn;

  // Тест 5: Gradient Clipping
  WriteLn('7. Тест gradient clipping...');
  try
    // Создаем искусственно большой градиент
    SetLength(gradOutput, Length(output), Length(output[0]));
    for i := 0 to High(gradOutput) do
      for j := 0 to High(gradOutput[0]) do
        gradOutput[i][j] := 100.0; // Большой градиент

    ApplyGradientClippingToModel(model, config.GradientClipValue);
    WriteLn('   ✓ Gradient clipping выполнен успешно');

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка gradient clipping: ', E.Message);
    end;
  end;
  WriteLn;

  // Тест 6: Multiple training steps
  WriteLn('8. Тест нескольких шагов обучения...');
  try
    WriteLn('   Выполняем 3 шага обучения...');

    for i := 1 to 3 do
    begin
      // Прямой проход
      ForwardTransformer(model, input, output, nil, True);

      // Вычисление потерь
      stepLoss := 0.0;
      for k := 0 to High(output) do
        for l := 0 to High(output[0]) do
          stepLoss := stepLoss + Sqr(output[k][l] - target[k][l]);

      stepLoss := stepLoss / (Length(output) * Length(output[0]));

      // Обратный проход
      for k := 0 to High(output) do
        for l := 0 to High(output[0]) do
          gradOutput[k][l] := 2.0 * (output[k][l] - target[k][l]) / 
                             (Length(output) * Length(output[0]));

      BackwardTransformer(model, input, gradOutput);
      UpdateTransformer(model, 0.001); // Меньший learning rate

      WriteLn('   Шаг ', i, ': потери = ', stepLoss:0:6);
    end;

    WriteLn('   ✓ Множественные шаги обучения выполнены успешно');

  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка множественных шагов: ', E.Message);
    end;
  end;
  WriteLn;

  // Очистка
  WriteLn('9. Очистка ресурсов...');
  try
    FreeTransformer(model);
    WriteLn('   ✓ Ресурсы освобождены успешно');
  except
    on E: Exception do
    begin
      WriteLn('   ✗ Ошибка очистки: ', E.Message);
    end;
  end;

  WriteLn;
  WriteLn('=== ТЕСТ ЗАВЕРШЕН ===');
  WriteLn('Все основные функции трансформера протестированы.');
  WriteLn('Если вы видите это сообщение, трансформер работает корректно!');
end.