- Notifications
You must be signed in to change notification settings - Fork 359
/
Copy pathEmbedding.cs
122 lines (100 loc) · 4.27 KB
/
Embedding.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Copyright (c) Microsoft. All rights reserved.
usingSystem;
usingSystem.Numerics.Tensors;
usingSystem.Runtime.InteropServices;
usingSystem.Text.Json;
usingSystem.Text.Json.Serialization;
#pragma warning disable IDE0130// first class concept we want to have readily available
#pragma warning disable CA2225// no need for explicit methods
// ReSharper disable once CheckNamespace - reduce number of "using" statements
namespaceMicrosoft.KernelMemory;
/// <summary>
/// Note: use Embedding.JsonConverter to serialize objects using this type.
/// Example:
/// [JsonPropertyName("vector")]
/// [JsonConverter(typeof(Embedding.JsonConverter))]
/// public Embedding Vector { get; set; }
/// </summary>
publicstructEmbedding:IEquatable<Embedding>
{
/// <summary>
/// Note: use Embedding.JsonConverter to serialize objects using this type.
/// </summary>
[JsonIgnore]
publicReadOnlyMemory<float>Data{get;set;}=new();
/// <summary>
/// Note: use Embedding.JsonConverter to serialize objects using this type.
/// </summary>
[JsonIgnore]
publicreadonlyintLength=>this.Data.Length;
publicEmbedding(float[]vector)
{
this.Data=vector;
}
/// <summary>
/// This is not a ctor on purpose so we can use collections syntax with
/// the main ctor, and surface the extra casting cost when not using floats.
/// </summary>
publicstaticEmbeddingFromDoubles(double[]vector)
{
float[]f=newfloat[vector.Length];
for(inti=0;i<vector.Length;i++){f[i]=(float)vector[i];}
returnnewEmbedding(f);
}
publicEmbedding(ReadOnlyMemory<float>vector)
{
this.Data=vector;
}
publicEmbedding(intsize)
{
this.Data=newReadOnlyMemory<float>(newfloat[size]);
}
publicreadonlydoubleCosineSimilarity(Embeddingembedding)
{
varsize1=this.Data.Span.Length;
varsize2=embedding.Data.Span.Length;
if(size1!=size2)
{
thrownewInvalidOperationException(
"Embedding vectors must have the same length to calculate cosine similarity. "+
$"Embedding 1 length: {size1}; Embedding 2 length: {size2}.");
}
returnTensorPrimitives.CosineSimilarity(this.Data.Span,embedding.Data.Span);
}
/// <summary>
/// Convert Semantic Kernel data type
/// </summary>
publicstaticimplicitoperatorEmbedding(ReadOnlyMemory<float>data)=>new(data);
/// <summary>
/// Allows simple embedding definition using float[]
/// </summary>
publicstaticimplicitoperatorEmbedding(float[]data)=>new(data);
publicreadonlyboolEquals(Embeddingother)=>this.Data.Equals(other.Data);
publicoverridereadonlyboolEquals(object?obj)=>(objisEmbeddingother&&this.Equals(other));
publicstaticbooloperator==(Embeddingv1,Embeddingv2)=>v1.Equals(v2);
publicstaticbooloperator!=(Embeddingv1,Embeddingv2)=>!(v1==v2);
publicoverridereadonlyintGetHashCode()=>this.Data.GetHashCode();
/// <summary>
/// Note: use Embedding.JsonConverter to serialize objects using
/// the Embedding type, for example:
/// [JsonPropertyName("vector")]
/// [JsonConverter(typeof(Embedding.JsonConverter))]
/// public Embedding Vector { get; set; }
/// </summary>
publicsealedclassJsonConverter:JsonConverter<Embedding>
{
/// <summary>An instance of a converter for float[] that all operations delegate to</summary>
privatestaticreadonlyJsonConverter<float[]>s_converter=
(JsonConverter<float[]>)newJsonSerializerOptions().GetConverter(typeof(float[]));
publicoverrideEmbeddingRead(refUtf8JsonReaderreader,TypetypeToConvert,JsonSerializerOptionsoptions)
{
returnnewEmbedding(s_converter.Read(refreader,typeof(float[]),options)??[]);
}
publicoverridevoidWrite(Utf8JsonWriterwriter,Embeddingvalue,JsonSerializerOptionsoptions)
{
s_converter.Write(writer,MemoryMarshal.TryGetArray(value.Data,outArraySegment<float>array)&&array.Count==value.Length
?array.Array!
:value.Data.ToArray(),options);
}
}
}