unit Attention;
{$MODE OBJFPC}{$H+}{$RANGECHECKS ON}
{$ASMMODE INTEL}

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}


interface

uses
  SysUtils, Math, DataUtils, MatrixOps, Optimizers;

type
  TAttentionHead = record
    // Веса
    Wq, Wk, Wv: TDoubleMatrix; // [inputSize x headSize]
    Wo: TDoubleMatrix;         // [headSize x inputSize]

    // Градиенты
    dWq, dWk, dWv, dWo: TDoubleMatrix;

    // Состояния Adam
    Wq_AdamState, Wk_AdamState, Wv_AdamState, Wo_AdamState: TAdamState;

    // ✅ ОПТИМИЗАЦИЯ: Кэши для ускорения
    cachedQ, cachedK, cachedV: TDoubleMatrix;
    cachedK_T: TDoubleMatrix;          // Транспонированная K
    cachedHeadOutput: TDoubleMatrix;   // Выход до проекции Wo
    attentionWeights: TDoubleMatrix;   // Веса внимания после softmax
  end;

  TMultiHeadAttention = record
    Heads: array of TAttentionHead;
    HeadSize: Integer;
    NumHeads: Integer;
    cachedOutput: TDoubleMatrix;
  end;

procedure InitializeAttentionHead(var head: TAttentionHead; inputSize, headSize: Integer);
procedure InitializeMultiHeadAttention(var mha: TMultiHeadAttention; inputSize, headSize, numHeads: Integer);

// ✅ ОПТИМИЗАЦИЯ: Новая улучшенная версия attention
function OptimizedScaledDotProductAttention(const Q, K_T, V: TDoubleMatrix; 
                                          mask: TDoubleMatrix = nil): TDoubleMatrix;

procedure MultiHeadAttentionForward(var mha: TMultiHeadAttention; 
                                  const input: TDoubleMatrix;
                                  out output: TDoubleMatrix;
                                  mask: TDoubleMatrix = nil);
procedure MultiHeadAttentionBackward(var mha: TMultiHeadAttention; 
                                   const input, gradOutput: TDoubleMatrix);
procedure FreeMultiHeadAttention(var mha: TMultiHeadAttention);

function CreateFutureMask(seqLength: Integer): TDoubleMatrix;
function CreatePaddingMask(input: TDoubleMatrix; paddingValue: Double = 0): TDoubleMatrix;

function ScaledDotProductAttentionBackward(var head: TAttentionHead;
                                         const gradOutput: TDoubleMatrix;
                                         const input: TDoubleMatrix): TDoubleMatrix;
function AttentionBackward(var attention: TMultiHeadAttention;
                         const gradOutput: TDoubleMatrix;
                         const attnInput: TDoubleMatrix): TDoubleMatrix;

procedure UpdateAttentionLayer(var attention: TMultiHeadAttention; learningRate: Double);
procedure CheckAttentionWeights(const Attention: TMultiHeadAttention);

// ✅ ОПТИМИЗАЦИЯ: Вспомогательные функции
function SimpleAttentionBackward(var head: TAttentionHead;
                               const gradOutput: TDoubleMatrix;
                               const input: TDoubleMatrix): TDoubleMatrix;

implementation

{$I asmf.inc}

