program ii3_classify;

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}

{$MODE OBJFPC}{$H+}

uses
  SysUtils, Classes, DataUtils, DecisionTree, RandomForest;

type
  TModelType = (mtTree, mtForest);

  TModelWrapper = class
  public
    ModelType: TModelType;
    Tree: TDecisionTree;
    Forest: TRandomForest;
    constructor Create;
    destructor Destroy; override;
  end;

constructor TModelWrapper.Create;
begin
  inherited;
  ModelType := mtTree;
  Tree.root := nil;
  SetLength(Forest.trees, 0);
end;

destructor TModelWrapper.Destroy;
begin
  if ModelType = mtTree then
    FreeDecisionTree(Tree)
  else
    FreeRandomForest(Forest);
  inherited;
end;

procedure PrintHelp;
begin
  WriteLn('Usage: ii3_classify [options]');
  WriteLn('Options:');
  WriteLn('  --train <file>        Train model using data from file');
  WriteLn('  --predict <file>      Make predictions using trained model');
  WriteLn('  --model <file>        Model file to save/load');
  WriteLn('  --algorithm <name>    Algorithm to use (tree or forest)');
  WriteLn('  --max-depth <n>       Maximum tree depth (default: 5)');
  WriteLn('  --min-samples <n>     Minimum samples to split (default: 2)');
  WriteLn('  --num-trees <n>       Number of trees in forest (default: 10)');
  WriteLn('  --target-col <n>      Target column index (default: last column)');
  WriteLn('  --polynomial <n>      Add polynomial features up to degree n');
  WriteLn('  --help                Show this help message');
  Halt(0);
end;

type
  TAppOptions = record
    TrainFile: string;
    PredictFile: string;
    ModelFile: string;
    Algorithm: string;
    MaxDepth: Integer;
    MinSamples: Integer;
    NumTrees: Integer;
    TargetCol: Integer;
    PolynomialDegree: Integer;
  end;

function ParseOptions: TAppOptions;
var
  i: Integer;
begin
  Result.TrainFile := '';
  Result.PredictFile := '';
  Result.ModelFile := 'model.bin';
  Result.Algorithm := 'forest';
  Result.MaxDepth := 5;
  Result.MinSamples := 2;
  Result.NumTrees := 10;
  Result.TargetCol := -1; // -1 means last column
  Result.PolynomialDegree := 0;

  i := 1;
  while i <= ParamCount do
  begin
    if ParamStr(i) = '--train' then
    begin
      Inc(i);
      Result.TrainFile := ParamStr(i);
    end
    else if ParamStr(i) = '--predict' then
    begin
      Inc(i);
      Result.PredictFile := ParamStr(i);
    end
    else if ParamStr(i) = '--model' then
    begin
      Inc(i);
      Result.ModelFile := ParamStr(i);
    end
    else if ParamStr(i) = '--algorithm' then
    begin
      Inc(i);
      Result.Algorithm := LowerCase(ParamStr(i));
    end
    else if ParamStr(i) = '--max-depth' then
    begin
      Inc(i);
      Result.MaxDepth := StrToInt(ParamStr(i));
    end
    else if ParamStr(i) = '--min-samples' then
    begin
      Inc(i);
      Result.MinSamples := StrToInt(ParamStr(i));
    end
    else if ParamStr(i) = '--num-trees' then
    begin
      Inc(i);
      Result.NumTrees := StrToInt(ParamStr(i));
    end
    else if ParamStr(i) = '--target-col' then
    begin
      Inc(i);
      Result.TargetCol := StrToInt(ParamStr(i));
    end
    else if ParamStr(i) = '--polynomial' then
    begin
      Inc(i);
      Result.PolynomialDegree := StrToInt(ParamStr(i));
    end
    else if ParamStr(i) = '--help' then
    begin
      PrintHelp;
    end;
    Inc(i);
  end;

  if (Result.TrainFile = '') and (Result.PredictFile = '') then
  begin
    WriteLn('Error: Either --train or --predict must be specified');
    PrintHelp;
  end;
end;

procedure SplitData(const data: TDoubleMatrix; targetCol: Integer; 
                   var x: TDoubleMatrix; var y: TDoubleArray);
var
  i, j, cols: Integer;
