feat(transcribe): faster-whisper with telegram premium fallback
This commit is contained in:
parent
37a1a58024
commit
3a9c298a2c
|
|
@ -0,0 +1,42 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
success: bool
|
||||
text: str
|
||||
source: str # "whisper" | "premium" | "none"
|
||||
|
||||
|
||||
class Transcriber:
|
||||
"""Wraps faster-whisper; injectable model_factory for tests."""
|
||||
|
||||
def __init__(self, model_factory: Optional[Callable[[], object]] = None,
|
||||
model_name: str = "large-v3", device: str = "cpu"):
|
||||
self._factory = model_factory
|
||||
self._model_name = model_name
|
||||
self._device = device
|
||||
self._model = None
|
||||
|
||||
def _get_model(self):
|
||||
if self._factory is not None:
|
||||
return self._factory()
|
||||
if self._model is None:
|
||||
from faster_whisper import WhisperModel
|
||||
self._model = WhisperModel(self._model_name, device=self._device)
|
||||
return self._model
|
||||
|
||||
def transcribe(self, audio_path: Path, premium_text: Optional[str] = None) -> TranscriptionResult:
|
||||
try:
|
||||
model = self._get_model()
|
||||
segments, _info = model.transcribe(str(audio_path), language="de", beam_size=5)
|
||||
text = " ".join(s.text.strip() for s in segments).strip()
|
||||
if text:
|
||||
return TranscriptionResult(True, text, "whisper")
|
||||
except Exception:
|
||||
pass
|
||||
if premium_text:
|
||||
return TranscriptionResult(True, premium_text, "premium")
|
||||
return TranscriptionResult(False, "", "none")
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from pathlib import Path
|
||||
from journal_bot.transcribe import Transcriber, TranscriptionResult
|
||||
|
||||
|
||||
class FakeModel:
|
||||
def __init__(self, segments_text: str | None = None, raises: Exception | None = None):
|
||||
self.segments_text = segments_text
|
||||
self.raises = raises
|
||||
|
||||
def transcribe(self, audio_path, language=None, beam_size=5):
|
||||
if self.raises:
|
||||
raise self.raises
|
||||
seg = type("S", (), {"text": self.segments_text})()
|
||||
return [seg], type("I", (), {"language": "de"})()
|
||||
|
||||
|
||||
def test_transcribe_returns_text(tmp_path):
|
||||
audio = tmp_path / "x.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
t = Transcriber(model_factory=lambda: FakeModel("Hallo Welt"))
|
||||
res = t.transcribe(audio)
|
||||
assert res.success
|
||||
assert res.text == "Hallo Welt"
|
||||
|
||||
|
||||
def test_transcribe_with_premium_fallback_when_model_fails(tmp_path):
|
||||
audio = tmp_path / "x.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
t = Transcriber(model_factory=lambda: FakeModel(raises=RuntimeError("boom")))
|
||||
res = t.transcribe(audio, premium_text="Premium Transkript")
|
||||
assert res.success
|
||||
assert res.text == "Premium Transkript"
|
||||
assert res.source == "premium"
|
||||
|
||||
|
||||
def test_transcribe_returns_failure_when_no_fallback(tmp_path):
|
||||
audio = tmp_path / "x.ogg"
|
||||
audio.write_bytes(b"fake")
|
||||
t = Transcriber(model_factory=lambda: FakeModel(raises=RuntimeError("boom")))
|
||||
res = t.transcribe(audio)
|
||||
assert not res.success
|
||||
assert res.text == ""
|
||||
Loading…
Reference in New Issue