208 lines
6.8 KiB
Python
208 lines
6.8 KiB
Python
"""Synthesize the full narration in ONE batched Qwen3-TTS call.
|
|
|
|
Reads ``output/narration-script.json`` (emitted by ``dist/preflight.js``) and
|
|
runs ``Qwen3TTSModel.generate_custom_voice`` with all cue texts as a single
|
|
batched list — that way every cue shares the same model state, which keeps
|
|
prosody and timbre consistent across cues. Per-cue WAVs and an index manifest
|
|
go to ``output/audio/`` for the recording step (which reads measured cue
|
|
durations) and the mux step (which drops each WAV at its videoTime).
|
|
|
|
Run from the ``video/`` directory:
|
|
|
|
uv run --project tts python tts/synth.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import soundfile as sf
|
|
import torch
|
|
from qwen_tts import Qwen3TTSModel
|
|
|
|
|
|
DEFAULT_MODEL = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"
|
|
DEFAULT_SPEAKER = "ryan"
|
|
DEFAULT_LANGUAGE = "English"
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"--script",
|
|
type=Path,
|
|
default=Path("output/narration-script.json"),
|
|
help="Narration script emitted by dist/preflight.js.",
|
|
)
|
|
parser.add_argument(
|
|
"--out-dir",
|
|
type=Path,
|
|
default=Path("output/audio"),
|
|
help="Directory to write WAV files and index.json into.",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default=os.environ.get("TTS_MODEL", DEFAULT_MODEL),
|
|
)
|
|
parser.add_argument(
|
|
"--speaker",
|
|
default=os.environ.get("TTS_SPEAKER", DEFAULT_SPEAKER),
|
|
help="CustomVoice preset speaker name (use --list-speakers to enumerate).",
|
|
)
|
|
parser.add_argument(
|
|
"--language",
|
|
default=os.environ.get("TTS_LANGUAGE", DEFAULT_LANGUAGE),
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
default=os.environ.get("TTS_DEVICE", "cuda:0"),
|
|
)
|
|
parser.add_argument(
|
|
"--list-speakers",
|
|
action="store_true",
|
|
help="Load the model, print available speaker names, and exit.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def load_model(model_id: str, device: str) -> Qwen3TTSModel:
|
|
dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32
|
|
print(f"[synth] loading {model_id} on {device} ({dtype})", flush=True)
|
|
return Qwen3TTSModel.from_pretrained(model_id, device_map=device, dtype=dtype)
|
|
|
|
|
|
def cached_index_matches(
|
|
index_path: Path,
|
|
cues: list[dict],
|
|
speaker: str,
|
|
language: str,
|
|
) -> bool:
|
|
"""Return True iff index_path's cue list lines up with `cues` 1:1.
|
|
|
|
Compared fields: ``cueIndex``, ``text``, ``gapBeforeMs`` plus the synth
|
|
settings (``speaker``, ``language``). All cue WAV files must also exist
|
|
on disk. Mismatched length, reordered cues, or a missing WAV invalidate
|
|
the cache.
|
|
"""
|
|
if not index_path.exists():
|
|
return False
|
|
try:
|
|
cached = json.loads(index_path.read_text())
|
|
except json.JSONDecodeError:
|
|
return False
|
|
if cached.get("speaker") != speaker or cached.get("language") != language:
|
|
return False
|
|
cached_items = cached.get("items", [])
|
|
if len(cached_items) != len(cues):
|
|
return False
|
|
for live, prev in zip(cues, cached_items):
|
|
if int(live["cueIndex"]) != int(prev.get("cueIndex", -1)):
|
|
return False
|
|
if live["text"].strip() != str(prev.get("text", "")).strip():
|
|
return False
|
|
if int(live.get("gapBeforeMs", 0)) != int(prev.get("gapBeforeMs", -1)):
|
|
return False
|
|
wav = prev.get("wav")
|
|
if not wav or not (index_path.parent / wav).exists():
|
|
return False
|
|
return True
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
|
|
if args.list_speakers:
|
|
model = load_model(args.model, args.device)
|
|
speakers = model.get_supported_speakers()
|
|
print(json.dumps(speakers, indent=2, ensure_ascii=False))
|
|
return 0
|
|
|
|
if not args.script.exists():
|
|
print(f"[synth] script not found: {args.script}", file=sys.stderr)
|
|
return 1
|
|
|
|
script = json.loads(args.script.read_text())
|
|
cues = [c for c in script.get("items", []) if c.get("text", "").strip()]
|
|
if not cues:
|
|
print("[synth] script has no cues; nothing to generate.", file=sys.stderr)
|
|
return 1
|
|
|
|
args.out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Skip generation when the existing audio matches the script — same cue
|
|
# texts and same gapBeforeMs values in the same order. Saves ~30s of GPU
|
|
# time when iterating on activity timing without changing narration.
|
|
if cached_index_matches(args.out_dir / "index.json", cues, args.speaker, args.language):
|
|
print(
|
|
f"[synth] cached audio in {args.out_dir} matches the current script — skipping generation",
|
|
flush=True,
|
|
)
|
|
return 0
|
|
|
|
model = load_model(args.model, args.device)
|
|
|
|
texts = [c["text"].strip() for c in cues]
|
|
print(f"[synth] generating {len(texts)} cues in one batched call", flush=True)
|
|
for i, t in enumerate(texts):
|
|
print(f"[synth] {i:2d}: {t}", flush=True)
|
|
|
|
# ONE batched call. generate_custom_voice handles text=List[str] natively
|
|
# and broadcasts the speaker/language across all items, so the entire
|
|
# narration is decoded in one model pass — same RNG state, same batch,
|
|
# consistent voice from cue to cue.
|
|
wavs, sr = model.generate_custom_voice(
|
|
text=texts,
|
|
language=args.language,
|
|
speaker=args.speaker,
|
|
)
|
|
if len(wavs) != len(texts):
|
|
print(
|
|
f"[synth] model returned {len(wavs)} wavs for {len(texts)} cues",
|
|
file=sys.stderr,
|
|
)
|
|
return 1
|
|
|
|
items = []
|
|
for cue, audio in zip(cues, wavs):
|
|
if hasattr(audio, "cpu"):
|
|
audio = audio.cpu().float().numpy()
|
|
wav_name = f"cue_{cue['cueIndex']:03d}.wav"
|
|
wav_path = args.out_dir / wav_name
|
|
sf.write(str(wav_path), audio, sr)
|
|
duration_ms = int(round(len(audio) * 1000 / sr))
|
|
items.append(
|
|
{
|
|
"cueIndex": cue["cueIndex"],
|
|
"text": cue["text"],
|
|
"gapBeforeMs": int(cue.get("gapBeforeMs", 0)),
|
|
"wav": wav_name,
|
|
"sampleRate": sr,
|
|
"durationMs": duration_ms,
|
|
}
|
|
)
|
|
print(
|
|
f"[synth] wrote {wav_name} {duration_ms:>5d}ms «{cue['text']}»",
|
|
flush=True,
|
|
)
|
|
|
|
out_index = {
|
|
"speaker": args.speaker,
|
|
"language": args.language,
|
|
"model": args.model,
|
|
"items": items,
|
|
}
|
|
(args.out_dir / "index.json").write_text(json.dumps(out_index, indent=2))
|
|
total_ms = sum(it["gapBeforeMs"] + it["durationMs"] for it in items)
|
|
print(
|
|
f"[synth] {len(items)} cues, {total_ms}ms of audio (incl. gaps) -> {args.out_dir}",
|
|
flush=True,
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|