procedure InitializeAttentionHead(var head: TAttentionHead; inputSize, headSize: Integer);
begin
  WriteLn('    InitializeAttentionHead: ', inputSize, 'x', headSize);

  // Инициализация весов
  head.Wq := RandomMatrix(inputSize, headSize, -0.1, 0.1);
  head.Wk := RandomMatrix(inputSize, headSize, -0.1, 0.1);
  head.Wv := RandomMatrix(inputSize, headSize, -0.1, 0.1);
  head.Wo := RandomMatrix(headSize, inputSize, -0.1, 0.1);

  // Инициализация градиентов
  SetLength(head.dWq, inputSize, headSize);
  SetLength(head.dWk, inputSize, headSize);
  SetLength(head.dWv, inputSize, headSize);
  SetLength(head.dWo, headSize, inputSize);

  FillMatrix(head.dWq, 0.0);
  FillMatrix(head.dWk, 0.0);
  FillMatrix(head.dWv, 0.0);
  FillMatrix(head.dWo, 0.0);

  // Инициализация состояний Adam
  InitAdamState(head.Wq_AdamState, inputSize, headSize);
  InitAdamState(head.Wk_AdamState, inputSize, headSize);
  InitAdamState(head.Wv_AdamState, inputSize, headSize);
  InitAdamState(head.Wo_AdamState, headSize, inputSize);

  // ✅ ОПТИМИЗАЦИЯ: Инициализируем кэши как пустые массивы
  SetLength(head.cachedQ, 0, 0);
  SetLength(head.cachedK, 0, 0);
  SetLength(head.cachedV, 0, 0);
  SetLength(head.cachedK_T, 0, 0);
  SetLength(head.cachedHeadOutput, 0, 0);
  SetLength(head.attentionWeights, 0, 0);
end;

procedure InitializeMultiHeadAttention(var mha: TMultiHeadAttention; 
                                     inputSize, headSize, numHeads: Integer);
var
  i: Integer;
begin
  WriteLn('InitializeMultiHeadAttention:');
  WriteLn('  inputSize: ', inputSize);
  WriteLn('  headSize: ', headSize);
  WriteLn('  numHeads: ', numHeads);

  mha.HeadSize := headSize;
  mha.NumHeads := numHeads;
  SetLength(mha.Heads, numHeads);

  for i := 0 to numHeads - 1 do
  begin
    WriteLn('  Инициализация головы ', i, ':');
    InitializeAttentionHead(mha.Heads[i], inputSize, headSize);
  end;

  SetLength(mha.cachedOutput, 0, 0);
end;

// ✅ ОПТИМИЗАЦИЯ: Улучшенная версия Scaled Dot-Product Attention
function OptimizedScaledDotProductAttention(const Q, K_T, V: TDoubleMatrix; 
                                          mask: TDoubleMatrix = nil): TDoubleMatrix;
var
  scores: TDoubleMatrix;
  scaleFactor: Double;
  i, j: Integer;
begin
  WriteLn('      OptimizedScaledDotProductAttention:');
  WriteLn('        Q: ', Length(Q), 'x', Length(Q[0]));
  WriteLn('        K_T: ', Length(K_T), 'x', Length(K_T[0]));
  WriteLn('        V: ', Length(V), 'x', Length(V[0]));

  // ✅ ОПТИМИЗАЦИЯ: K уже транспонирована - экономия времени
  scores := MatrixMultiply(Q, K_T);
  WriteLn('        scores: ', Length(scores), 'x', Length(scores[0]));

  // Масштабирование с защитой от деления на ноль
  if (Length(K_T) > 0) and (Length(K_T[0]) > 0) then
    scaleFactor := 1.0 / Sqrt(Length(K_T[0]))
  else
    scaleFactor := 1.0;

  ScaleMatrix(scores, scaleFactor);

  // Применяем маску
  if mask <> nil then
  begin
    WriteLn('        Применяем маску: ', Length(mask), 'x', Length(mask[0]));
    for i := 0 to High(scores) do
      for j := 0 to High(scores[i]) do
        scores[i][j] := scores[i][j] + mask[i][j];
  end;

  // Сохраняем веса внимания перед softmax (для backward pass)
  // attentionWeights будет сохранен в вызывающей функции

  // Softmax
  WriteLn('        Softmax...');
  Result := MatrixMultiply(Softmax(scores), V);
  WriteLn('        Result: ', Length(Result), 'x', Length(Result[0]));
end;

procedure MultiHeadAttentionForward(var mha: TMultiHeadAttention; 
                                  const input: TDoubleMatrix;
                                  out output: TDoubleMatrix;
                                  mask: TDoubleMatrix = nil);
