43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
from pathlib import Path
|
|
from journal_bot.transcribe import Transcriber, TranscriptionResult
|
|
|
|
|
|
class FakeModel:
|
|
def __init__(self, segments_text: str | None = None, raises: Exception | None = None):
|
|
self.segments_text = segments_text
|
|
self.raises = raises
|
|
|
|
def transcribe(self, audio_path, language=None, beam_size=5):
|
|
if self.raises:
|
|
raise self.raises
|
|
seg = type("S", (), {"text": self.segments_text})()
|
|
return [seg], type("I", (), {"language": "de"})()
|
|
|
|
|
|
def test_transcribe_returns_text(tmp_path):
|
|
audio = tmp_path / "x.ogg"
|
|
audio.write_bytes(b"fake")
|
|
t = Transcriber(model_factory=lambda: FakeModel("Hallo Welt"))
|
|
res = t.transcribe(audio)
|
|
assert res.success
|
|
assert res.text == "Hallo Welt"
|
|
|
|
|
|
def test_transcribe_with_premium_fallback_when_model_fails(tmp_path):
|
|
audio = tmp_path / "x.ogg"
|
|
audio.write_bytes(b"fake")
|
|
t = Transcriber(model_factory=lambda: FakeModel(raises=RuntimeError("boom")))
|
|
res = t.transcribe(audio, premium_text="Premium Transkript")
|
|
assert res.success
|
|
assert res.text == "Premium Transkript"
|
|
assert res.source == "premium"
|
|
|
|
|
|
def test_transcribe_returns_failure_when_no_fallback(tmp_path):
|
|
audio = tmp_path / "x.ogg"
|
|
audio.write_bytes(b"fake")
|
|
t = Transcriber(model_factory=lambda: FakeModel(raises=RuntimeError("boom")))
|
|
res = t.transcribe(audio)
|
|
assert not res.success
|
|
assert res.text == ""
|