- Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathtest_serialization.py
205 lines (169 loc) · 6.89 KB
/
test_serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Copyright (c) Microsoft. All rights reserved.
importtypingast
importpytest
importtyping_extensionsaste
frompydanticimportField, Json
fromsemantic_kernel.contents.chat_historyimportChatHistory
fromsemantic_kernel.core_plugins.conversation_summary_pluginimportConversationSummaryPlugin
fromsemantic_kernel.core_plugins.http_pluginimportHttpPlugin
fromsemantic_kernel.core_plugins.math_pluginimportMathPlugin
fromsemantic_kernel.core_plugins.text_memory_pluginimportTextMemoryPlugin
fromsemantic_kernel.core_plugins.text_pluginimportTextPlugin
fromsemantic_kernel.core_plugins.time_pluginimportTimePlugin
fromsemantic_kernel.core_plugins.wait_pluginimportWaitPlugin
fromsemantic_kernel.core_plugins.web_search_engine_pluginimportWebSearchEnginePlugin
fromsemantic_kernel.functions.kernel_argumentsimportKernelArguments
fromsemantic_kernel.functions.kernel_functionimportKernelFunction
fromsemantic_kernel.functions.kernel_function_decoratorimportkernel_function
fromsemantic_kernel.functions.kernel_function_metadataimportKernelFunctionMetadata
fromsemantic_kernel.functions.kernel_parameter_metadataimportKernelParameterMetadata
fromsemantic_kernel.kernel_pydanticimportKernelBaseModel
fromsemantic_kernel.memory.null_memoryimportNullMemory
fromsemantic_kernel.memory.semantic_text_memory_baseimportSemanticTextMemoryBase
fromsemantic_kernel.template_engine.blocks.blockimportBlock
fromsemantic_kernel.template_engine.blocks.block_typesimportBlockTypes
fromsemantic_kernel.template_engine.blocks.code_blockimportCodeBlock
fromsemantic_kernel.template_engine.blocks.function_id_blockimportFunctionIdBlock
fromsemantic_kernel.template_engine.blocks.named_arg_blockimportNamedArgBlock
fromsemantic_kernel.template_engine.blocks.text_blockimportTextBlock
fromsemantic_kernel.template_engine.blocks.val_blockimportValBlock
fromsemantic_kernel.template_engine.blocks.var_blockimportVarBlock
KernelBaseModelFieldT=t.TypeVar("KernelBaseModelFieldT", bound=KernelBaseModel)
class_Serializable(t.Protocol):
"""A serializable object."""
defjson(self) ->Json:
"""Return a JSON representation of the object."""
raiseNotImplementedError
@classmethod
defparse_raw(cls: t.Type[te.Self], json: Json) ->te.Self:
"""Return the constructed object from a JSON representation."""
raiseNotImplementedError
@pytest.fixture()
defkernel_factory() ->t.Callable[[t.Type[_Serializable]], _Serializable]:
"""Return a factory for various objects in semantic-kernel."""
defcreate_kernel_function() ->KernelFunction:
"""Return an KernelFunction."""
@kernel_function(name="function")
defmy_function(arguments: KernelArguments) ->str:
returnf"F({arguments['input']})"
returnKernelFunction.from_method(
plugin_name="plugin",
method=my_function,
)
defcreate_chat_history() ->ChatHistory:
returnChatHistory()
cls_obj_map= {
Block: Block(content="foo"),
CodeBlock: CodeBlock(content="foo"),
FunctionIdBlock: FunctionIdBlock(content="foo.bar"),
TextBlock: TextBlock(content="baz"),
ValBlock: ValBlock(content="'qux'"),
VarBlock: VarBlock(content="$quux"),
NamedArgBlock: NamedArgBlock(content="foo='bar'"),
# PromptTemplateEngine: PromptTemplateEngine(),
KernelParameterMetadata: KernelParameterMetadata(
name="foo",
description="bar",
default_value="baz",
type_="string",
is_required=True,
schema_data=KernelParameterMetadata.infer_schema(None, "str", "baz", "bar"),
),
KernelFunctionMetadata: KernelFunctionMetadata(
name="foo",
plugin_name="bar",
description="baz",
parameters=[
KernelParameterMetadata(
name="qux",
description="bar",
default_value="baz",
type_="str",
schema_data=KernelParameterMetadata.infer_schema(None, "str", "baz", "bar"),
)
],
is_prompt=True,
is_asynchronous=False,
),
ChatHistory: create_chat_history(),
NullMemory: NullMemory(),
KernelFunction: create_kernel_function(),
}
defconstructor(cls: t.Type[_Serializable]) ->_Serializable:
"""Return a serializable object."""
returncls_obj_map[cls]
returnconstructor
PROTOCOLS= [
ConversationSummaryPlugin,
HttpPlugin,
MathPlugin,
TextMemoryPlugin,
TextPlugin,
TimePlugin,
WaitPlugin,
WebSearchEnginePlugin,
]
BASE_CLASSES= [
SemanticTextMemoryBase,
]
STATELESS_CLASSES= [
NullMemory,
]
ENUMS= [
BlockTypes,
]
PYDANTIC_MODELS= [
Block,
CodeBlock,
FunctionIdBlock,
TextBlock,
ValBlock,
VarBlock,
NamedArgBlock,
KernelParameterMetadata,
KernelFunctionMetadata,
ChatHistory,
]
KERNEL_FUNCTION_OPTIONAL= [KernelFunction]
KERNEL_FUNCTION_REQUIRED= [
pytest.param(
KernelFunction,
marks=pytest.mark.xfail(reason="Need to implement Pickle serialization."),
)
]
classTestUsageInPydanticFields:
@pytest.mark.parametrize(
"kernel_type",
BASE_CLASSES+PROTOCOLS+ENUMS+PYDANTIC_MODELS+STATELESS_CLASSES+KERNEL_FUNCTION_OPTIONAL,
)
deftest_usage_as_optional_field(
self,
kernel_type: t.Type[KernelBaseModelFieldT],
) ->None:
"""Semantic Kernel objects should be valid Pydantic fields.
Otherwise, they cannot be used in Pydantic models.
"""
classTestModel(KernelBaseModel):
"""A test model."""
field: kernel_type|None=None
assert_serializable(TestModel(), TestModel)
@pytest.mark.parametrize("kernel_type", PYDANTIC_MODELS+STATELESS_CLASSES+KERNEL_FUNCTION_REQUIRED)
deftest_usage_as_required_field(
self,
kernel_factory: t.Callable[[t.Type[KernelBaseModelFieldT]], KernelBaseModelFieldT],
kernel_type: t.Type[KernelBaseModelFieldT],
) ->None:
"""Semantic Kernel objects should be valid Pydantic fields.
Otherwise, they cannot be used in Pydantic models.
"""
classTestModel(KernelBaseModel):
"""A test model."""
field: kernel_type=Field(default_factory=lambda: kernel_factory(kernel_type))
assert_serializable(TestModel(), TestModel)
assert_serializable(TestModel(field=kernel_factory(kernel_type)), TestModel)
defassert_serializable(obj: _Serializable, obj_type) ->None:
"""Assert that an object is serializable, uses both dump and dump_json methods."""
assertobjisnotNone
serialized=obj.model_dump_json()
assertisinstance(serialized, str)
assertobj_type.model_validate_json(serialized).model_dump() ==obj.model_dump()