var
  i: Integer;
  Q, K, V, headOutput: TDoubleMatrix;
  scores, attentionWeights: TDoubleMatrix;
  scaleFactor: Double;
begin
  WriteLn('        MultiHeadAttentionForward (оптимизированная версия):');
  WriteLn('          input: ', Length(input), 'x', Length(input[0]));

  // Инициализируем output нулями
  SetLength(output, Length(input), Length(input[0]));
  FillMatrix(output, 0.0);

  for i := 0 to mha.NumHeads - 1 do
  begin
    try
      WriteLn('          Head ', i, ':');

      // Линейные преобразования
      Q := MatrixMultiply(input, mha.Heads[i].Wq);
      K := MatrixMultiply(input, mha.Heads[i].Wk);
      V := MatrixMultiply(input, mha.Heads[i].Wv);

      // ✅ ОПТИМИЗАЦИЯ: Сохраняем для backward pass
      mha.Heads[i].cachedQ := CopyMatrix(Q);
      mha.Heads[i].cachedK := CopyMatrix(K);
      mha.Heads[i].cachedV := CopyMatrix(V);

      // ✅ ОПТИМИЗАЦИЯ: Транспонируем K один раз и сохраняем
      mha.Heads[i].cachedK_T := TransposeMatrix(K);

      // ✅ ОПТИМИЗАЦИЯ: Используем оптимизированную версию attention
      headOutput := OptimizedScaledDotProductAttention(Q, mha.Heads[i].cachedK_T, V, mask);

      // ✅ ОПТИМИЗАЦИЯ: Сохраняем выход головы ДО проекции Wo
      mha.Heads[i].cachedHeadOutput := CopyMatrix(headOutput);

      // Выходная проекция
      headOutput := MatrixMultiply(headOutput, mha.Heads[i].Wo);

      // Суммирование выходов голов
      output := MatrixAdd(output, headOutput);

    except
      on E: Exception do
      begin
        WriteLn('          ОШИБКА в голове ', i, ': ', E.Message);
        // Пропускаем эту голову
      end;
    end;
  end;

  // Усреднение выходов голов
  ScaleMatrix(output, 1.0 / mha.NumHeads);

  // Сохраняем выход для backward pass
  mha.cachedOutput := CopyMatrix(output);

  WriteLn('          final output: ', Length(output), 'x', Length(output[0]));
end;

// ✅ ОПТИМИЗАЦИЯ: Улучшенный backward pass с использованием кэшей
function ScaledDotProductAttentionBackward(var head: TAttentionHead;
                                         const gradOutput: TDoubleMatrix;
                                         const input: TDoubleMatrix): TDoubleMatrix;
var
  gradQ, gradK, gradV, gradWq, gradWk, gradWv, gradWeights: TDoubleMatrix;
  inputT: TDoubleMatrix;
  scaleFactor: Double;
begin
  WriteLn('      ScaledDotProductAttentionBackward (оптимизированная):');

  try
    // ✅ ОПТИМИЗАЦИЯ: Используем кэшированные значения
    if (Length(head.cachedK) = 0) or (Length(head.cachedK[0]) = 0) then
    begin
      WriteLn('      ОШИБКА: cachedK пустой');
      Exit(CopyMatrix(gradOutput));
    end;

    scaleFactor := 1.0 / (Sqrt(Length(head.cachedK[0])) + 1e-8);

    // Градиенты для значений (V)
    gradV := MatrixMultiply(TransposeMatrix(head.attentionWeights), gradOutput);
    gradWv := MatrixMultiply(gradV, TransposeMatrix(input));
    head.dWv := MatrixAdd(head.dWv, gradWv);

    // Градиенты для весов внимания
    gradWeights := MatrixMultiply(gradOutput, TransposeMatrix(head.cachedV));

    // Градиенты для запросов (Q) и ключей (K)
    inputT := TransposeMatrix(input);

    // ✅ ОПТИМИЗАЦИЯ: Используем кэшированную транспонированную K
    gradQ := MatrixMultiply(gradWeights, head.cachedK_T);
    ScaleMatrix(gradQ, scaleFactor);
    gradWq := MatrixMultiply(gradQ, inputT);
    head.dWq := MatrixAdd(head.dWq, gradWq);

    gradK := MatrixMultiply(TransposeMatrix(gradWeights), head.cachedQ);
    ScaleMatrix(gradK, scaleFactor);
    gradWk := MatrixMultiply(gradK, inputT);
    head.dWk := MatrixAdd(head.dWk, gradWk);

    // Градиент для предыдущего слоя
    Result := MatrixAdd(MatrixMultiply(gradQ, TransposeMatrix(head.Wq)),
                      MatrixAdd(MatrixMultiply(gradK, TransposeMatrix(head.Wk)),
                               MatrixMultiply(gradV, TransposeMatrix(head.Wv))));

  except
    on E: Exception do
    begin
      WriteLn('      ОШИБКА: ', E.Message);
      Result := CopyMatrix(gradOutput); // Fallback
    end;
  end;
