unit DecisionTree;

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}

{$MODE OBJFPC}{$H+}//{$RANGECHECKS ON}
//{$OPTIMIZATION LEVEL4}
{$OPTIMIZATION LEVEL3}
{$OPTIMIZATION PEEPHOLE}
{$INLINE ON}

interface

uses
  SysUtils, Math, DataUtils;

type
  PTreeNode = ^TTreeNode;
  TTreeNode = record
    isLeaf: Boolean;
    classLabel: Integer;
    value: Double;
    splitFeature: Integer;
    splitValue: Double;
    left, right: PTreeNode;
  end;

  TDecisionTree = record
    root: PTreeNode;
    maxDepth: Integer;
    minSamplesSplit: Integer;
  end;

  TDataIndices = array of Integer;

procedure TrainDecisionTree(var tree: TDecisionTree; const x: TDoubleMatrix; 
                          const y: TDoubleArray; maxDepth: Integer; 
                          minSamplesSplit: Integer);
function PredictDecisionTree(const tree: TDecisionTree; const x: TDoubleArray): Double;
procedure FreeDecisionTree(var tree: TDecisionTree);

implementation

function CalculateMSE(const y: TDoubleArray; const indices: TDataIndices): Double; inline;
var
  i: Integer;
  sum, mean, val: Double;
begin
  if Length(indices) = 0 then Exit(MaxDouble);

  sum := 0;
  for i := 0 to High(indices) do
    sum := sum + y[indices[i]];
  mean := sum / Length(indices);

  Result := 0;
  for i := 0 to High(indices) do
  begin
    val := y[indices[i]] - mean;
    Result := Result + val * val;
  end;
  Result := Result / Length(indices);
end;

{
procedure SplitData(const x: TDoubleMatrix; const y: TDoubleArray; 
                   const indices: TDataIndices; featureIndex: Integer; 
                   threshold: Double; var leftIndices, rightIndices: TDataIndices); inline;
var
  i, leftPos, rightPos: Integer;
begin
  SetLength(leftIndices, Length(indices));
  SetLength(rightIndices, Length(indices));
  leftPos := 0;
  rightPos := 0;

  for i := 0 to High(indices) do
  begin
    if x[indices[i]][featureIndex] <= threshold then
    begin
      leftIndices[leftPos] := indices[i];
      Inc(leftPos);
    end
    else
    begin
      rightIndices[rightPos] := indices[i];
      Inc(rightPos);
    end;
  end;

  SetLength(leftIndices, leftPos);
  SetLength(rightIndices, rightPos);
end;
}
procedure SplitData(const x: TDoubleMatrix; const y: TDoubleArray;
                   out xTrain, xTest: TDoubleMatrix;
                   out yTrain, yTest: TDoubleArray);
var
  i, trainSize: Integer;
begin
  if Length(x) = 0 then Exit;

  trainSize := Round(Length(x) * (1 - TEST_SIZE));
  
  // Проверка размеров
  if (trainSize <= 0) or (Length(x) - trainSize <= 0) then
    raise Exception.Create('Invalid train/test split');

  SetLength(xTrain, trainSize, Length(x[0]));
  SetLength(yTrain, trainSize);
  SetLength(xTest, Length(x) - trainSize, Length(x[0]));
  SetLength(yTest, Length(y) - trainSize);

  // Копирование данных с проверкой границ
  for i := 0 to trainSize - 1 do
  begin
    if (i >= Length(x)) or (Length(x[i]) <> Length(xTrain[i])) then
      raise Exception.Create('Array bounds error');
    xTrain[i] := Copy(x[i], 0, Length(x[i]));
    yTrain[i] := y[i];
  end;
  
  for i := 0 to High(xTest) do
  begin
    if (trainSize + i >= Length(x)) or (Length(x[trainSize + i]) <> Length(xTest[i])) then
      raise Exception.Create('Array bounds error');
    xTest[i] := Copy(x[trainSize + i], 0, Length(x[trainSize + i]));
    yTest[i] := y[trainSize + i];
  end;
end;

function FindBestSplit(const x: TDoubleMatrix; const y: TDoubleArray;
                      const indices: TDataIndices; var bestFeature: Integer;
                      var bestValue: Double; var leftIndices, rightIndices: TDataIndices): Boolean;
