- Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathbase_audio_transcription_unit_tests.py
86 lines (75 loc) · 2.92 KB
/
base_audio_transcription_unit_tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
importhttpx
importjson
importpytest
importsys
fromtypingimportAny, Dict, List
fromunittest.mockimportMagicMock, Mock, patch
importos
importuuid
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
importlitellm
fromlitellmimporttranscription
fromlitellm.litellm_core_utils.get_supported_openai_paramsimport (
get_supported_openai_params,
)
fromlitellm.llms.base_llm.audio_transcription.transformationimport (
BaseAudioTranscriptionConfig,
)
fromlitellm.utilsimportProviderConfigManager
fromabcimportABC, abstractmethod
pwd=os.path.dirname(os.path.realpath(__file__))
print(pwd)
file_path=os.path.join(pwd, "gettysburg.wav")
audio_file=open(file_path, "rb")
classBaseLLMAudioTranscriptionTest(ABC):
@abstractmethod
defget_base_audio_transcription_call_args(self) ->dict:
"""Must return the base audio transcription call args"""
pass
@abstractmethod
defget_custom_llm_provider(self) ->litellm.LlmProviders:
"""Must return the custom llm provider"""
pass
deftest_audio_transcription(self):
"""
Test that the audio transcription is translated correctly.
"""
litellm.set_verbose=True
transcription_call_args=self.get_base_audio_transcription_call_args()
transcript=transcription(**transcription_call_args, file=audio_file)
print(f"transcript: {transcript.model_dump()}")
print(f"transcript hidden params: {transcript._hidden_params}")
asserttranscript.textisnotNone
deftest_audio_transcription_optional_params(self):
"""
Test that the audio transcription is translated correctly.
"""
transcription_args=self.get_base_audio_transcription_call_args()
model=transcription_args["model"]
custom_llm_provider=self.get_custom_llm_provider()
optional_params=get_supported_openai_params(
model=model,
custom_llm_provider=custom_llm_provider.value,
request_type="transcription",
)
print(f"optional_params: {optional_params}")
assertoptional_paramsisnotNone
assert (
"max_completion_tokens"notinoptional_params
) # assert default chat completion response not returned
deftest_audio_transcription_config(self):
"""
Test that the audio transcription config is implemented and correctly instrumented.
"""
transcription_args=self.get_base_audio_transcription_call_args()
model=transcription_args["model"]
custom_llm_provider=self.get_custom_llm_provider()
config=ProviderConfigManager.get_provider_audio_transcription_config(
model=model,
provider=custom_llm_provider,
)
print(f"config: {config}")
assertconfigisnotNone
assertisinstance(config, BaseAudioTranscriptionConfig)