end;

function SimpleAttentionBackward(var head: TAttentionHead;
                               const gradOutput: TDoubleMatrix;
                               const input: TDoubleMatrix): TDoubleMatrix;
var
  gradQ, gradK, gradV: TDoubleMatrix;
begin
  WriteLn('        SimpleAttentionBackward:');

  try
    // Упрощенная версия: градиенты через линейные преобразования
    gradQ := MatrixMultiply(gradOutput, TransposeMatrix(head.Wq));
    gradK := MatrixMultiply(gradOutput, TransposeMatrix(head.Wk)); 
    gradV := MatrixMultiply(gradOutput, TransposeMatrix(head.Wv));

    // Суммируем все градиенты
    Result := MatrixAdd(gradQ, MatrixAdd(gradK, gradV));

    WriteLn('        Результат: ', Length(Result), 'x', Length(Result[0]));

  except
    on E: Exception do
    begin
      WriteLn('        ОШИБКА: ', E.Message);
      Result := CopyMatrix(gradOutput);
    end;
  end;
end;

function AttentionBackward(var attention: TMultiHeadAttention;
                         const gradOutput: TDoubleMatrix;
                         const attnInput: TDoubleMatrix): TDoubleMatrix;
var
  i: Integer;
  headGrad, gradWo, WoT, scaledHeadGrad, scaledGradWo: TDoubleMatrix;
