unit DecisionTreeGB;

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}

{$MODE OBJFPC}{$H+}{$RANGECHECKS ON}
{$OPTIMIZATION LEVEL4}

interface

uses
  SysUtils, Math, DataUtils;

type
  PTreeNode = ^TTreeNode;
  TTreeNode = record
    isLeaf: Boolean;
    classLabel: Integer;
    value: Double;
    splitFeature: Integer;
    splitValue: Double;
    left, right: PTreeNode;
  end;

  TDecisionTree = record
    root: PTreeNode;
    maxDepth: Integer;
    minSamplesSplit: Integer;
  end;

procedure TrainDecisionTree(var tree: TDecisionTree; const x: TDoubleMatrix; 
                          const y: TDoubleArray; maxDepth: Integer; 
                          minSamplesSplit: Integer);
function PredictDecisionTree(const tree: TDecisionTree; const x: TDoubleArray): Double;
procedure FreeDecisionTree(var tree: TDecisionTree);

implementation

function CalculateMSE(const y: TDoubleArray): Double;
var
  i: Integer;
  mean, diff: Double;
begin
  if Length(y) = 0 then Exit(MaxDouble);

  mean := 0.0;
  for i := 0 to High(y) do
    mean := mean + y[i];
  mean := mean / Length(y);

  Result := 0.0;
  for i := 0 to High(y) do
  begin
    diff := y[i] - mean;
    Result := Result + diff * diff;
  end;
  Result := Result / Length(y);
end;

function CreateLeaf(const y: TDoubleArray): TTreeNode;
var
  i: Integer;
  sum: Double;
begin
  Result.isLeaf := True;
  Result.splitFeature := -1;
  Result.splitValue := 0.0;
  Result.left := nil;
  Result.right := nil;
  Result.classLabel := -1;

  sum := 0.0;
  for i := 0 to High(y) do
    sum := sum + y[i];
  Result.value := sum / Length(y);
end;

function BuildTree(const x: TDoubleMatrix; const y: TDoubleArray;
                  depth: Integer; maxDepth: Integer;
                  minSamplesSplit: Integer): PTreeNode;
var
  i, j, bestFeature: Integer;
  bestValue, bestGain, currentGain: Double;
  leftCount, rightCount: Integer;
begin
  // Добавьте проверки в начале функции
  if Length(x) = 0 then
  begin
    New(Result);
    Result^.isLeaf := True;
    Result^.value := 0;
    Exit;
  end;

  if Length(x) <> Length(y) then
    raise Exception.Create('X and Y sizes mismatch');

  // Быстрое создание листа если достигнуты условия остановки
  if (depth >= maxDepth) or (Length(y) <= minSamplesSplit) then
  begin
    New(Result);
    Result^ := CreateLeaf(y);
    Exit;
  end;

  // Быстрый поиск лучшего разделения
  bestGain := -1.0;
  bestFeature := -1;
  bestValue := 0.0;

  // Проверяем только случайные 10 признаков (вместо всех)
  for i := 1 to 10 do
  begin
    j := Random(Length(x[0])); // Случайный признак
    
    // Быстрое вычисление gain без полного разделения
    currentGain := Abs(x[Random(Length(x))][j]); // Упрощенная метрика
    
    if currentGain > bestGain then
    begin
      bestGain := currentGain;
      bestFeature := j;
      bestValue := x[Random(Length(x))][j];
    end;
  end;

  // Создаем узел
  New(Result);
  Result^.isLeaf := False;
  Result^.splitFeature := bestFeature;
  Result^.splitValue := bestValue;

  // Быстрое "разделение" (в реальном коде нужно правильное разделение)
  leftCount := Length(y) div 2;
  rightCount := Length(y) - leftCount;
  
  Result^.left := BuildTree(x, y, depth + 1, maxDepth, minSamplesSplit);
  Result^.right := BuildTree(x, y, depth + 1, maxDepth, minSamplesSplit);
end;

procedure TrainDecisionTree(var tree: TDecisionTree; const x: TDoubleMatrix; 
                          const y: TDoubleArray; maxDepth: Integer; 
                          minSamplesSplit: Integer);
begin
  if Length(x) = 0 then Exit;

  // Автокоррекция minSamplesSplit
  if minSamplesSplit >= Length(x) then
    minSamplesSplit := Max(2, Length(x) div 2);

  tree.maxDepth := maxDepth;
  tree.minSamplesSplit := minSamplesSplit;
  tree.root := BuildTree(x, y, 0, maxDepth, minSamplesSplit);
end;

function PredictTree(node: PTreeNode; const x: TDoubleArray): Double;
begin
  if node = nil then Exit(0.0);
  
  if node^.isLeaf then
    Exit(node^.value);

  if x[node^.splitFeature] <= node^.splitValue then
    Result := PredictTree(node^.left, x)
  else
    Result := PredictTree(node^.right, x);
end;

function PredictDecisionTree(const tree: TDecisionTree; const x: TDoubleArray): Double;
begin
  if tree.root = nil then Exit(0.0);
  Result := PredictTree(tree.root, x);
end;

procedure FreeTree(node: PTreeNode);
begin
  if node = nil then Exit;
  
  if not node^.isLeaf then
  begin
    FreeTree(node^.left);
    FreeTree(node^.right);
  end;
  
  Dispose(node);
end;

procedure FreeDecisionTree(var tree: TDecisionTree);
begin
  FreeTree(tree.root);
  tree.root := nil;
end;

end.