var
  i, j, k, feat: Integer;
  currentValue, currentGain, bestGain, mseBefore: Double;
  currentLeft, currentRight: TDataIndices;
begin
  Result := False;
  bestFeature := -1;
  bestValue := 0;
  bestGain := -1;
  mseBefore := CalculateMSE(y, indices);

  // Проверяем только 10 случайных признаков
  for feat := 1 to 10 do
  begin
    j := Random(Length(x[0]));
    
    // Проверяем только 20 случайных значений
    for k := 1 to 20 do
    begin
      i := indices[Random(Length(indices))];
      currentValue := x[i][j];
      
      SplitData(x, y, indices, j, currentValue, currentLeft, currentRight);
      
      if (Length(currentLeft) > 0) and (Length(currentRight) > 0) then
      begin
        currentGain := mseBefore - (CalculateMSE(y, currentLeft) * Length(currentLeft) + 
                      CalculateMSE(y, currentRight) * Length(currentRight)) / Length(indices);
        
        if (currentGain > 0) and (currentGain > bestGain) then
        begin
          bestGain := currentGain;
          bestFeature := j;
          bestValue := currentValue;
          leftIndices := currentLeft;
          rightIndices := currentRight;
          Result := True;
        end;
      end;
    end;
  end;
end;

function BuildTree(const x: TDoubleMatrix; const y: TDoubleArray; 
                  const indices: TDataIndices; depth: Integer; 
                  maxDepth: Integer; minSamplesSplit: Integer): PTreeNode;
var
  bestFeature: Integer;
  bestValue: Double;
  leftIndices, rightIndices: TDataIndices;
  i: Integer;
  sum: Double;
begin
  New(Result);
  Result^.isLeaf := False;
  Result^.left := nil;
  Result^.right := nil;

  // Условия остановки рекурсии
  if (depth >= maxDepth) or (Length(indices) < minSamplesSplit) or
     not FindBestSplit(x, y, indices, bestFeature, bestValue, leftIndices, rightIndices) then
  begin
    Result^.isLeaf := True;
    Result^.classLabel := 0;
    
    // Вычисляем среднее значение для регрессии
    sum := 0;
    for i := 0 to High(indices) do
      sum := sum + y[indices[i]];
    Result^.value := sum / Length(indices);
    
    Exit;
  end;

  // Создаем внутренний узел
  Result^.splitFeature := bestFeature;
  Result^.splitValue := bestValue;
  Result^.left := BuildTree(x, y, leftIndices, depth + 1, maxDepth, minSamplesSplit);
  Result^.right := BuildTree(x, y, rightIndices, depth + 1, maxDepth, minSamplesSplit);
end;

procedure TrainDecisionTree(var tree: TDecisionTree; const x: TDoubleMatrix; 
                          const y: TDoubleArray; maxDepth: Integer; 
                          minSamplesSplit: Integer);
var
  indices: TDataIndices;
  i: Integer;
begin
  if Length(x) = 0 then
    raise Exception.Create('Training data is empty');
  if Length(x) <> Length(y) then
    raise Exception.Create('X and Y must have same length');

  // Инициализируем индексы всех строк
  SetLength(indices, Length(x));
  for i := 0 to High(indices) do
    indices[i] := i;

  tree.maxDepth := maxDepth;
  tree.minSamplesSplit := minSamplesSplit;
  tree.root := BuildTree(x, y, indices, 0, maxDepth, minSamplesSplit);
end;

function PredictTree(node: PTreeNode; const x: TDoubleArray): Double; inline;
begin
  if node^.isLeaf then
    Exit(node^.value);

  if x[node^.splitFeature] <= node^.splitValue then
    Result := PredictTree(node^.left, x)
  else
    Result := PredictTree(node^.right, x);
end;

function PredictDecisionTree(const tree: TDecisionTree; const x: TDoubleArray): Double;
begin
  if tree.root = nil then
    raise Exception.Create('Decision tree not trained');
  Result := PredictTree(tree.root, x);
end;

procedure FreeTree(node: PTreeNode);
begin
  if node = nil then Exit;
  
  if not node^.isLeaf then
  begin
    FreeTree(node^.left);
    FreeTree(node^.right);
  end;
  
  Dispose(node);
end;

procedure FreeDecisionTree(var tree: TDecisionTree);
begin
  FreeTree(tree.root);
  tree.root := nil;
end;

end.