begin
  WriteLn('    AttentionBackward (исправленная):');
  WriteLn('      gradOutput: ', Length(gradOutput), 'x', Length(gradOutput[0]));
  WriteLn('      attnInput: ', Length(attnInput), 'x', Length(attnInput[0]));

  // Создаем нулевую матрицу правильного размера
  SetLength(Result, Length(gradOutput), Length(gradOutput[0]));
  FillMatrix(Result, 0.0);

  for i := 0 to High(attention.Heads) do
  begin
    try
      WriteLn('      Голова ', i, ':');

      // Инициализируем dWo если он пустой
      if Length(attention.Heads[i].dWo) = 0 then
      begin
        WriteLn('        Инициализируем dWo...');
        SetLength(attention.Heads[i].dWo, 
                 Length(attention.Heads[i].Wo), 
                 Length(attention.Heads[i].Wo[0]));
        FillMatrix(attention.Heads[i].dWo, 0.0);
      end;

      // Backward через output projection Wo
      if (Length(attention.Heads[i].cachedHeadOutput) > 0) then
      begin
        WriteLn('        Вычисление gradWo...');
        WriteLn('        cachedHeadOutput: ', 
                Length(attention.Heads[i].cachedHeadOutput), 'x', 
                Length(attention.Heads[i].cachedHeadOutput[0]));

        // ✅ ИСПРАВЛЕНИЕ: Правильная формула
        // gradWo = cachedHeadOutput^T * gradOutput
        gradWo := MatrixMultiply(
          TransposeMatrix(attention.Heads[i].cachedHeadOutput),
          gradOutput
        );

        WriteLn('        gradWo: ', Length(gradWo), 'x', Length(gradWo[0]));
        WriteLn('        dWo: ', Length(attention.Heads[i].dWo), 'x', 
                Length(attention.Heads[i].dWo[0]));

        // Проверяем и масштабируем размеры
        if (Length(gradWo) = Length(attention.Heads[i].dWo)) and 
           (Length(gradWo[0]) = Length(attention.Heads[i].dWo[0])) then
        begin
          attention.Heads[i].dWo := MatrixAdd(attention.Heads[i].dWo, gradWo);
          WriteLn('        Wo градиенты обновлены');
        end
        else
        begin
          WriteLn('        Масштабируем gradWo...');
          scaledGradWo := ScaleMatrixToSize(
            gradWo,
            Length(attention.Heads[i].dWo),
            Length(attention.Heads[i].dWo[0])
          );
          attention.Heads[i].dWo := MatrixAdd(attention.Heads[i].dWo, scaledGradWo);
          WriteLn('        Wo градиенты обновлены (с масштабированием)');
        end;
      end
      else
      begin
        WriteLn('        Предупреждение: cachedHeadOutput пустой');
      end;

      // Backward through attention projection
      WriteLn('        Вычисление headGrad через Wo...');
      if (Length(attention.Heads[i].Wo) > 0) and (Length(attention.Heads[i].Wo[0]) > 0) then
      begin
        WoT := TransposeMatrix(attention.Heads[i].Wo);
        WriteLn('        WoT: ', Length(WoT), 'x', Length(WoT[0]));

        headGrad := MatrixMultiply(gradOutput, WoT);
        WriteLn('        headGrad после Wo: ', Length(headGrad), 'x', Length(headGrad[0]));

        // Упрощенный backward через attention
        headGrad := SimpleAttentionBackward(attention.Heads[i], headGrad, attnInput);
        WriteLn('        headGrad после attention: ', Length(headGrad), 'x', Length(headGrad[0]));

        // Накопление градиентов
        if (Length(headGrad) = Length(Result)) and (Length(headGrad[0]) = Length(Result[0])) then
        begin
          Result := MatrixAdd(Result, headGrad);
          WriteLn('        Градиенты головы добавлены');
        end
        else
        begin
          WriteLn('        Предупреждение: Несовпадение размеров headGrad');
          WriteLn('        headGrad: ', Length(headGrad), 'x', Length(headGrad[0]));
          WriteLn('        Result: ', Length(Result), 'x', Length(Result[0]));

          // Масштабируем до правильного размера
          scaledHeadGrad := ScaleMatrixToSize(headGrad, Length(Result), Length(Result[0]));
          Result := MatrixAdd(Result, scaledHeadGrad);
          WriteLn('        Градиенты добавлены (с масштабированием)');
        end;
      end
      else
      begin
        WriteLn('        Предупреждение: Wo пустой');
      end;

    except
      on E: Exception do
      begin
        WriteLn('      ОШИБКА в голове ', i, ': ', E.Message);
        // Пропускаем эту голову
      end;
    end;
  end;
end;

procedure MultiHeadAttentionBackward(var mha: TMultiHeadAttention; 
                                   const input, gradOutput: TDoubleMatrix);
begin
  WriteLn('MultiHeadAttentionBackward: используем улучшенную версию через AttentionBackward');
  // Вызываем нашу улучшенную версию
  AttentionBackward(mha, gradOutput, input);
end;

procedure UpdateAttentionLayer(var attention: TMultiHeadAttention; learningRate: Double);
var
  i: Integer;
