feat: add edge tts fallback provider

This commit is contained in:
Peter Steinberger
2026-01-25 01:05:23 +00:00
parent 6a7a1d7085
commit fc0e303e05
11 changed files with 466 additions and 32 deletions
+118 -1
View File
@@ -4,7 +4,7 @@ import { completeSimple } from "@mariozechner/pi-ai";
import { getApiKeyForModel } from "../agents/model-auth.js";
import { resolveModel } from "../agents/pi-embedded-runner/model.js";
import { _test, resolveTtsConfig } from "./tts.js";
import { _test, getTtsProvider, resolveTtsConfig } from "./tts.js";
vi.mock("@mariozechner/pi-ai", () => ({
completeSimple: vi.fn(),
@@ -47,6 +47,7 @@ const {
resolveModelOverridePolicy,
summarizeText,
resolveOutputFormat,
resolveEdgeOutputFormat,
} = _test;
describe("tts", () => {
@@ -149,6 +150,30 @@ describe("tts", () => {
});
});
describe("resolveEdgeOutputFormat", () => {
const baseCfg = {
agents: { defaults: { model: { primary: "openai/gpt-4o-mini" } } },
messages: { tts: {} },
};
it("uses default output format when edge output format is not configured", () => {
const config = resolveTtsConfig(baseCfg);
expect(resolveEdgeOutputFormat(config)).toBe("audio-24khz-48kbitrate-mono-mp3");
});
it("uses configured output format when provided", () => {
const config = resolveTtsConfig({
...baseCfg,
messages: {
tts: {
edge: { outputFormat: "audio-24khz-96kbitrate-mono-mp3" },
},
},
});
expect(resolveEdgeOutputFormat(config)).toBe("audio-24khz-96kbitrate-mono-mp3");
});
});
describe("parseTtsDirectives", () => {
it("extracts overrides and strips directives when enabled", () => {
const policy = resolveModelOverridePolicy({ enabled: true });
@@ -165,6 +190,14 @@ describe("tts", () => {
expect(result.overrides.elevenlabs?.voiceSettings?.speed).toBe(1.1);
});
it("accepts edge as provider override", () => {
const policy = resolveModelOverridePolicy({ enabled: true });
const input = "Hello [[tts:provider=edge]] world";
const result = parseTtsDirectives(input, policy);
expect(result.overrides.provider).toBe("edge");
});
it("keeps text intact when overrides are disabled", () => {
const policy = resolveModelOverridePolicy({ enabled: false });
const input = "Hello [[tts:voice=alloy]] world";
@@ -314,4 +347,88 @@ describe("tts", () => {
).rejects.toThrow("No summary returned");
});
});
describe("getTtsProvider", () => {
const baseCfg = {
agents: { defaults: { model: { primary: "openai/gpt-4o-mini" } } },
messages: { tts: {} },
};
const restoreEnv = (snapshot: Record<string, string | undefined>) => {
const keys = ["OPENAI_API_KEY", "ELEVENLABS_API_KEY", "XI_API_KEY"] as const;
for (const key of keys) {
const value = snapshot[key];
if (value === undefined) {
delete process.env[key];
} else {
process.env[key] = value;
}
}
};
const withEnv = (env: Record<string, string | undefined>, run: () => void) => {
const snapshot = {
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
ELEVENLABS_API_KEY: process.env.ELEVENLABS_API_KEY,
XI_API_KEY: process.env.XI_API_KEY,
};
try {
for (const [key, value] of Object.entries(env)) {
if (value === undefined) {
delete process.env[key];
} else {
process.env[key] = value;
}
}
run();
} finally {
restoreEnv(snapshot);
}
};
it("prefers OpenAI when no provider is configured and API key exists", () => {
withEnv(
{
OPENAI_API_KEY: "test-openai-key",
ELEVENLABS_API_KEY: undefined,
XI_API_KEY: undefined,
},
() => {
const config = resolveTtsConfig(baseCfg);
const provider = getTtsProvider(config, "/tmp/tts-prefs-openai.json");
expect(provider).toBe("openai");
},
);
});
it("prefers ElevenLabs when OpenAI is missing and ElevenLabs key exists", () => {
withEnv(
{
OPENAI_API_KEY: undefined,
ELEVENLABS_API_KEY: "test-elevenlabs-key",
XI_API_KEY: undefined,
},
() => {
const config = resolveTtsConfig(baseCfg);
const provider = getTtsProvider(config, "/tmp/tts-prefs-elevenlabs.json");
expect(provider).toBe("elevenlabs");
},
);
});
it("falls back to Edge when no API keys are present", () => {
withEnv(
{
OPENAI_API_KEY: undefined,
ELEVENLABS_API_KEY: undefined,
XI_API_KEY: undefined,
},
() => {
const config = resolveTtsConfig(baseCfg);
const provider = getTtsProvider(config, "/tmp/tts-prefs-edge.json");
expect(provider).toBe("edge");
},
);
});
});
});
+166 -10
View File
@@ -12,6 +12,7 @@ import { tmpdir } from "node:os";
import path from "node:path";
import { completeSimple, type TextContent } from "@mariozechner/pi-ai";
import { EdgeTTS } from "node-edge-tts";
import type { ReplyPayload } from "../auto-reply/types.js";
import { normalizeChannelId } from "../channels/plugins/index.js";
@@ -24,6 +25,7 @@ import type {
TtsModelOverrideConfig,
} from "../config/types.tts.js";
import { logVerbose } from "../globals.js";
import { isVoiceCompatibleAudio } from "../media/audio.js";
import { CONFIG_DIR, resolveUserPath } from "../utils.js";
import { getApiKeyForModel, requireApiKey } from "../agents/model-auth.js";
import {
@@ -45,6 +47,9 @@ const DEFAULT_ELEVENLABS_VOICE_ID = "pMsXgVXv3BLzUgSXRplE";
const DEFAULT_ELEVENLABS_MODEL_ID = "eleven_multilingual_v2";
const DEFAULT_OPENAI_MODEL = "gpt-4o-mini-tts";
const DEFAULT_OPENAI_VOICE = "alloy";
const DEFAULT_EDGE_VOICE = "en-US-MichelleNeural";
const DEFAULT_EDGE_LANG = "en-US";
const DEFAULT_EDGE_OUTPUT_FORMAT = "audio-24khz-48kbitrate-mono-mp3";
const DEFAULT_ELEVENLABS_VOICE_SETTINGS = {
stability: 0.5,
@@ -74,6 +79,7 @@ export type ResolvedTtsConfig = {
enabled: boolean;
mode: TtsMode;
provider: TtsProvider;
providerSource: "config" | "default";
summaryModel?: string;
modelOverrides: ResolvedTtsModelOverrides;
elevenlabs: {
@@ -97,6 +103,19 @@ export type ResolvedTtsConfig = {
model: string;
voice: string;
};
edge: {
enabled: boolean;
voice: string;
lang: string;
outputFormat: string;
outputFormatConfigured: boolean;
pitch?: string;
rate?: string;
volume?: string;
saveSubtitles: boolean;
proxy?: string;
timeoutMs?: number;
};
prefsPath?: string;
maxTextLength: number;
timeoutMs: number;
@@ -199,10 +218,13 @@ function resolveModelOverridePolicy(
export function resolveTtsConfig(cfg: ClawdbotConfig): ResolvedTtsConfig {
const raw: TtsConfig = cfg.messages?.tts ?? {};
const providerSource = raw.provider ? "config" : "default";
const edgeOutputFormat = raw.edge?.outputFormat?.trim();
return {
enabled: raw.enabled ?? false,
mode: raw.mode ?? "final",
provider: raw.provider ?? "elevenlabs",
provider: raw.provider ?? "edge",
providerSource,
summaryModel: raw.summaryModel?.trim() || undefined,
modelOverrides: resolveModelOverridePolicy(raw.modelOverrides),
elevenlabs: {
@@ -231,6 +253,19 @@ export function resolveTtsConfig(cfg: ClawdbotConfig): ResolvedTtsConfig {
model: raw.openai?.model ?? DEFAULT_OPENAI_MODEL,
voice: raw.openai?.voice ?? DEFAULT_OPENAI_VOICE,
},
edge: {
enabled: raw.edge?.enabled ?? true,
voice: raw.edge?.voice?.trim() || DEFAULT_EDGE_VOICE,
lang: raw.edge?.lang?.trim() || DEFAULT_EDGE_LANG,
outputFormat: edgeOutputFormat || DEFAULT_EDGE_OUTPUT_FORMAT,
outputFormatConfigured: Boolean(edgeOutputFormat),
pitch: raw.edge?.pitch?.trim() || undefined,
rate: raw.edge?.rate?.trim() || undefined,
volume: raw.edge?.volume?.trim() || undefined,
saveSubtitles: raw.edge?.saveSubtitles ?? false,
proxy: raw.edge?.proxy?.trim() || undefined,
timeoutMs: raw.edge?.timeoutMs,
},
prefsPath: raw.prefsPath,
maxTextLength: raw.maxTextLength ?? DEFAULT_MAX_TEXT_LENGTH,
timeoutMs: raw.timeoutMs ?? DEFAULT_TIMEOUT_MS,
@@ -302,7 +337,12 @@ export function setTtsEnabled(prefsPath: string, enabled: boolean): void {
export function getTtsProvider(config: ResolvedTtsConfig, prefsPath: string): TtsProvider {
const prefs = readPrefs(prefsPath);
return prefs.tts?.provider ?? config.provider;
if (prefs.tts?.provider) return prefs.tts.provider;
if (config.providerSource === "config") return config.provider;
if (resolveTtsApiKey(config, "openai")) return "openai";
if (resolveTtsApiKey(config, "elevenlabs")) return "elevenlabs";
return "edge";
}
export function setTtsProvider(prefsPath: string, provider: TtsProvider): void {
@@ -350,6 +390,10 @@ function resolveChannelId(channel: string | undefined): ChannelId | null {
return channel ? normalizeChannelId(channel) : null;
}
function resolveEdgeOutputFormat(config: ResolvedTtsConfig): string {
return config.edge.outputFormat;
}
export function resolveTtsApiKey(
config: ResolvedTtsConfig,
provider: TtsProvider,
@@ -363,6 +407,17 @@ export function resolveTtsApiKey(
return undefined;
}
export const TTS_PROVIDERS = ["openai", "elevenlabs", "edge"] as const;
export function resolveTtsProviderOrder(primary: TtsProvider): TtsProvider[] {
return [primary, ...TTS_PROVIDERS.filter((provider) => provider !== primary)];
}
export function isTtsProviderConfigured(config: ResolvedTtsConfig, provider: TtsProvider): boolean {
if (provider === "edge") return config.edge.enabled;
return Boolean(resolveTtsApiKey(config, provider));
}
function isValidVoiceId(voiceId: string): boolean {
return /^[a-zA-Z0-9]{10,40}$/.test(voiceId);
}
@@ -459,7 +514,7 @@ function parseTtsDirectives(
switch (key) {
case "provider":
if (!policy.allowProvider) break;
if (rawValue === "openai" || rawValue === "elevenlabs") {
if (rawValue === "openai" || rawValue === "elevenlabs" || rawValue === "edge") {
overrides.provider = rawValue;
} else {
warnings.push(`unsupported provider "${rawValue}"`);
@@ -893,6 +948,38 @@ async function openaiTTS(params: {
}
}
function inferEdgeExtension(outputFormat: string): string {
const normalized = outputFormat.toLowerCase();
if (normalized.includes("webm")) return ".webm";
if (normalized.includes("ogg")) return ".ogg";
if (normalized.includes("opus")) return ".opus";
if (normalized.includes("wav") || normalized.includes("riff") || normalized.includes("pcm")) {
return ".wav";
}
return ".mp3";
}
async function edgeTTS(params: {
text: string;
outputPath: string;
config: ResolvedTtsConfig["edge"];
timeoutMs: number;
}): Promise<void> {
const { text, outputPath, config, timeoutMs } = params;
const tts = new EdgeTTS({
voice: config.voice,
lang: config.lang,
outputFormat: config.outputFormat,
saveSubtitles: config.saveSubtitles,
proxy: config.proxy,
rate: config.rate,
pitch: config.pitch,
volume: config.volume,
timeout: config.timeoutMs ?? timeoutMs,
});
await tts.ttsPromise(text, outputPath);
}
export async function textToSpeech(params: {
text: string;
cfg: ClawdbotConfig;
@@ -915,19 +1002,87 @@ export async function textToSpeech(params: {
const userProvider = getTtsProvider(config, prefsPath);
const overrideProvider = params.overrides?.provider;
const provider = overrideProvider ?? userProvider;
const providers: TtsProvider[] = [provider, provider === "openai" ? "elevenlabs" : "openai"];
const providers = resolveTtsProviderOrder(provider);
let lastError: string | undefined;
for (const provider of providers) {
const apiKey = resolveTtsApiKey(config, provider);
if (!apiKey) {
lastError = `No API key for ${provider}`;
continue;
}
const providerStart = Date.now();
try {
if (provider === "edge") {
if (!config.edge.enabled) {
lastError = "edge: disabled";
continue;
}
const tempDir = mkdtempSync(path.join(tmpdir(), "tts-"));
let edgeOutputFormat = resolveEdgeOutputFormat(config);
const fallbackEdgeOutputFormat =
edgeOutputFormat !== DEFAULT_EDGE_OUTPUT_FORMAT ? DEFAULT_EDGE_OUTPUT_FORMAT : undefined;
const attemptEdgeTts = async (outputFormat: string) => {
const extension = inferEdgeExtension(outputFormat);
const audioPath = path.join(tempDir, `voice-${Date.now()}${extension}`);
await edgeTTS({
text: params.text,
outputPath: audioPath,
config: {
...config.edge,
outputFormat,
},
timeoutMs: config.timeoutMs,
});
return { audioPath, outputFormat };
};
let edgeResult: { audioPath: string; outputFormat: string };
try {
edgeResult = await attemptEdgeTts(edgeOutputFormat);
} catch (err) {
if (fallbackEdgeOutputFormat && fallbackEdgeOutputFormat !== edgeOutputFormat) {
logVerbose(
`TTS: Edge output ${edgeOutputFormat} failed; retrying with ${fallbackEdgeOutputFormat}.`,
);
edgeOutputFormat = fallbackEdgeOutputFormat;
try {
edgeResult = await attemptEdgeTts(edgeOutputFormat);
} catch (fallbackErr) {
try {
rmSync(tempDir, { recursive: true, force: true });
} catch {
// ignore cleanup errors
}
throw fallbackErr;
}
} else {
try {
rmSync(tempDir, { recursive: true, force: true });
} catch {
// ignore cleanup errors
}
throw err;
}
}
scheduleCleanup(tempDir);
const voiceCompatible = isVoiceCompatibleAudio({ fileName: edgeResult.audioPath });
return {
success: true,
audioPath: edgeResult.audioPath,
latencyMs: Date.now() - providerStart,
provider,
outputFormat: edgeResult.outputFormat,
voiceCompatible,
};
}
const apiKey = resolveTtsApiKey(config, provider);
if (!apiKey) {
lastError = `No API key for ${provider}`;
continue;
}
let audioBuffer: Buffer;
if (provider === "elevenlabs") {
const voiceIdOverride = params.overrides?.elevenlabs?.voiceId;
@@ -1120,4 +1275,5 @@ export const _test = {
resolveModelOverridePolicy,
summarizeText,
resolveOutputFormat,
resolveEdgeOutputFormat,
};