unit Optimizers;

{$MODE OBJFPC}{$H+}
{$RANGECHECKS ON}
{$ASMMODE INTEL}

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}


interface

uses
  SysUtils, Math, DataUtils;

procedure InitAdamState(var state: TAdamState; rows, cols: Integer);
procedure InitAdamVectorState(var state: TAdamVectorState; size: Integer);
procedure UpdateMatrixWithAdam(var params, grads: TDoubleMatrix; var state: TAdamState; learningRate: Double; weightDecay: Double = 0.0);
procedure FreeAdamState(var state: TAdamState);
procedure FreeAdamVectorState(var state: TAdamVectorState);
procedure AdamUpdate(var params, grads: TDoubleMatrix; var state: TAdamState; learningRate: Double = 0.001);
procedure ApplyGradientClipping(var Gradients: TDoubleMatrix; MaxNorm: Double);
procedure ApplyL2ToMatrix(var Matrix: TDoubleMatrix; WeightDecay, LearningRate: Double);
procedure UpdateVectorAdam(var params, grads: TDoubleArray; var state: TAdamVectorState; learningRate: Double);

implementation

{$I asmf.inc}

const
  EPS_ADAM = 1e-8;

procedure InitAdamState(var state: TAdamState; rows, cols: Integer);
var
  i: Integer;
begin
  SetLength(state.M, rows, cols);
  SetLength(state.V, rows, cols);
  // zero initialize
  for i := 0 to rows - 1 do
    FillChar(state.M[i][0], cols * SizeOf(Double), 0);
  for i := 0 to rows - 1 do
    FillChar(state.V[i][0], cols * SizeOf(Double), 0);

  state.Beta1 := 0.9;
  state.Beta2 := 0.999;
  state.Timestep := 0;
  state.GradientClipValue := 0.0;
end;

procedure InitAdamVectorState(var state: TAdamVectorState; size: Integer);
begin
SetLength(state.M, size);
SetLength(state.V, size);
FillChar(state.M[0], size * SizeOf(Double), 0);
FillChar(state.V[0], size * SizeOf(Double), 0);
state.Beta1 := 0.9;
state.Beta2 := 0.999;
state.Timestep := 0;
end;

procedure FreeAdamState(var state: TAdamState);
begin
SetLength(state.M, 0);
SetLength(state.V, 0);
state.Timestep := 0;
end;

procedure FreeAdamVectorState(var state: TAdamVectorState);
begin
SetLength(state.M, 0);
SetLength(state.V, 0);
state.Timestep := 0;
end;

procedure UpdateMatrixWithAdam(var params, grads: TDoubleMatrix; var state: TAdamState; learningRate: Double; weightDecay: Double = 0.0);
var
  r, c, i, j: Integer;
  beta1t, beta2t: Double;
  denom1, denom2: Double;
  mHat, vHat, g: Double;
begin
  // safety checks
  r := Length(params);
  if r = 0 then Exit;
  c := Length(params[0]);
  if c = 0 then Exit;

  Inc(state.Timestep);

  // compute bias correction factors once per update (expensive Power removed from inner loop)
  beta1t := Power(state.Beta1, state.Timestep);
  beta2t := Power(state.Beta2, state.Timestep);
  denom1 := 1.0 - beta1t;
  denom2 := 1.0 - beta2t;
  if denom1 <= 0 then denom1 := EPS_ADAM;
  if denom2 <= 0 then denom2 := EPS_ADAM;

  // local aliases for speed
  for i := 0 to r - 1 do begin
    // small unrolling for inner loop
    j := 0;
    while j <= c - 4 do begin
      // index j
      g := grads[i][j];
      if weightDecay <> 0.0 then g := g + weightDecay * params[i][j];
      state.M[i][j] := state.Beta1 * state.M[i][j] + (1 - state.Beta1) * g;
      state.V[i][j] := state.Beta2 * state.V[i][j] + (1 - state.Beta2) * (g * g);
      mHat := state.M[i][j] / denom1;
      vHat := state.V[i][j] / denom2;
      params[i][j] := params[i][j] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

      // j+1
      g := grads[i][j+1];
      if weightDecay <> 0.0 then g := g + weightDecay * params[i][j+1];
      state.M[i][j+1] := state.Beta1 * state.M[i][j+1] + (1 - state.Beta1) * g;
      state.V[i][j+1] := state.Beta2 * state.V[i][j+1] + (1 - state.Beta2) * (g * g);
      mHat := state.M[i][j+1] / denom1;
      vHat := state.V[i][j+1] / denom2;
      params[i][j+1] := params[i][j+1] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

      // j+2
      g := grads[i][j+2];
      if weightDecay <> 0.0 then g := g + weightDecay * params[i][j+2];
      state.M[i][j+2] := state.Beta1 * state.M[i][j+2] + (1 - state.Beta1) * g;
      state.V[i][j+2] := state.Beta2 * state.V[i][j+2] + (1 - state.Beta2) * (g * g);
      mHat := state.M[i][j+2] / denom1;
      vHat := state.V[i][j+2] / denom2;
      params[i][j+2] := params[i][j+2] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

      // j+3
      g := grads[i][j+3];
      if weightDecay <> 0.0 then g := g + weightDecay * params[i][j+3];
      state.M[i][j+3] := state.Beta1 * state.M[i][j+3] + (1 - state.Beta1) * g;
      state.V[i][j+3] := state.Beta2 * state.V[i][j+3] + (1 - state.Beta2) * (g * g);
      mHat := state.M[i][j+3] / denom1;
      vHat := state.V[i][j+3] / denom2;
      params[i][j+3] := params[i][j+3] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

      Inc(j, 4);
    end;

    // remaining
    while j < c do begin
      g := grads[i][j];
      if weightDecay <> 0.0 then g := g + weightDecay * params[i][j];
      state.M[i][j] := state.Beta1 * state.M[i][j] + (1 - state.Beta1) * g;
      state.V[i][j] := state.Beta2 * state.V[i][j] + (1 - state.Beta2) * (g * g);
      mHat := state.M[i][j] / denom1;
      vHat := state.V[i][j] / denom2;
      params[i][j] := params[i][j] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);
      Inc(j);
    end;
  end;