begin
  for i := 0 to High(attention.Heads) do
  begin
    // Обновление весов запроса
    UpdateMatrixWithAdam(attention.Heads[i].Wq, 
                       attention.Heads[i].dWq, 
                       attention.Heads[i].Wq_AdamState, 
                       learningRate);

    // Обновление весов ключа
    UpdateMatrixWithAdam(attention.Heads[i].Wk, 
                       attention.Heads[i].dWk, 
                       attention.Heads[i].Wk_AdamState, 
                       learningRate);

    // Обновление весов значения
    UpdateMatrixWithAdam(attention.Heads[i].Wv, 
                       attention.Heads[i].dWv, 
                       attention.Heads[i].Wv_AdamState, 
                       learningRate);

    // Обновление выходных весов
    UpdateMatrixWithAdam(attention.Heads[i].Wo, 
                       attention.Heads[i].dWo, 
                       attention.Heads[i].Wo_AdamState, 
                       learningRate);
  end;
end;

procedure FreeMultiHeadAttention(var mha: TMultiHeadAttention);
var
  i: Integer;
begin
  for i := 0 to High(mha.Heads) do
  begin
    SetLength(mha.Heads[i].Wq, 0);
    SetLength(mha.Heads[i].Wk, 0);
    SetLength(mha.Heads[i].Wv, 0);
    SetLength(mha.Heads[i].Wo, 0);
    SetLength(mha.Heads[i].dWq, 0);
    SetLength(mha.Heads[i].dWk, 0);
    SetLength(mha.Heads[i].dWv, 0);
    SetLength(mha.Heads[i].dWo, 0);

    // ✅ ОПТИМИЗАЦИЯ: Освобождаем кэши
    SetLength(mha.Heads[i].cachedQ, 0);
    SetLength(mha.Heads[i].cachedK, 0);
    SetLength(mha.Heads[i].cachedV, 0);
    SetLength(mha.Heads[i].cachedK_T, 0);
    SetLength(mha.Heads[i].cachedHeadOutput, 0);
    SetLength(mha.Heads[i].attentionWeights, 0);

    // Освобождаем состояния Adam
    FreeAdamState(mha.Heads[i].Wq_AdamState);
    FreeAdamState(mha.Heads[i].Wk_AdamState);
    FreeAdamState(mha.Heads[i].Wv_AdamState);
    FreeAdamState(mha.Heads[i].Wo_AdamState);
  end;
  SetLength(mha.Heads, 0);
  SetLength(mha.cachedOutput, 0);
end;

// Остальные функции без изменений
function CreateFutureMask(seqLength: Integer): TDoubleMatrix;
var
  i, j: Integer;
begin
  Result := nil;
  SetLength(Result, seqLength, seqLength);
  for i := 0 to seqLength - 1 do
    for j := 0 to seqLength - 1 do
      Result[i][j] := IfThen(j > i, -1e9, 0);
end;

function CreatePaddingMask(input: TDoubleMatrix; paddingValue: Double = 0): TDoubleMatrix;
var
  i, j: Integer;
begin
  Result := nil;
  if Length(input) = 0 then
    raise Exception.Create('Input matrix is empty in CreatePaddingMask');

  SetLength(Result, Length(input), Length(input));
  for i := 0 to High(input) do
    for j := 0 to High(input) do
      if (input[i][0] = paddingValue) or (input[j][0] = paddingValue) then
        Result[i][j] := -1e9
      else
        Result[i][j] := 0;
end;

procedure CheckAttentionWeights(const Attention: TMultiHeadAttention);
var
  i: Integer;
begin
  WriteLn('=== ПРОВЕРКА ВЕСОВ ATTENTION (оптимизированная) ===');
  WriteLn('Количество голов: ', Length(Attention.Heads));

  for i := 0 to High(Attention.Heads) do
  begin
    WriteLn('Голова ', i, ':');
    WriteLn('  Wq: ', Length(Attention.Heads[i].Wq), 'x', 
            IfThen(Length(Attention.Heads[i].Wq) > 0, 
                   IntToStr(Length(Attention.Heads[i].Wq[0])), '?'));
    WriteLn('  cachedK_T: ', Length(Attention.Heads[i].cachedK_T), 'x', 
            IfThen(Length(Attention.Heads[i].cachedK_T) > 0, 
                   IntToStr(Length(Attention.Heads[i].cachedK_T[0])), '?'));
  end;
end;

end.