- Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathbase_rerank_unit_tests.py
128 lines (105 loc) · 4.1 KB
/
base_rerank_unit_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
importasyncio
importhttpx
importjson
importpytest
importsys
fromtypingimportAny, Dict, List
fromunittest.mockimportMagicMock, Mock, patch
importos
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
importlitellm
fromlitellm.exceptionsimportBadRequestError
fromlitellm.llms.custom_httpx.http_handlerimportAsyncHTTPHandler, HTTPHandler
fromlitellm.utilsimport (
CustomStreamWrapper,
get_supported_openai_params,
get_optional_params,
)
# test_example.py
fromabcimportABC, abstractmethod
defassert_response_shape(response, custom_llm_provider):
expected_response_shape= {"id": str, "results": list, "meta": dict}
expected_results_shape= {"index": int, "relevance_score": float}
expected_meta_shape= {"api_version": dict, "billed_units": dict}
expected_api_version_shape= {"version": str}
expected_billed_units_shape= {"search_units": int}
expected_billed_units_total_tokens_shape= {"total_tokens": int}
assertisinstance(response.id, expected_response_shape["id"])
assertisinstance(response.results, expected_response_shape["results"])
forresultinresponse.results:
assertisinstance(result["index"], expected_results_shape["index"])
assertisinstance(
result["relevance_score"], expected_results_shape["relevance_score"]
)
assertisinstance(response.meta, expected_response_shape["meta"])
ifcustom_llm_provider=="cohere":
assertisinstance(
response.meta["api_version"], expected_meta_shape["api_version"]
)
assertisinstance(
response.meta["api_version"]["version"],
expected_api_version_shape["version"],
)
assertisinstance(
response.meta["billed_units"], expected_meta_shape["billed_units"]
)
if"total_tokens"inresponse.meta["billed_units"]:
assertisinstance(
response.meta["billed_units"]["total_tokens"],
expected_billed_units_total_tokens_shape["total_tokens"],
)
else:
assertisinstance(
response.meta["billed_units"]["search_units"],
expected_billed_units_shape["search_units"],
)
classBaseLLMRerankTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
defget_base_rerank_call_args(self) ->dict:
"""Must return the base rerank call args"""
pass
@abstractmethod
defget_custom_llm_provider(self) ->litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
@pytest.mark.asyncio()
@pytest.mark.parametrize("sync_mode", [True, False])
asyncdeftest_basic_rerank(self, sync_mode):
litellm._turn_on_debug()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] ="True"
litellm.model_cost=litellm.get_model_cost_map(url="")
rerank_call_args=self.get_base_rerank_call_args()
custom_llm_provider=self.get_custom_llm_provider()
ifsync_modeisTrue:
response=litellm.rerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=2,
)
print("re rank response: ", response)
assertresponse.idisnotNone
assertresponse.resultsisnotNone
assertresponse._hidden_params["response_cost"] isnotNone
assertresponse._hidden_params["response_cost"] >0
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)
else:
response=awaitlitellm.arerank(
**rerank_call_args,
query="hello",
documents=["hello", "world"],
top_n=2,
)
print("async re rank response: ", response)
assertresponse.idisnotNone
assertresponse.resultsisnotNone
assert_response_shape(
response=response, custom_llm_provider=custom_llm_provider.value
)