end;

procedure UpdateVectorAdam(var params, grads: TDoubleArray; var state: TAdamVectorState; learningRate: Double);
var
  n, i: Integer;
  beta1t, beta2t, denom1, denom2: Double;
  g, mHat, vHat: Double;
begin
  n := Length(params);
  if n = 0 then Exit;
  Inc(state.Timestep);
  // compute bias-power once
  beta1t := Power(state.Beta1, state.Timestep);
  beta2t := Power(state.Beta2, state.Timestep);
  denom1 := 1.0 - beta1t;
  denom2 := 1.0 - beta2t;
  if denom1 <= 0 then denom1 := EPS_ADAM;
  if denom2 <= 0 then denom2 := EPS_ADAM;

  i := 0;
  while i <= n - 4 do begin
    g := grads[i];
    state.M[i] := state.Beta1 * state.M[i] + (1 - state.Beta1) * g;
    state.V[i] := state.Beta2 * state.V[i] + (1 - state.Beta2) * (g * g);
    mHat := state.M[i] / denom1;
    vHat := state.V[i] / denom2;
    params[i] := params[i] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

    g := grads[i+1];
    state.M[i+1] := state.Beta1 * state.M[i+1] + (1 - state.Beta1) * g;
    state.V[i+1] := state.Beta2 * state.V[i+1] + (1 - state.Beta2) * (g * g);
    mHat := state.M[i+1] / denom1;
    vHat := state.V[i+1] / denom2;
    params[i+1] := params[i+1] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

    g := grads[i+2];
    state.M[i+2] := state.Beta1 * state.M[i+2] + (1 - state.Beta1) * g;
    state.V[i+2] := state.Beta2 * state.V[i+2] + (1 - state.Beta2) * (g * g);
    mHat := state.M[i+2] / denom1;
    vHat := state.V[i+2] / denom2;
    params[i+2] := params[i+2] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

    g := grads[i+3];
    state.M[i+3] := state.Beta1 * state.M[i+3] + (1 - state.Beta1) * g;
    state.V[i+3] := state.Beta2 * state.V[i+3] + (1 - state.Beta2) * (g * g);
    mHat := state.M[i+3] / denom1;
    vHat := state.V[i+3] / denom2;
    params[i+3] := params[i+3] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);

    Inc(i, 4);
  end;
  while i < n do begin
    g := grads[i];
    state.M[i] := state.Beta1 * state.M[i] + (1 - state.Beta1) * g;
    state.V[i] := state.Beta2 * state.V[i] + (1 - state.Beta2) * (g * g);
    mHat := state.M[i] / denom1;
    vHat := state.V[i] / denom2;
    params[i] := params[i] - learningRate * mHat / (Sqrt(vHat) + EPS_ADAM);
    Inc(i);
  end;
end;

procedure AdamUpdate(var params, grads: TDoubleMatrix; var state: TAdamState; learningRate: Double = 0.001);
begin
  // alias to UpdateMatrixWithAdam for compatibility
  UpdateMatrixWithAdam(params, grads, state, learningRate, 0.0);
end;

procedure ApplyGradientClipping(var Gradients: TDoubleMatrix; MaxNorm: Double);
var
  total, scale: Double;
  i, j: Integer;
begin
  if MaxNorm <= 0 then Exit;
  total := 0.0;
  for i := 0 to High(Gradients) do
    for j := 0 to High(Gradients[i]) do
      total := total + Gradients[i][j] * Gradients[i][j];

  total := Sqrt(total);
  if total > MaxNorm then begin
    scale := MaxNorm / (total + 1e-12);
    for i := 0 to High(Gradients) do
      for j := 0 to High(Gradients[i]) do
        Gradients[i][j] := Gradients[i][j] * scale;
  end;
end;

procedure ApplyL2ToMatrix(var Matrix: TDoubleMatrix; WeightDecay, LearningRate: Double);
var
  i, j: Integer;
  factor: Double;
begin
  if (WeightDecay = 0.0) or (Length(Matrix) = 0) then Exit;
  factor := 1.0 - LearningRate * WeightDecay;
  for i := 0 to High(Matrix) do
    for j := 0 to High(Matrix[i]) do
      Matrix[i][j] := Matrix[i][j] * factor;
end;

end.