unit Word2VecOpt;

{
    Part of AdvancedChatAI.
    For GNU/Linux 64 bit version.
    Version: 1.
    Written on FreePascal (https://freepascal.org/).
    Copyright (C) 2025-2026 Artyomov Alexander
    Used https://chat.deepseek.com/
    http://self-made-free.ru/
    aralni@mail.ru

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as
    published by the Free Software Foundation, either version 3 of the
    License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
}


{$MODE OBJFPC}
{$H+}
{$RANGECHECKS ON}
{$ASMMODE INTEL}

interface

uses
  SysUtils, Classes, Generics.Collections, Math, DataUtils, LazUTF8, fgl;

const
  EMBEDDING_SIZE = 300;
  MERGE_THRESHOLD = 64;

type
  TStringArray = array of string;
  TDoubleArray = array of Double;
  TDoubleMatrix = array of TDoubleArray;

  TScore = record
    Word: string;
    Score: Double;
  end;
  TScoreArray = array of TScore;

  // Node for LRU cache
  TCacheNode = record
    Word: string;
    Embedding: TDoubleArray;
    Prev, Next: Integer; // indices in array; -1 = none
  end;

  TWordEmbeddings = class
  private
    FVocab: TStringList;                      // list of words (lowercased)
    FEmbeddings: TDoubleMatrix;               // [vocabSize x embSize]
    FEmbeddingSize: Integer;

    // mapping word -> index in FVocab / FEmbeddings
    FWordToIndex: specialize TDictionary<string,Integer>;

    // Precomputed norms
    FEmbeddingNorms: TDoubleArray;

    // LRU cache for embeddings (O(1) operations)
    FCacheNodes: array of TCacheNode;
    FCacheMap: specialize TDictionary<string,Integer>; // word -> node index
    FCacheCapacity: Integer;
    FCacheCount: Integer;
    FCacheHead, FCacheTail: Integer; // indices of head/tail

    // sorting helpers
    procedure QuickSort(var A: TScoreArray; L, R: Integer);
    procedure MergeArrays(var A: TScoreArray; L, M, R: Integer; var Temp: TScoreArray);
    procedure MergeSort(var A: TScoreArray; L, R: Integer; var Temp: TScoreArray);

    function SumOfSquares(const Vec: TDoubleArray): Double; inline;

    // cache helpers
    procedure CacheInit(Capacity: Integer);
    procedure CacheMoveToFront(nodeIdx: Integer);
    procedure CacheInsert(const Word: string; const Emb: TDoubleArray);
    function CacheGet(const Word: string; out Emb: TDoubleArray): Boolean;
    procedure CacheClear;

  public
    constructor Create(const ModelFile: string; CacheSize: Integer = 8192);
    destructor Destroy; override;

    // accessors
    function GetWordIndex(const Word: string): Integer;
    function GetEmbedding(const Word: string): TDoubleArray; // copy
    function GetEmbeddingFastByIndex(Index: Integer): TDoubleArray; // reference
    function GetEmbeddingByIndex(Index: Integer): TDoubleArray; // copy
    function GetEmbeddingWithCache(const Word: string): TDoubleArray;

    // similarity
    function FastSimilarityScore(const Emb1, Emb2: TDoubleArray): Double; inline;
    function FastSimilarity(const Word1, Word2: string): Double;
    function Similarity(const Word1, Word2: string): Double;

    // top-N
    function MostSimilar(const Word: string; TopN: Integer = 5): TStringArray;

    // cache control & stats
    procedure ClearCache;
    function GetCacheStats: string;
    property EmbeddingSize: Integer read FEmbeddingSize;
  end;

implementation

{ ---------------------- Helpers ---------------------- }

procedure SwapScore(var A, B: TScore); inline;
var T: TScore;
begin
  T := A; A := B; B := T;
end;

{ ---------------------- TWordEmbeddings ---------------------- }

constructor TWordEmbeddings.Create(const ModelFile: string; CacheSize: Integer);
var
  f: TextFile;
  line: string;
  parts: TStringArray;
  vocabSize, embedSize: Integer;
  i, j, loaded: Integer;
begin
  inherited Create;

  FVocab := TStringList.Create;
  FVocab.Sorted := False;
  FVocab.Duplicates := dupIgnore;

  FWordToIndex := specialize TDictionary<string,Integer>.Create;
  FCacheMap := specialize TDictionary<string,Integer>.Create;

  // defaults
  FCacheCapacity := Max(16, CacheSize);
  CacheInit(FCacheCapacity);

  // if model file absent, create minimal model
  if not FileExists(ModelFile) then
  begin
    Writeln('Model file not found, creating minimal model...');
    // minimal model: few dummy words
    FVocab.Clear;
    FVocab.Add('and'); FVocab.Add('in'); FVocab.Add('the'); FVocab.Add('of'); FVocab.Add('to');
    FEmbeddingSize := EMBEDDING_SIZE;
    SetLength(FEmbeddings, FVocab.Count, FEmbeddingSize);
    for i := 0 to FVocab.Count - 1 do
      for j := 0 to FEmbeddingSize - 1 do
        FEmbeddings[i][j] := (Random - 0.5) * 0.02;
    // build dictionary
    for i := 0 to FVocab.Count - 1 do
      FWordToIndex.Add(LowerCase(FVocab[i]), i);
    // norms
    SetLength(FEmbeddingNorms, FVocab.Count);
    for i := 0 to FVocab.Count - 1 do FEmbeddingNorms[i] := Sqrt(SumOfSquares(FEmbeddings[i]));
    Exit;
  end;

  // read file header
  AssignFile(f, ModelFile);
  Reset(f);
  try
    if Eof(f) then raise Exception.Create('Empty model file');
    ReadLn(f, line);
    parts := line.Split([' '], TStringSplitOptions.ExcludeEmpty);
    if Length(parts) < 2 then raise Exception.Create('Bad model header');

    vocabSize := StrToIntDef(parts[0], 0);
    embedSize := StrToIntDef(parts[1], EMBEDDING_SIZE);

    // clamp/validate
    if vocabSize <= 0 then vocabSize := 10000;
    if embedSize <= 0 then embedSize := EMBEDDING_SIZE;
    vocabSize := Min(vocabSize, 500000);
    embedSize := Min(Max(embedSize, 50), 1024);

    FEmbeddingSize := embedSize;
    SetLength(FEmbeddings, vocabSize, embedSize);

    loaded := 0;
    for i := 0 to vocabSize - 1 do
    begin
      if Eof(f) then Break;
      ReadLn(f, line);
      parts := line.Split([' '], TStringSplitOptions.ExcludeEmpty);
      if Length(parts) < embedSize + 1 then Continue;
      // normalize word once: lowercased
      parts[0] := UTF8LowerCase(Trim(parts[0]));
      FVocab.Add(parts[0]);
      for j := 0 to embedSize - 1 do
        FEmbeddings[loaded][j] := StrToFloatDef(parts[j+1], 0.0);
      Inc(loaded);
      if (loaded mod 10000) = 0 then
        Writeln('Loaded ', loaded, ' embeddings...');
    end;

    // shrink to loaded
    if loaded <> vocabSize then
      SetLength(FEmbeddings, loaded);

    // build dictionary
    for i := 0 to FVocab.Count - 1 do
      FWordToIndex.Add(LowerCase(FVocab[i]), i);

    // precompute norms
    SetLength(FEmbeddingNorms, FVocab.Count);
    for i := 0 to FVocab.Count - 1 do
      FEmbeddingNorms[i] := Sqrt(SumOfSquares(FEmbeddings[i]));

    Writeln('Model loaded: vocab=', FVocab.Count, ' embSize=', FEmbeddingSize);
  finally
    CloseFile(f);
  end;
end;

destructor TWordEmbeddings.Destroy;
begin
  CacheClear;
  FCacheMap.Free;
  FWordToIndex.Free;
  FVocab.Free;
  SetLength(FEmbeddings, 0);
  SetLength(FEmbeddingNorms, 0);
  inherited Destroy;
end;

function TWordEmbeddings.SumOfSquares(const Vec: TDoubleArray): Double;
var i: Integer; s: Double;
begin
  s := 0.0;
  for i := 0 to High(Vec) do s := s + Vec[i]*Vec[i];
  Result := s;
end;

{ -------------------- Cache implementation (LRU via arrays) -------------------- }

procedure TWordEmbeddings.CacheInit(Capacity: Integer);
var i: Integer;
begin
  FCacheCapacity := Capacity;
  SetLength(FCacheNodes, FCacheCapacity);
  for i := 0 to FCacheCapacity - 1 do
  begin
    FCacheNodes[i].Word := '';
    SetLength(FCacheNodes[i].Embedding, 0);
    FCacheNodes[i].Prev := -1;
    FCacheNodes[i].Next := -1;
  end;
  FCacheMap.Clear;
  FCacheCount := 0;
  FCacheHead := -1;
  FCacheTail := -1;
end;

procedure TWordEmbeddings.CacheClear;
var i: Integer;
begin
  FCacheMap.Clear;
  for i := 0 to High(FCacheNodes) do
  begin
    FCacheNodes[i].Word := '';
    SetLength(FCacheNodes[i].Embedding, 0);
    FCacheNodes[i].Prev := -1;
    FCacheNodes[i].Next := -1;
  end;
  FCacheCount := 0;
  FCacheHead := -1;
  FCacheTail := -1;
end;

procedure TWordEmbeddings.CacheMoveToFront(nodeIdx: Integer);
var prevIdx, nextIdx: Integer;
begin
  if nodeIdx < 0 then Exit;
  if nodeIdx = FCacheHead then Exit;

  prevIdx := FCacheNodes[nodeIdx].Prev;
  nextIdx := FCacheNodes[nodeIdx].Next;

  // unlink
  if prevIdx >= 0 then FCacheNodes[prevIdx].Next := nextIdx;
  if nextIdx >= 0 then FCacheNodes[nextIdx].Prev := prevIdx;

  if nodeIdx = FCacheTail then FCacheTail := prevIdx;

  // put at head
  FCacheNodes[nodeIdx].Prev := -1;
  FCacheNodes[nodeIdx].Next := FCacheHead;
  if FCacheHead >= 0 then FCacheNodes[FCacheHead].Prev := nodeIdx;
  FCacheHead := nodeIdx;

  if FCacheTail = -1 then FCacheTail := nodeIdx;
end;

procedure TWordEmbeddings.CacheInsert(const Word: string; const Emb: TDoubleArray);
var nodeIdx: Integer;
    i, evictIdx: Integer;
begin
  // if already in cache, replace embedding and move to front
  if FCacheMap.TryGetValue(Word, nodeIdx) then
  begin
    FCacheNodes[nodeIdx].Embedding := Emb;
    CacheMoveToFront(nodeIdx);
    Exit;
  end;

  if FCacheCount < FCacheCapacity then
  begin
    // find first empty slot
    for i := 0 to FCacheCapacity - 1 do
      if FCacheNodes[i].Word = '' then
      begin
        nodeIdx := i;
        Break;
      end;
    Inc(FCacheCount);
  end
  else
  begin
    // evict tail
    evictIdx := FCacheTail;
    if evictIdx < 0 then evictIdx := 0;
    // remove mapping
    if FCacheNodes[evictIdx].Word <> '' then
      FCacheMap.Remove(FCacheNodes[evictIdx].Word);
    nodeIdx := evictIdx;
    // unlink evictIdx from list
    if FCacheNodes[evictIdx].Prev >= 0 then
      FCacheNodes[FCacheNodes[evictIdx].Prev].Next := -1;
    FCacheTail := FCacheNodes[evictIdx].Prev;
    // we will overwrite nodeIdx
  end;

  // write node
  FCacheNodes[nodeIdx].Word := Word;
  FCacheNodes[nodeIdx].Embedding := Emb; // copy reference
  FCacheNodes[nodeIdx].Prev := -1;
  FCacheNodes[nodeIdx].Next := FCacheHead;
  if FCacheHead >= 0 then FCacheNodes[FCacheHead].Prev := nodeIdx;
  FCacheHead := nodeIdx;
  if FCacheTail = -1 then FCacheTail := nodeIdx;

  FCacheMap.Add(Word, nodeIdx);
end;

function TWordEmbeddings.CacheGet(const Word: string; out Emb: TDoubleArray): Boolean;
var nodeIdx: Integer;
begin
  Result := False;
  if FCacheMap.TryGetValue(Word, nodeIdx) then
  begin
    Emb := FCacheNodes[nodeIdx].Embedding;
    CacheMoveToFront(nodeIdx);
    Result := True;
  end;
end;

{ -------------------- Index & embedding access -------------------- }

function TWordEmbeddings.GetWordIndex(const Word: string): Integer;
var key: string;
begin
  key := UTF8LowerCase(Trim(Word));
  if key = '' then Exit(-1);
  if FWordToIndex.TryGetValue(key, Result) then Exit;
  Result := -1;
end;

function TWordEmbeddings.GetEmbedding(const Word: string): TDoubleArray;
var idx: Integer; key: string;
begin
  key := UTF8LowerCase(Trim(Word));
  if key = '' then begin SetLength(Result,0); Exit; end;
  idx := GetWordIndex(key);
  if (idx >= 0) and (idx < Length(FEmbeddings)) then
  begin
    // return a copy to be safe
    Result := Copy(FEmbeddings[idx]);
  end
  else
    SetLength(Result,0);
end;

function TWordEmbeddings.GetEmbeddingFastByIndex(Index: Integer): TDoubleArray;
begin
  if (Index >= 0) and (Index < FVocab.Count) then
    Result := FEmbeddings[Index] // reference (fast)
  else
    SetLength(Result, 0);
end;

function TWordEmbeddings.GetEmbeddingByIndex(Index: Integer): TDoubleArray;
begin
  if (Index >= 0) and (Index < FVocab.Count) then
    Result := Copy(FEmbeddings[Index])
  else
    SetLength(Result, 0);
end;

function TWordEmbeddings.GetEmbeddingWithCache(const Word: string): TDoubleArray;
var key: string; emb: TDoubleArray; idx: Integer;
begin
  key := UTF8LowerCase(Trim(Word));
  if key = '' then begin SetLength(Result,0); Exit; end;

  // first check cache
  if CacheGet(key, emb) then
  begin
    Result := Copy(emb);
    Exit;
  end;

  // not in cache -> find index
  idx := GetWordIndex(key);
  if idx < 0 then begin SetLength(Result,0); Exit; end;

  // load embedding and insert to cache
  emb := Copy(FEmbeddings[idx]);
  CacheInsert(key, emb);
  Result := Copy(emb);
end;

{ -------------------- Similarity functions -------------------- }

function TWordEmbeddings.FastSimilarityScore(const Emb1, Emb2: TDoubleArray): Double;
var i: Integer; s: Double;
begin
  Result := 0.0;
  if (Length(Emb1) = 0) or (Length(Emb2) = 0) then Exit(0.0);
  for i := 0 to Min(High(Emb1), High(Emb2)) do
    Result := Result + Emb1[i] * Emb2[i];
end;

function TWordEmbeddings.FastSimilarity(const Word1, Word2: string): Double;
var i1, i2: Integer;
begin
  i1 := GetWordIndex(Word1);
  i2 := GetWordIndex(Word2);
  if (i1 < 0) or (i2 < 0) then Exit(0.0);
  if (i1 >= Length(FEmbeddingNorms)) or (i2 >= Length(FEmbeddingNorms)) then Exit(0.0);
  Result := FastSimilarityScore(FEmbeddings[i1], FEmbeddings[i2]) / (FEmbeddingNorms[i1] * FEmbeddingNorms[i2] + 1e-12);
end;

function TWordEmbeddings.Similarity(const Word1, Word2: string): Double;
var Emb1, Emb2: TDoubleArray; i: Integer; dot, n1, n2: Double;
begin
  Emb1 := GetEmbeddingWithCache(Word1);
  Emb2 := GetEmbeddingWithCache(Word2);
  if (Length(Emb1) = 0) or (Length(Emb2) = 0) then Exit(0.0);
  dot := 0.0; n1 := 0.0; n2 := 0.0;
  for i := 0 to Min(High(Emb1), High(Emb2)) do
  begin
    dot := dot + Emb1[i]*Emb2[i];
    n1 := n1 + Emb1[i]*Emb1[i];
    n2 := n2 + Emb2[i]*Emb2[i];
  end;
  if (n1 <= 0) or (n2 <= 0) then Exit(0.0);
  Result := dot / (Sqrt(n1) * Sqrt(n2));
end;

{ -------------------- Sorting helpers -------------------- }

procedure TWordEmbeddings.QuickSort(var A: TScoreArray; L, R: Integer);
var I, J: Integer; Pivot: Double; Temp: TScore;
begin
  I := L; J := R;
  Pivot := A[(L+R) shr 1].Score;
  repeat
    while A[I].Score > Pivot do Inc(I);
    while A[J].Score < Pivot do Dec(J);
    if I <= J then
    begin
      if I < J then SwapScore(A[I], A[J]);
      Inc(I); Dec(J);
    end;
  until I > J;
  if L < J then QuickSort(A, L, J);
  if I < R then QuickSort(A, I, R);
end;

procedure TWordEmbeddings.MergeArrays(var A: TScoreArray; L, M, R: Integer; var Temp: TScoreArray);
var I, J, K: Integer;
begin
  I := L; J := M+1; K := 0;
  while (I <= M) and (J <= R) do
  begin
    if A[I].Score >= A[J].Score then
    begin
      Temp[K] := A[I]; Inc(I);
    end else
    begin
      Temp[K] := A[J]; Inc(J);
    end;
    Inc(K);
  end;
  while I <= M do begin Temp[K] := A[I]; Inc(I); Inc(K); end;
  while J <= R do begin Temp[K] := A[J]; Inc(J); Inc(K); end;
  Move(Temp[0], A[L], K * SizeOf(TScore));
end;

procedure TWordEmbeddings.MergeSort(var A: TScoreArray; L, R: Integer; var Temp: TScoreArray);
var M: Integer;
begin
  if R - L <= MERGE_THRESHOLD then
  begin
    QuickSort(A, L, R);
    Exit;
  end;
  M := (L + R) shr 1;
  MergeSort(A, L, M, Temp);
  MergeSort(A, M+1, R, Temp);
  MergeArrays(A, L, M, R, Temp);
end;

{ -------------------- MostSimilar -------------------- }

function TWordEmbeddings.MostSimilar(const Word: string; TopN: Integer): TStringArray;
var
  target: TDoubleArray;
  i, n, take: Integer;
  scores: TScoreArray;
  temp: TScoreArray;
begin
  target := GetEmbeddingWithCache(Word);
  if Length(target) = 0 then Exit(nil);

  n := FVocab.Count;
  SetLength(scores, n);
  for i := 0 to n-1 do
  begin
    scores[i].Word := FVocab[i];
    scores[i].Score := FastSimilarityScore(target, FEmbeddings[i]) / (FEmbeddingNorms[i] + 1e-12);
  end;

  if n = 0 then Exit(nil);
  if n = 1 then
  begin
    SetLength(Result, 1);
    Result[0] := scores[0].Word;
    Exit;
  end;

  if n <= MERGE_THRESHOLD then
    QuickSort(scores, 0, n-1)
  else
  begin
    SetLength(temp, n);
    MergeSort(scores, 0, n-1, temp);
    SetLength(temp, 0);
  end;

  take := Min(TopN, n);
  SetLength(Result, take);
  for i := 0 to take-1 do Result[i] := scores[i].Word;
end;

{ -------------------- Utility -------------------- }

procedure TWordEmbeddings.ClearCache;
begin
  CacheClear;
end;

function TWordEmbeddings.GetCacheStats: string;
begin
  Result := Format('Cache: capacity=%d count=%d', [FCacheCapacity, FCacheCount]);
end;

end.