// Copyright (c) Free Mind Labs, Inc. All rights reserved.
using System.Runtime.CompilerServices;
using Elastic.Clients.Elasticsearch;
using Elastic.Clients.Elasticsearch.Mapping;
using Elastic.Clients.Elasticsearch.QueryDsl;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryStorage;
namespace FreeMindLabs.KernelMemory.Elasticsearch;
///
/// Elasticsearch connector for Kernel Memory.
///
public class ElasticsearchMemory : IMemoryDb
{
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly IIndexNameHelper _indexNameHelper;
private readonly ElasticsearchConfig _config;
private readonly ILogger _log;
private readonly ElasticsearchClient _client;
///
/// Create a new instance of Elasticsearch KM connector
///
/// Elasticsearch configuration
/// Elasticsearch client
/// Application logger
/// Embedding generator
/// Index name helper
public ElasticsearchMemory(
ElasticsearchConfig config,
ElasticsearchClient client,
ITextEmbeddingGenerator embeddingGenerator,
IIndexNameHelper indexNameHelper,
ILogger? log = null)
{
this._embeddingGenerator = embeddingGenerator ?? throw new ArgumentNullException(nameof(embeddingGenerator));
this._indexNameHelper = indexNameHelper ?? throw new ArgumentNullException(nameof(indexNameHelper));
this._config = config ?? throw new ArgumentNullException(nameof(config));
this._client = client;// new ElasticsearchClient(this._config.ToElasticsearchClientSettings()); // TODO: inject
this._log = log ?? DefaultLogger.Instance;
}
///
public async Task CreateIndexAsync(
string index,
int vectorSize,
CancellationToken cancellationToken = default)
{
index = this._indexNameHelper.Convert(index);
var existsResponse = await this._client.Indices.ExistsAsync(index, cancellationToken).ConfigureAwait(false);
if (existsResponse.Exists)
{
this._log.LogTrace("{MethodName}: Index {Index} already exists.", nameof(CreateIndexAsync), index);
return;
}
var createIdxResponse = await this._client.Indices.CreateAsync(index,
cfg =>
{
cfg.Settings(setts =>
{
setts.NumberOfShards(this._config.ShardCount);
setts.NumberOfReplicas(this._config.ReplicaCount);
});
},
cancellationToken).ConfigureAwait(false);
const int Dimensions = 1536; // TODO: make not hardcoded
var np = new NestedProperty()
{
Properties = new Properties()
{
{ ElasticsearchTag.NameField, new KeywordProperty() },
{ ElasticsearchTag.ValueField, new KeywordProperty() }
}
};
var mapResponse = await this._client.Indices.PutMappingAsync(index, x => x
.Properties(propDesc =>
{
propDesc.Keyword(x => x.Id);
propDesc.Nested(ElasticsearchMemoryRecord.TagsField, np);
propDesc.Text(x => x.Payload, pd => pd.Index(false));
propDesc.Text(x => x.Content);
propDesc.DenseVector(x => x.Vector, d => d.Index(true).Dims(Dimensions).Similarity("cosine"));
this._config.ConfigureProperties?.Invoke(propDesc);
}),
cancellationToken).ConfigureAwait(false);
this._log.LogTrace("{MethodName}: Index {Index} creeated.", nameof(CreateIndexAsync), index);
}
///
public async Task> GetIndexesAsync(
CancellationToken cancellationToken = default)
{
var resp = await this._client.Indices.GetAsync(this._config.IndexPrefix + "*", cancellationToken).ConfigureAwait(false);
var names = resp.Indices
.Select(x => x.Key.ToString().Replace(this._config.IndexPrefix, string.Empty, StringComparison.Ordinal))
.ToHashSet(StringComparer.OrdinalIgnoreCase);
this._log.LogTrace("{MethodName}: Returned {IndexCount} indices: {Indices}.", nameof(GetIndexesAsync), names.Count, string.Join(", ", names));
return names;
}
///
public async Task DeleteIndexAsync(
string index,
CancellationToken cancellationToken = default)
{
index = this._indexNameHelper.Convert(index);
var delResponse = await this._client.Indices.DeleteAsync(
index,
cancellationToken).ConfigureAwait(false);
if (delResponse.IsSuccess())
{
this._log.LogTrace("{MethodName}: Index {Index} deleted.", nameof(DeleteIndexAsync), index);
}
else
{
this._log.LogWarning("{MethodName}: Index {Index} delete failed.", nameof(DeleteIndexAsync), index);
}
}
///
public async Task DeleteAsync(
string index,
MemoryRecord record,
CancellationToken cancellationToken = default)
{
index = this._indexNameHelper.Convert(index);
record = record ?? throw new ArgumentNullException(nameof(record));
var delResponse = await this._client.DeleteAsync(
index,
record.Id,
(delReq) =>
{
delReq.Refresh(Refresh.WaitFor);
},
cancellationToken)
.ConfigureAwait(false);
if (delResponse.IsSuccess())
{
this._log.LogTrace("{MethodName}: Record {RecordId} deleted.", nameof(DeleteAsync), record.Id);
}
else
{
this._log.LogWarning("{MethodName}: Record {RecordId} delete failed.", nameof(DeleteAsync), record.Id);
}
}
///
public async Task UpsertAsync(
string index,
MemoryRecord record,
CancellationToken cancellationToken = default)
{
index = this._indexNameHelper.Convert(index);
var memRec = ElasticsearchMemoryRecord.FromMemoryRecord(record);
var response = await this._client.UpdateAsync(
index,
memRec.Id,
(updateReq) =>
{
updateReq.Refresh(Refresh.WaitFor);
var memRec2 = memRec;
updateReq.Doc(memRec2);
updateReq.DocAsUpsert(true);
},
cancellationToken)
.ConfigureAwait(false);
if (response.IsSuccess())
{
this._log.LogTrace("{MethodName}: Record {RecordId} upserted.", nameof(UpsertAsync), memRec.Id);
}
else
{
this._log.LogError("{MethodName}: Record {RecordId} upsert failed.", nameof(UpsertAsync), memRec.Id);
}
return response.Id;
}
///
public async IAsyncEnumerable<(MemoryRecord, double)> GetSimilarListAsync(
string index,
string text,
ICollection? filters = null,
double minRelevance = 0, int limit = 1, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (limit < 0)
{
limit = 10;
}
index = this._indexNameHelper.Convert(index);
this._log.LogTrace("{MethodName}: Searching for '{Text}' on index '{IndexName}' with filters {Filters}. {MinRelevance} {Limit} {WithEmbeddings}",
nameof(GetSimilarListAsync), text, index, filters.ToDebugString(), minRelevance, limit, withEmbeddings);
Embedding qembed = await this._embeddingGenerator.GenerateEmbeddingAsync(text, cancellationToken).ConfigureAwait(false);
var coll = qembed.Data.ToArray();
var resp = await this._client.SearchAsync(s =>
s.Index(index)
.Knn(qd =>
{
qd.k(limit)
.Filter(q => this.ConvertTagFilters(q, filters))
.NumCandidates(limit + 100)
.Field(x => x.Vector)
.QueryVector(coll);
}),
cancellationToken)
.ConfigureAwait(false);
if ((resp.HitsMetadata is null) || (resp.HitsMetadata.Hits is null))
{
this._log.LogWarning("The search returned a null result. Should retry?");
yield break;
}
foreach (var hit in resp.HitsMetadata.Hits)
{
if (hit?.Source == null)
{
continue;
}
this._log.LogTrace("{MethodName} Hit: {HitScore}, {HitId}", nameof(GetSimilarListAsync), hit.Score, hit.Id);
yield return (hit.Source!.ToMemoryRecord(), hit.Score ?? 0);
}
}
///
public async IAsyncEnumerable GetListAsync(
string index,
ICollection? filters = null,
int limit = 1,
bool withEmbeddings = false,
[EnumeratorCancellation]
CancellationToken cancellationToken = default)
{
this._log.LogTrace("{MethodName}: querying index '{IndexName}' with filters {Filters}. {Limit} {WithEmbeddings}",
nameof(GetListAsync), index, filters.ToDebugString(), limit, withEmbeddings);
if (limit < 0)
{
limit = 10;
}
index = this._indexNameHelper.Convert(index);
var resp = await this._client.SearchAsync(s =>
s.Index(index)
.Size(limit)
.Query(qd =>
{
this.ConvertTagFilters(qd, filters);
}),
cancellationToken)
.ConfigureAwait(false);
if ((resp.HitsMetadata is null) || (resp.HitsMetadata.Hits is null))
{
yield break;
}
foreach (var hit in resp.Hits)
{
if (hit?.Source == null)
{
continue;
}
this._log.LogTrace("{MethodName} Hit: {HitScore}, {HitId}", nameof(GetListAsync), hit.Score, hit.Id);
yield return hit.Source!.ToMemoryRecord();
}
}
//private string ConvertIndexName(string index) => ESIndexName.Convert(this._config.IndexPrefix + index);
private QueryDescriptor ConvertTagFilters(
QueryDescriptor qd,
ICollection? filters = null)
{
if ((filters == null) || (filters.Count == 0))
{
qd.MatchAll();
return qd;
}
filters = filters.Where(f => f.Keys.Count > 0)
.ToList(); // Remove empty filters
if (filters.Count == 0)
{
qd.MatchAll();
return qd;
}
foreach (MemoryFilter filter in filters)
{
List all = new();
// Each tag collection is an element of a List>>
foreach (var tagName in filter.Keys)
{
List tagValues = filter[tagName];
List terms = tagValues.Select(x => (FieldValue)(x ?? FieldValue.Null))
.ToList();
// ----------------
Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.Tags_Name) { Value = tagName };
newTagQuery &= new TermsQuery()
{
Field = ElasticsearchMemoryRecord.Tags_Value,
Terms = new TermsQueryField(terms)
};
var nestedQd = new NestedQuery();
nestedQd.Path = ElasticsearchMemoryRecord.TagsField;
nestedQd.Query = newTagQuery;
all.Add(nestedQd);
qd.Bool(bq => bq.Must(all.ToArray()));
}
}
// ---------------------
//qd.Nested(nqd =>
//{
// nqd.Path(ElasticsearchMemoryRecord.TagsField);
// nqd.Query(nq =>
// {
// // Each filter is a tag collection.
// foreach (MemoryFilter filter in filters)
// {
// List all = new();
// // Each tag collection is an element of a List>>
// foreach (var tagName in filter.Keys)
// {
// List tagValues = filter[tagName];
// List terms = tagValues.Select(x => (FieldValue)(x ?? FieldValue.Null))
// .ToList();
// // ----------------
// Query newTagQuery = new TermQuery(ElasticsearchMemoryRecord.Tags_Name) { Value = tagName };
// newTagQuery &= new TermsQuery() {
// Field = ElasticsearchMemoryRecord.Tags_Value,
// Terms = new TermsQueryField(terms)
// };
// all.Add(newTagQuery);
// }
// nq.Bool(bq => bq.Must(all.ToArray()));
// }
// });
//});
return qd;
}
}