begin
  if Length(data) = 0 then Exit;

  cols := Length(data[0]);
  if targetCol < 0 then targetCol := cols - 1;

  SetLength(x, Length(data));
  SetLength(y, Length(data));

  for i := 0 to High(data) do
  begin
    SetLength(x[i], cols - 1);
    for j := 0 to targetCol - 1 do
      x[i][j] := data[i][j];
    for j := targetCol + 1 to cols - 1 do
      x[i][j - 1] := data[i][j];
    y[i] := data[i][targetCol];
  end;
end;

procedure SaveModel(wrapper: TModelWrapper; const filename: string);
var
  stream: TFileStream;
begin
  stream := TFileStream.Create(filename, fmCreate);
  try
    stream.WriteAnsiString(wrapper.ClassName);
    if wrapper.ModelType = mtTree then
      stream.Write(wrapper.Tree, SizeOf(wrapper.Tree))
    else
      stream.Write(wrapper.Forest, SizeOf(wrapper.Forest));
  finally
    stream.Free;
  end;
end;

function LoadModel(const filename: string): TModelWrapper;
var
  stream: TFileStream;
  className: string;
begin
  Result := TModelWrapper.Create;
  stream := TFileStream.Create(filename, fmOpenRead);
  try
    className := stream.ReadAnsiString;
    if className = 'TModelWrapper' then
    begin
      if stream.Position < stream.Size - SizeOf(Result.Forest) then
      begin
        stream.Read(Result.Tree, SizeOf(Result.Tree));
        Result.ModelType := mtTree;
      end
      else
      begin
        stream.Read(Result.Forest, SizeOf(Result.Forest));
        Result.ModelType := mtForest;
      end;
    end
    else
      raise Exception.Create('Unknown model type in file');
  finally
    stream.Free;
  end;
end;

procedure TrainModel(const options: TAppOptions);
var
  data, x: TDoubleMatrix;
  y: TDoubleArray;
  wrapper: TModelWrapper;
begin
  // Load and prepare data
  LoadData(options.TrainFile, data);
  
  if options.PolynomialDegree > 1 then
    AddPolynomialFeatures(data, options.PolynomialDegree);
  
  NormalizeData(data);
  SplitData(data, options.TargetCol, x, y);

  // Train selected model
  wrapper := TModelWrapper.Create;
  try
    if options.Algorithm = 'tree' then
    begin
      wrapper.ModelType := mtTree;
      TrainDecisionTree(wrapper.Tree, x, y, options.MaxDepth, options.MinSamples);
    end
    else if options.Algorithm = 'forest' then
    begin
      wrapper.ModelType := mtForest;
      TrainRandomForest(wrapper.Forest, x, y, options.NumTrees, options.MaxDepth, options.MinSamples);
    end
    else
      raise Exception.Create('Unknown algorithm: ' + options.Algorithm);

    SaveModel(wrapper, options.ModelFile);
    WriteLn('Model trained and saved to ', options.ModelFile);
  finally
    wrapper.Free;
  end;
end;

procedure Predict(const options: TAppOptions);
var
  data, x: TDoubleMatrix;
  y: TDoubleArray;
  wrapper: TModelWrapper;
  i: Integer;
  prediction: Double;
begin
  // Load data to predict
  LoadData(options.PredictFile, data);
  
  if options.PolynomialDegree > 1 then
    AddPolynomialFeatures(data, options.PolynomialDegree);
  
  NormalizeData(data);
  SplitData(data, options.TargetCol, x, y);

  // Load model
  wrapper := LoadModel(options.ModelFile);
  try
    // Make predictions
    if wrapper.ModelType = mtTree then
    begin
      for i := 0 to High(x) do
      begin
        prediction := PredictDecisionTree(wrapper.Tree, x[i]);
        WriteLn('Sample ', i+1, ': ', prediction:0:4);
      end;
    end
    else if wrapper.ModelType = mtForest then
    begin
      for i := 0 to High(x) do
      begin
        prediction := PredictRandomForest(wrapper.Forest, x[i]);
        WriteLn('Sample ', i+1, ': ', prediction:0:4);
      end;
    end;
  finally
    wrapper.Free;
  end;
end;

var
  options: TAppOptions;
begin
  try
    options := ParseOptions;
    Randomize;

    if options.TrainFile <> '' then
      TrainModel(options);
    
    if options.PredictFile <> '' then
      Predict(options);
  except
    on E: Exception do
    begin
      WriteLn('Error: ', E.Message);
      Halt(1);
    end;
  end;
end.