From dc0ee2e178dc41672d703746edcbd17ff6281bd1 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Mon, 6 Apr 2026 01:43:08 +0100 Subject: [PATCH] feat: add music generation tooling --- .../OpenClawKit/Resources/tool-display.json | 23 + docs/.generated/config-baseline.sha256 | 8 +- .../.generated/plugin-sdk-api-baseline.sha256 | 4 +- extensions/google/index.ts | 2 + .../google/music-generation-provider.test.ts | 98 +++ .../google/music-generation-provider.ts | 186 +++++ extensions/google/openclaw.plugin.json | 1 + .../image-generation-core/src/runtime.ts | 90 +-- extensions/minimax/index.ts | 2 + .../minimax/music-generation-provider.test.ts | 104 +++ .../minimax/music-generation-provider.ts | 232 ++++++ extensions/minimax/openclaw.plugin.json | 1 + .../music-generation-providers.live.test.ts | 104 +++ .../video-generation-core/src/runtime.ts | 90 +-- package.json | 8 + src/agents/internal-events.ts | 2 +- .../media-generation-task-status-shared.ts | 100 +++ src/agents/music-generation-task-status.ts | 65 ++ src/agents/openclaw-tools.ts | 11 + .../run/attempt.prompt-helpers.test.ts | 20 +- .../run/attempt.prompt-helpers.ts | 15 +- .../pi-embedded-subscribe.tools.media.test.ts | 4 + src/agents/pi-embedded-subscribe.tools.ts | 1 + src/agents/tool-catalog.test.ts | 1 + src/agents/tool-catalog.ts | 8 + src/agents/tool-display-config.ts | 14 + .../tools/media-generate-background-shared.ts | 223 ++++++ src/agents/tools/media-tool-shared.ts | 139 +++- .../tools/music-generate-background.test.ts | 121 +++ src/agents/tools/music-generate-background.ts | 81 ++ .../tools/music-generate-tool.actions.ts | 130 ++++ .../tools/music-generate-tool.status.test.ts | 106 +++ src/agents/tools/music-generate-tool.test.ts | 273 +++++++ src/agents/tools/music-generate-tool.ts | 703 ++++++++++++++++++ src/agents/tools/video-generate-background.ts | 195 +---- src/agents/tools/video-generate-tool.ts | 111 +-- src/agents/video-generation-task-status.ts | 80 +- src/config/schema.help.ts | 4 + src/config/schema.labels.ts | 2 + src/config/types.agent-defaults.ts | 2 + src/config/zod-schema.agent-defaults.ts | 1 + src/gateway/server-plugins.test.ts | 1 + src/gateway/test-helpers.plugin-registry.ts | 1 + src/image-generation/runtime.ts | 94 +-- src/media-generation/runtime-shared.ts | 93 +++ src/music-generation/live-test-helpers.ts | 4 + src/music-generation/model-ref.ts | 16 + src/music-generation/provider-registry.ts | 77 ++ src/music-generation/runtime.ts | 129 ++++ src/music-generation/types.ts | 69 ++ src/plugin-sdk/index.ts | 1 + src/plugin-sdk/music-generation-core.ts | 28 + src/plugin-sdk/music-generation.ts | 11 + src/plugins/.DS_Store | Bin 0 -> 6148 bytes src/plugins/api-builder.ts | 5 + .../bundled-capability-metadata.test.ts | 2 + src/plugins/bundled-capability-runtime.ts | 13 + src/plugins/capability-provider-runtime.ts | 7 +- src/plugins/captured-registration.ts | 7 + src/plugins/channel-plugin-ids.ts | 1 + .../inventory/bundled-capability-metadata.ts | 3 + src/plugins/contracts/registry.ts | 28 + .../contracts/speech-vitest-registry.ts | 20 +- src/plugins/hooks.test-helpers.ts | 1 + src/plugins/loader.ts | 1 + src/plugins/manifest-registry.ts | 1 + src/plugins/manifest.ts | 3 + src/plugins/registry-empty.ts | 1 + src/plugins/registry.ts | 21 + src/plugins/runtime.test.ts | 2 + src/plugins/runtime/index.ts | 28 +- src/plugins/runtime/types-core.ts | 4 + src/plugins/status.test-helpers.ts | 2 + src/plugins/types.ts | 4 + src/test-utils/channel-plugins.ts | 1 + src/video-generation/runtime.ts | 94 +-- test/helpers/plugins/plugin-api.ts | 1 + test/helpers/plugins/plugin-runtime-mock.ts | 4 + test/helpers/plugins/provider-registration.ts | 15 +- 79 files changed, 3538 insertions(+), 620 deletions(-) create mode 100644 extensions/google/music-generation-provider.test.ts create mode 100644 extensions/google/music-generation-provider.ts create mode 100644 extensions/minimax/music-generation-provider.test.ts create mode 100644 extensions/minimax/music-generation-provider.ts create mode 100644 extensions/music-generation-providers.live.test.ts create mode 100644 src/agents/media-generation-task-status-shared.ts create mode 100644 src/agents/music-generation-task-status.ts create mode 100644 src/agents/tools/media-generate-background-shared.ts create mode 100644 src/agents/tools/music-generate-background.test.ts create mode 100644 src/agents/tools/music-generate-background.ts create mode 100644 src/agents/tools/music-generate-tool.actions.ts create mode 100644 src/agents/tools/music-generate-tool.status.test.ts create mode 100644 src/agents/tools/music-generate-tool.test.ts create mode 100644 src/agents/tools/music-generate-tool.ts create mode 100644 src/media-generation/runtime-shared.ts create mode 100644 src/music-generation/live-test-helpers.ts create mode 100644 src/music-generation/model-ref.ts create mode 100644 src/music-generation/provider-registry.ts create mode 100644 src/music-generation/runtime.ts create mode 100644 src/music-generation/types.ts create mode 100644 src/plugin-sdk/music-generation-core.ts create mode 100644 src/plugin-sdk/music-generation.ts create mode 100644 src/plugins/.DS_Store diff --git a/apps/shared/OpenClawKit/Sources/OpenClawKit/Resources/tool-display.json b/apps/shared/OpenClawKit/Sources/OpenClawKit/Resources/tool-display.json index 5453bc9a34c..b8e699bc83b 100644 --- a/apps/shared/OpenClawKit/Sources/OpenClawKit/Resources/tool-display.json +++ b/apps/shared/OpenClawKit/Sources/OpenClawKit/Resources/tool-display.json @@ -1030,6 +1030,29 @@ } } }, + "music_generate": { + "emoji": "🎵", + "title": "Music Generation", + "actions": { + "generate": { + "label": "generate", + "detailKeys": [ + "prompt", + "model", + "durationSeconds", + "format", + "instrumental" + ] + }, + "list": { + "label": "list", + "detailKeys": [ + "provider", + "model" + ] + } + } + }, "video_generate": { "emoji": "🎬", "title": "Video Generation", diff --git a/docs/.generated/config-baseline.sha256 b/docs/.generated/config-baseline.sha256 index cf97f1f64a0..1b586337049 100644 --- a/docs/.generated/config-baseline.sha256 +++ b/docs/.generated/config-baseline.sha256 @@ -1,4 +1,4 @@ -73fbcd00d17685b462dfb11aff74baae99265ae5671db28893d8608456daa44e config-baseline.json -effaf240920c16fce2c78af52dec15aa9ceb049e34f703c568669cb6beef3f91 config-baseline.core.json -3c999707b167138de34f6255e3488b99e404c5132d3fc5879a1fa12d815c31f5 config-baseline.channel.json -031b237717ca108ea2cd314413db4c91edfdfea55f808179e3066331f41af134 config-baseline.plugin.json +fb2c88ef41657f1aa7237dcce655d16313dc849fd03991b221346367c569a482 config-baseline.json +ff8f64e1866748644776b229bdf334762875e3139b717a3adb8e5c587286ada3 config-baseline.core.json +ba5f7e89aad95d3eae0bc4e3b590c8dbb87bd921bba0d8f12fe67545af5887c6 config-baseline.channel.json +dc19ac1c60544d87fe08944d1184e0ade7b469367cdf8d6ce61452f64f9e0a47 config-baseline.plugin.json diff --git a/docs/.generated/plugin-sdk-api-baseline.sha256 b/docs/.generated/plugin-sdk-api-baseline.sha256 index fe013d1bb9e..15e040cd869 100644 --- a/docs/.generated/plugin-sdk-api-baseline.sha256 +++ b/docs/.generated/plugin-sdk-api-baseline.sha256 @@ -1,2 +1,2 @@ -97509287d728c8f5d1736f7ea07521451ada4b9d7ef56555dbe860a89e1b6e08 plugin-sdk-api-baseline.json -a22b3d427953cc8394b28c87ef7a992d2eb4f2c9f6a76fa58b33079e2306661b plugin-sdk-api-baseline.jsonl +4e024092a28987e1a826b0c731e9ee5adb9d28e73b5cac51ca055c46d9067258 plugin-sdk-api-baseline.json +9e3279a3e78e24b72952ab0f1707dcf465f8c283acf568f043e9b232fd0ae5dd plugin-sdk-api-baseline.jsonl diff --git a/extensions/google/index.ts b/extensions/google/index.ts index c756d50e9de..0e47db0ce71 100644 --- a/extensions/google/index.ts +++ b/extensions/google/index.ts @@ -12,6 +12,7 @@ import { resolveGoogleGenerativeAiTransport, } from "./api.js"; import { registerGoogleGeminiCliProvider } from "./gemini-cli-provider.js"; +import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js"; import { isModernGoogleModel, resolveGoogleGeminiForwardCompatModel } from "./provider-models.js"; import { createGeminiWebSearchProvider } from "./src/gemini-web-search-provider.js"; import { buildGoogleVideoGenerationProvider } from "./video-generation-provider.js"; @@ -166,6 +167,7 @@ export default definePluginEntry({ }); api.registerImageGenerationProvider(createLazyGoogleImageGenerationProvider()); api.registerMediaUnderstandingProvider(createLazyGoogleMediaUnderstandingProvider()); + api.registerMusicGenerationProvider(buildGoogleMusicGenerationProvider()); api.registerVideoGenerationProvider(buildGoogleVideoGenerationProvider()); api.registerWebSearchProvider(createGeminiWebSearchProvider()); }, diff --git a/extensions/google/music-generation-provider.test.ts b/extensions/google/music-generation-provider.test.ts new file mode 100644 index 00000000000..a6c260f425d --- /dev/null +++ b/extensions/google/music-generation-provider.test.ts @@ -0,0 +1,98 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; + +const { GoogleGenAIMock, generateContentMock } = vi.hoisted(() => { + const generateContentMock = vi.fn(); + const GoogleGenAIMock = vi.fn(function GoogleGenAI() { + return { + models: { + generateContent: generateContentMock, + }, + }; + }); + return { GoogleGenAIMock, generateContentMock }; +}); + +vi.mock("@google/genai", () => ({ + GoogleGenAI: GoogleGenAIMock, +})); + +import * as providerAuthRuntime from "openclaw/plugin-sdk/provider-auth-runtime"; +import { buildGoogleMusicGenerationProvider } from "./music-generation-provider.js"; + +describe("google music generation provider", () => { + afterEach(() => { + vi.restoreAllMocks(); + generateContentMock.mockReset(); + GoogleGenAIMock.mockClear(); + }); + + it("submits generation and returns inline audio bytes plus lyrics", async () => { + vi.spyOn(providerAuthRuntime, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "google-key", + source: "env", + mode: "api-key", + }); + generateContentMock.mockResolvedValue({ + candidates: [ + { + content: { + parts: [ + { text: "wake the city up" }, + { + inlineData: { + data: Buffer.from("mp3-bytes").toString("base64"), + mimeType: "audio/mpeg", + }, + }, + ], + }, + }, + ], + }); + + const provider = buildGoogleMusicGenerationProvider(); + const result = await provider.generateMusic({ + provider: "google", + model: "lyria-3-clip-preview", + prompt: "upbeat synthpop anthem", + cfg: {}, + instrumental: true, + }); + + expect(generateContentMock).toHaveBeenCalledWith( + expect.objectContaining({ + model: "lyria-3-clip-preview", + config: { + responseModalities: ["AUDIO", "TEXT"], + }, + }), + ); + expect(result.tracks).toHaveLength(1); + expect(result.tracks[0]?.mimeType).toBe("audio/mpeg"); + expect(result.lyrics).toEqual(["wake the city up"]); + expect(GoogleGenAIMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "google-key", + }), + ); + }); + + it("rejects unsupported wav output on clip model", async () => { + vi.spyOn(providerAuthRuntime, "resolveApiKeyForProvider").mockResolvedValue({ + apiKey: "google-key", + source: "env", + mode: "api-key", + }); + const provider = buildGoogleMusicGenerationProvider(); + + await expect( + provider.generateMusic({ + provider: "google", + model: "lyria-3-clip-preview", + prompt: "ambient ocean", + cfg: {}, + format: "wav", + }), + ).rejects.toThrow("supports mp3 output"); + }); +}); diff --git a/extensions/google/music-generation-provider.ts b/extensions/google/music-generation-provider.ts new file mode 100644 index 00000000000..ce51a1b19e7 --- /dev/null +++ b/extensions/google/music-generation-provider.ts @@ -0,0 +1,186 @@ +import { GoogleGenAI } from "@google/genai"; +import { extensionForMime } from "openclaw/plugin-sdk/msteams"; +import type { + GeneratedMusicAsset, + MusicGenerationProvider, + MusicGenerationRequest, +} from "openclaw/plugin-sdk/music-generation"; +import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; +import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { normalizeGoogleApiBaseUrl } from "./api.js"; + +const DEFAULT_GOOGLE_MUSIC_MODEL = "lyria-3-clip-preview"; +const GOOGLE_PRO_MUSIC_MODEL = "lyria-3-pro-preview"; +const DEFAULT_TIMEOUT_MS = 180_000; +const GOOGLE_MAX_INPUT_IMAGES = 10; + +type GoogleInlineDataPart = { + mimeType?: string; + mime_type?: string; + data?: string; +}; + +type GoogleGenerateMusicResponse = { + candidates?: Array<{ + content?: { + parts?: Array<{ + text?: string; + inlineData?: GoogleInlineDataPart; + inline_data?: GoogleInlineDataPart; + }>; + }; + }>; +}; + +function resolveConfiguredGoogleMusicBaseUrl(req: MusicGenerationRequest): string | undefined { + const configured = req.cfg?.models?.providers?.google?.baseUrl?.trim(); + return configured ? normalizeGoogleApiBaseUrl(configured) : undefined; +} + +function buildMusicPrompt(req: MusicGenerationRequest): string { + const parts = [req.prompt.trim()]; + const lyrics = req.lyrics?.trim(); + if (req.instrumental === true) { + parts.push("Instrumental only. No vocals, no sung lyrics, no spoken word."); + } + if (lyrics) { + parts.push(`Lyrics:\n${lyrics}`); + } + return parts.join("\n\n"); +} + +function resolveSupportedFormats(model: string): readonly string[] { + return model === GOOGLE_PRO_MUSIC_MODEL ? ["mp3", "wav"] : ["mp3"]; +} + +function resolveTrackFileName(params: { index: number; mimeType: string; model: string }): string { + const ext = + extensionForMime(params.mimeType)?.replace(/^\./u, "") || + (params.model === GOOGLE_PRO_MUSIC_MODEL ? "wav" : "mp3"); + return `track-${params.index + 1}.${ext}`; +} + +function extractTracks(params: { payload: GoogleGenerateMusicResponse; model: string }): { + tracks: GeneratedMusicAsset[]; + lyrics: string[]; +} { + const lyrics: string[] = []; + const tracks: GeneratedMusicAsset[] = []; + for (const part of params.payload.candidates?.[0]?.content?.parts ?? []) { + if (part.text?.trim()) { + lyrics.push(part.text.trim()); + continue; + } + const inline = part.inlineData ?? part.inline_data; + const data = inline?.data?.trim(); + if (!data) { + continue; + } + const mimeType = inline?.mimeType?.trim() || inline?.mime_type?.trim() || "audio/mpeg"; + tracks.push({ + buffer: Buffer.from(data, "base64"), + mimeType, + fileName: resolveTrackFileName({ + index: tracks.length, + mimeType, + model: params.model, + }), + }); + } + return { tracks, lyrics }; +} + +export function buildGoogleMusicGenerationProvider(): MusicGenerationProvider { + return { + id: "google", + label: "Google", + defaultModel: DEFAULT_GOOGLE_MUSIC_MODEL, + models: [DEFAULT_GOOGLE_MUSIC_MODEL, GOOGLE_PRO_MUSIC_MODEL], + isConfigured: ({ agentDir }) => + isProviderApiKeyConfigured({ + provider: "google", + agentDir, + }), + capabilities: { + maxTracks: 1, + maxInputImages: GOOGLE_MAX_INPUT_IMAGES, + supportsLyrics: true, + supportsInstrumental: true, + supportsFormat: true, + supportedFormatsByModel: { + [DEFAULT_GOOGLE_MUSIC_MODEL]: ["mp3"], + [GOOGLE_PRO_MUSIC_MODEL]: ["mp3", "wav"], + }, + }, + async generateMusic(req) { + if ((req.inputImages?.length ?? 0) > GOOGLE_MAX_INPUT_IMAGES) { + throw new Error( + `Google music generation supports at most ${GOOGLE_MAX_INPUT_IMAGES} reference images.`, + ); + } + const auth = await resolveApiKeyForProvider({ + provider: "google", + cfg: req.cfg, + agentDir: req.agentDir, + store: req.authStore, + }); + if (!auth.apiKey) { + throw new Error("Google API key missing"); + } + + const model = req.model?.trim() || DEFAULT_GOOGLE_MUSIC_MODEL; + if (req.format) { + const supportedFormats = resolveSupportedFormats(model); + if (!supportedFormats.includes(req.format)) { + throw new Error( + `Google music generation model ${model} supports ${supportedFormats.join(", ")} output.`, + ); + } + } + + const client = new GoogleGenAI({ + apiKey: auth.apiKey, + httpOptions: { + ...(resolveConfiguredGoogleMusicBaseUrl(req) + ? { baseUrl: resolveConfiguredGoogleMusicBaseUrl(req) } + : {}), + timeout: req.timeoutMs ?? DEFAULT_TIMEOUT_MS, + }, + }); + const response = (await client.models.generateContent({ + model, + contents: [ + { text: buildMusicPrompt(req) }, + ...(req.inputImages ?? []).map((image) => ({ + inlineData: { + mimeType: image.mimeType?.trim() || "image/png", + data: image.buffer?.toString("base64") ?? "", + }, + })), + ], + config: { + responseModalities: ["AUDIO", "TEXT"], + }, + })) as GoogleGenerateMusicResponse; + + const { tracks, lyrics } = extractTracks({ + payload: response, + model, + }); + if (tracks.length === 0) { + throw new Error("Google music generation response missing audio data"); + } + return { + tracks, + ...(lyrics.length > 0 ? { lyrics } : {}), + model, + metadata: { + inputImageCount: req.inputImages?.length ?? 0, + instrumental: req.instrumental === true, + ...(req.lyrics?.trim() ? { requestedLyrics: true } : {}), + ...(req.format ? { requestedFormat: req.format } : {}), + }, + }; + }, + }; +} diff --git a/extensions/google/openclaw.plugin.json b/extensions/google/openclaw.plugin.json index bdc3e960a41..67e9436e643 100644 --- a/extensions/google/openclaw.plugin.json +++ b/extensions/google/openclaw.plugin.json @@ -46,6 +46,7 @@ "contracts": { "mediaUnderstandingProviders": ["google"], "imageGenerationProviders": ["google"], + "musicGenerationProviders": ["google"], "videoGenerationProviders": ["google"], "webSearchProviders": ["gemini"] }, diff --git a/extensions/image-generation-core/src/runtime.ts b/extensions/image-generation-core/src/runtime.ts index 2977b32a5cc..0ad5edceb40 100644 --- a/extensions/image-generation-core/src/runtime.ts +++ b/extensions/image-generation-core/src/runtime.ts @@ -1,13 +1,15 @@ +import { + buildNoCapabilityModelConfiguredMessage, + resolveCapabilityModelCandidates, + throwCapabilityGenerationFailure, +} from "../../../src/media-generation/runtime-shared.js"; import { createSubsystemLogger, describeFailoverError, getImageGenerationProvider, - getProviderEnvVars, isFailoverError, listImageGenerationProviders, parseImageGenerationModelRef, - resolveAgentModelFallbackValues, - resolveAgentModelPrimaryValue, type AuthProfileStore, type FallbackAttempt, type GeneratedImageAsset, @@ -40,73 +42,13 @@ export type GenerateImageRuntimeResult = { metadata?: Record; }; -function resolveImageGenerationCandidates(params: { - cfg: OpenClawConfig; - modelOverride?: string; -}): Array<{ provider: string; model: string }> { - const candidates: Array<{ provider: string; model: string }> = []; - const seen = new Set(); - const add = (raw: string | undefined) => { - const parsed = parseImageGenerationModelRef(raw); - if (!parsed) { - return; - } - const key = `${parsed.provider}/${parsed.model}`; - if (seen.has(key)) { - return; - } - seen.add(key); - candidates.push(parsed); - }; - - add(params.modelOverride); - add(resolveAgentModelPrimaryValue(params.cfg.agents?.defaults?.imageGenerationModel)); - for (const fallback of resolveAgentModelFallbackValues( - params.cfg.agents?.defaults?.imageGenerationModel, - )) { - add(fallback); - } - return candidates; -} - -function throwImageGenerationFailure(params: { - attempts: FallbackAttempt[]; - lastError: unknown; -}): never { - if (params.attempts.length <= 1 && params.lastError) { - throw params.lastError; - } - const summary = - params.attempts.length > 0 - ? params.attempts - .map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`) - .join(" | ") - : "unknown"; - throw new Error(`All image generation models failed (${params.attempts.length}): ${summary}`, { - cause: params.lastError instanceof Error ? params.lastError : undefined, - }); -} - function buildNoImageGenerationModelConfiguredMessage(cfg: OpenClawConfig): string { - const providers = listImageGenerationProviders(cfg); - const sampleModel = - providers.find((provider) => provider.defaultModel) ?? - ({ id: "google", defaultModel: "gemini-3-pro-image-preview" } as const); - const authHints = providers - .flatMap((provider) => { - const envVars = getProviderEnvVars(provider.id); - if (envVars.length === 0) { - return []; - } - return [`${provider.id}: ${envVars.join(" / ")}`]; - }) - .slice(0, 3); - return [ - `No image-generation model configured. Set agents.defaults.imageGenerationModel.primary to a provider/model like "${sampleModel.id}/${sampleModel.defaultModel}".`, - authHints.length > 0 - ? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).` - : "If you want a specific provider, also configure that provider's auth/API key first.", - ].join(" "); + return buildNoCapabilityModelConfiguredMessage({ + capabilityLabel: "image-generation", + modelConfigKey: "imageGenerationModel", + providers: listImageGenerationProviders(cfg), + fallbackSampleRef: "google/gemini-3-pro-image-preview", + }); } export function listRuntimeImageGenerationProviders(params?: { config?: OpenClawConfig }) { @@ -116,9 +58,11 @@ export function listRuntimeImageGenerationProviders(params?: { config?: OpenClaw export async function generateImage( params: GenerateImageParams, ): Promise { - const candidates = resolveImageGenerationCandidates({ + const candidates = resolveCapabilityModelCandidates({ cfg: params.cfg, + modelConfig: params.cfg.agents?.defaults?.imageGenerationModel, modelOverride: params.modelOverride, + parseModelRef: parseImageGenerationModelRef, }); if (candidates.length === 0) { throw new Error(buildNoImageGenerationModelConfiguredMessage(params.cfg)); @@ -179,5 +123,9 @@ export async function generateImage( } } - throwImageGenerationFailure({ attempts, lastError }); + throwCapabilityGenerationFailure({ + capabilityLabel: "image generation", + attempts, + lastError, + }); } diff --git a/extensions/minimax/index.ts b/extensions/minimax/index.ts index 9950e170fbf..7b219e9b47b 100644 --- a/extensions/minimax/index.ts +++ b/extensions/minimax/index.ts @@ -23,6 +23,7 @@ import { minimaxMediaUnderstandingProvider, minimaxPortalMediaUnderstandingProvider, } from "./media-understanding-provider.js"; +import { buildMinimaxMusicGenerationProvider } from "./music-generation-provider.js"; import type { MiniMaxRegion } from "./oauth.js"; import { applyMinimaxApiConfig, applyMinimaxApiConfigCn } from "./onboard.js"; import { buildMinimaxPortalProvider, buildMinimaxProvider } from "./provider-catalog.js"; @@ -314,6 +315,7 @@ export default definePluginEntry({ }); api.registerImageGenerationProvider(buildMinimaxImageGenerationProvider()); api.registerImageGenerationProvider(buildMinimaxPortalImageGenerationProvider()); + api.registerMusicGenerationProvider(buildMinimaxMusicGenerationProvider()); api.registerVideoGenerationProvider(buildMinimaxVideoGenerationProvider()); api.registerSpeechProvider(buildMinimaxSpeechProvider()); api.registerWebSearchProvider(createMiniMaxWebSearchProvider()); diff --git a/extensions/minimax/music-generation-provider.test.ts b/extensions/minimax/music-generation-provider.test.ts new file mode 100644 index 00000000000..2528c8b28a0 --- /dev/null +++ b/extensions/minimax/music-generation-provider.test.ts @@ -0,0 +1,104 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { buildMinimaxMusicGenerationProvider } from "./music-generation-provider.js"; + +const { + resolveApiKeyForProviderMock, + postJsonRequestMock, + fetchWithTimeoutMock, + assertOkOrThrowHttpErrorMock, + resolveProviderHttpRequestConfigMock, +} = vi.hoisted(() => ({ + resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "minimax-key" })), + postJsonRequestMock: vi.fn(), + fetchWithTimeoutMock: vi.fn(), + assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + resolveProviderHttpRequestConfigMock: vi.fn((params) => ({ + baseUrl: params.baseUrl ?? params.defaultBaseUrl, + allowPrivateNetwork: false, + headers: new Headers(params.defaultHeaders), + dispatcherPolicy: undefined, + })), +})); + +vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ + resolveApiKeyForProvider: resolveApiKeyForProviderMock, +})); + +vi.mock("openclaw/plugin-sdk/provider-http", () => ({ + assertOkOrThrowHttpError: assertOkOrThrowHttpErrorMock, + fetchWithTimeout: fetchWithTimeoutMock, + postJsonRequest: postJsonRequestMock, + resolveProviderHttpRequestConfig: resolveProviderHttpRequestConfigMock, +})); + +describe("minimax music generation provider", () => { + afterEach(() => { + resolveApiKeyForProviderMock.mockClear(); + postJsonRequestMock.mockReset(); + fetchWithTimeoutMock.mockReset(); + assertOkOrThrowHttpErrorMock.mockClear(); + resolveProviderHttpRequestConfigMock.mockClear(); + }); + + it("creates music and downloads the generated track", async () => { + postJsonRequestMock.mockResolvedValue({ + response: { + json: async () => ({ + task_id: "task-123", + audio_url: "https://example.com/out.mp3", + lyrics: "our city wakes", + base_resp: { status_code: 0 }, + }), + }, + release: vi.fn(async () => {}), + }); + fetchWithTimeoutMock.mockResolvedValue({ + headers: new Headers({ "content-type": "audio/mpeg" }), + arrayBuffer: async () => Buffer.from("mp3-bytes"), + }); + + const provider = buildMinimaxMusicGenerationProvider(); + const result = await provider.generateMusic({ + provider: "minimax", + model: "music-2.5+", + prompt: "upbeat dance-pop with female vocals", + cfg: {}, + lyrics: "our city wakes", + durationSeconds: 45, + }); + + expect(postJsonRequestMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://api.minimax.io/v1/music_generation", + body: expect.objectContaining({ + model: "music-2.5+", + lyrics: "our city wakes", + output_format: "url", + }), + }), + ); + expect(result.tracks).toHaveLength(1); + expect(result.lyrics).toEqual(["our city wakes"]); + expect(result.metadata).toEqual( + expect.objectContaining({ + taskId: "task-123", + audioUrl: "https://example.com/out.mp3", + }), + ); + }); + + it("rejects instrumental requests that also include lyrics", async () => { + const provider = buildMinimaxMusicGenerationProvider(); + + await expect( + provider.generateMusic({ + provider: "minimax", + model: "music-2.5+", + prompt: "driving techno", + cfg: {}, + instrumental: true, + lyrics: "do not sing this", + }), + ).rejects.toThrow("cannot use lyrics when instrumental=true"); + }); +}); diff --git a/extensions/minimax/music-generation-provider.ts b/extensions/minimax/music-generation-provider.ts new file mode 100644 index 00000000000..ee5ac8d2396 --- /dev/null +++ b/extensions/minimax/music-generation-provider.ts @@ -0,0 +1,232 @@ +import { extensionForMime } from "openclaw/plugin-sdk/msteams"; +import type { + GeneratedMusicAsset, + MusicGenerationProvider, + MusicGenerationRequest, +} from "openclaw/plugin-sdk/music-generation"; +import { isProviderApiKeyConfigured } from "openclaw/plugin-sdk/provider-auth"; +import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth-runtime"; +import { + assertOkOrThrowHttpError, + fetchWithTimeout, + postJsonRequest, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; + +const DEFAULT_MINIMAX_MUSIC_BASE_URL = "https://api.minimax.io"; +const DEFAULT_MINIMAX_MUSIC_MODEL = "music-2.5+"; +const DEFAULT_TIMEOUT_MS = 120_000; + +type MinimaxBaseResp = { + status_code?: number; + status_msg?: string; +}; + +type MinimaxMusicCreateResponse = { + task_id?: string; + audio?: string; + audio_url?: string; + lyrics?: string; + data?: { + audio?: string; + audio_url?: string; + lyrics?: string; + }; + base_resp?: MinimaxBaseResp; +}; + +function resolveMinimaxMusicBaseUrl( + cfg: Parameters[0]["cfg"], +): string { + const direct = cfg?.models?.providers?.minimax?.baseUrl?.trim(); + if (!direct) { + return DEFAULT_MINIMAX_MUSIC_BASE_URL; + } + try { + return new URL(direct).origin; + } catch { + return DEFAULT_MINIMAX_MUSIC_BASE_URL; + } +} + +function assertMinimaxBaseResp(baseResp: MinimaxBaseResp | undefined, context: string): void { + if (!baseResp || typeof baseResp.status_code !== "number" || baseResp.status_code === 0) { + return; + } + throw new Error( + `${context} (${baseResp.status_code}): ${baseResp.status_msg ?? "unknown error"}`, + ); +} + +function decodePossibleBinary(data: string): Buffer { + const trimmed = data.trim(); + if (/^[0-9a-f]+$/iu.test(trimmed) && trimmed.length % 2 === 0) { + return Buffer.from(trimmed, "hex"); + } + return Buffer.from(trimmed, "base64"); +} + +function decodePossibleText(data: string): string { + const trimmed = data.trim(); + if (!trimmed) { + return ""; + } + if (/^[0-9a-f]+$/iu.test(trimmed) && trimmed.length % 2 === 0) { + return Buffer.from(trimmed, "hex").toString("utf8").trim(); + } + return trimmed; +} + +async function downloadTrackFromUrl(params: { + url: string; + timeoutMs?: number; + fetchFn: typeof fetch; +}): Promise { + const response = await fetchWithTimeout( + params.url, + { method: "GET" }, + params.timeoutMs ?? DEFAULT_TIMEOUT_MS, + params.fetchFn, + ); + await assertOkOrThrowHttpError(response, "MiniMax generated music download failed"); + const mimeType = response.headers.get("content-type")?.trim() || "audio/mpeg"; + const ext = extensionForMime(mimeType)?.replace(/^\./u, "") || "mp3"; + return { + buffer: Buffer.from(await response.arrayBuffer()), + mimeType, + fileName: `track-1.${ext}`, + }; +} + +function buildPrompt(req: MusicGenerationRequest): string { + const parts = [req.prompt.trim()]; + if (typeof req.durationSeconds === "number" && Number.isFinite(req.durationSeconds)) { + parts.push(`Target duration: about ${Math.max(1, Math.round(req.durationSeconds))} seconds.`); + } + return parts.join("\n\n"); +} + +export function buildMinimaxMusicGenerationProvider(): MusicGenerationProvider { + return { + id: "minimax", + label: "MiniMax", + defaultModel: DEFAULT_MINIMAX_MUSIC_MODEL, + models: [DEFAULT_MINIMAX_MUSIC_MODEL, "music-2.5", "music-2.0"], + isConfigured: ({ agentDir }) => + isProviderApiKeyConfigured({ + provider: "minimax", + agentDir, + }), + capabilities: { + maxTracks: 1, + supportsLyrics: true, + supportsInstrumental: true, + supportsDuration: true, + supportsFormat: true, + supportedFormats: ["mp3"], + }, + async generateMusic(req) { + if ((req.inputImages?.length ?? 0) > 0) { + throw new Error("MiniMax music generation does not support image reference inputs."); + } + if (req.instrumental === true && req.lyrics?.trim()) { + throw new Error("MiniMax music generation cannot use lyrics when instrumental=true."); + } + if (req.format && req.format !== "mp3") { + throw new Error("MiniMax music generation currently supports mp3 output only."); + } + + const auth = await resolveApiKeyForProvider({ + provider: "minimax", + cfg: req.cfg, + agentDir: req.agentDir, + store: req.authStore, + }); + if (!auth.apiKey) { + throw new Error("MiniMax API key missing"); + } + + const fetchFn = fetch; + const { baseUrl, allowPrivateNetwork, headers, dispatcherPolicy } = + resolveProviderHttpRequestConfig({ + baseUrl: resolveMinimaxMusicBaseUrl(req.cfg), + defaultBaseUrl: DEFAULT_MINIMAX_MUSIC_BASE_URL, + allowPrivateNetwork: false, + defaultHeaders: { + Authorization: `Bearer ${auth.apiKey}`, + }, + }); + + const model = req.model?.trim() || DEFAULT_MINIMAX_MUSIC_MODEL; + const body = { + model, + prompt: buildPrompt(req), + ...(req.instrumental === true ? { is_instrumental: true } : {}), + ...(req.lyrics?.trim() + ? { lyrics: req.lyrics.trim() } + : req.instrumental === true + ? {} + : { lyrics_optimizer: true }), + output_format: "url", + audio_setting: { + format: "mp3", + }, + }; + + const { response: res, release } = await postJsonRequest({ + url: `${baseUrl}/v1/music_generation`, + headers, + body, + timeoutMs: req.timeoutMs ?? DEFAULT_TIMEOUT_MS, + fetchFn, + pinDns: false, + allowPrivateNetwork, + dispatcherPolicy, + }); + + try { + await assertOkOrThrowHttpError(res, "MiniMax music generation failed"); + const payload = (await res.json()) as MinimaxMusicCreateResponse; + assertMinimaxBaseResp(payload.base_resp, "MiniMax music generation failed"); + + const audioUrl = payload.audio_url?.trim() || payload.data?.audio_url?.trim(); + const inlineAudio = payload.audio?.trim() || payload.data?.audio?.trim(); + const lyrics = decodePossibleText(payload.lyrics ?? payload.data?.lyrics ?? ""); + + const track = audioUrl + ? await downloadTrackFromUrl({ + url: audioUrl, + timeoutMs: req.timeoutMs, + fetchFn, + }) + : inlineAudio + ? { + buffer: decodePossibleBinary(inlineAudio), + mimeType: "audio/mpeg", + fileName: "track-1.mp3", + } + : null; + if (!track) { + throw new Error("MiniMax music generation response missing audio output"); + } + + return { + tracks: [track], + ...(lyrics ? { lyrics: [lyrics] } : {}), + model, + metadata: { + ...(payload.task_id?.trim() ? { taskId: payload.task_id.trim() } : {}), + ...(audioUrl ? { audioUrl } : {}), + instrumental: req.instrumental === true, + ...(req.lyrics?.trim() ? { requestedLyrics: true } : {}), + ...(typeof req.durationSeconds === "number" + ? { requestedDurationSeconds: req.durationSeconds } + : {}), + }, + }; + } finally { + await release(); + } + }, + }; +} diff --git a/extensions/minimax/openclaw.plugin.json b/extensions/minimax/openclaw.plugin.json index 69c641b8254..e70aaa155bf 100644 --- a/extensions/minimax/openclaw.plugin.json +++ b/extensions/minimax/openclaw.plugin.json @@ -64,6 +64,7 @@ "speechProviders": ["minimax"], "mediaUnderstandingProviders": ["minimax", "minimax-portal"], "imageGenerationProviders": ["minimax", "minimax-portal"], + "musicGenerationProviders": ["minimax"], "videoGenerationProviders": ["minimax"], "webSearchProviders": ["minimax"] }, diff --git a/extensions/music-generation-providers.live.test.ts b/extensions/music-generation-providers.live.test.ts new file mode 100644 index 00000000000..f4b26b1c496 --- /dev/null +++ b/extensions/music-generation-providers.live.test.ts @@ -0,0 +1,104 @@ +import { describe, expect, it } from "vitest"; +import { collectProviderApiKeys } from "../src/agents/live-auth-keys.js"; +import { isLiveTestEnabled } from "../src/agents/live-test-helpers.js"; +import type { OpenClawConfig } from "../src/config/config.js"; +import { DEFAULT_LIVE_MUSIC_MODELS } from "../src/music-generation/live-test-helpers.js"; +import { parseMusicGenerationModelRef } from "../src/music-generation/model-ref.js"; +import { getProviderEnvVars } from "../src/secrets/provider-env-vars.js"; +import { + parseCsvFilter, + parseProviderModelMap, +} from "../src/video-generation/live-test-helpers.js"; +import { + registerProviderPlugin, + requireRegisteredProvider, +} from "../test/helpers/plugins/provider-registration.js"; +import googlePlugin from "./google/index.js"; +import minimaxPlugin from "./minimax/index.js"; + +const LIVE = isLiveTestEnabled(); +const providerFilter = parseCsvFilter(process.env.OPENCLAW_LIVE_MUSIC_GENERATION_PROVIDERS); +const envModelMap = parseProviderModelMap(process.env.OPENCLAW_LIVE_MUSIC_GENERATION_MODELS); + +type LiveProviderCase = { + plugin: Parameters[0]["plugin"]; + pluginId: string; + pluginName: string; + providerId: string; +}; + +const CASES: LiveProviderCase[] = [ + { + plugin: googlePlugin, + pluginId: "google", + pluginName: "Google Provider", + providerId: "google", + }, + { + plugin: minimaxPlugin, + pluginId: "minimax", + pluginName: "MiniMax Provider", + providerId: "minimax", + }, +] + .filter((entry) => (providerFilter ? providerFilter.has(entry.providerId) : true)) + .toSorted((left, right) => left.providerId.localeCompare(right.providerId)); + +function asConfig(value: unknown): OpenClawConfig { + return value as OpenClawConfig; +} + +function resolveProviderModelForLiveTest(providerId: string, modelRef: string): string { + const parsed = parseMusicGenerationModelRef(modelRef); + if (parsed && parsed.provider === providerId) { + return parsed.model; + } + return modelRef; +} + +describe.skipIf(!LIVE)("music generation provider live", () => { + for (const testCase of CASES) { + const modelRef = + envModelMap.get(testCase.providerId) ?? DEFAULT_LIVE_MUSIC_MODELS[testCase.providerId]; + const hasAuth = collectProviderApiKeys(testCase.providerId).length > 0; + const expectedEnvVars = getProviderEnvVars(testCase.providerId).join(", "); + + const liveIt = hasAuth && modelRef ? it : it.skip; + liveIt( + `generates a short track via ${testCase.providerId}`, + async () => { + const { musicProviders } = await registerProviderPlugin({ + plugin: testCase.plugin, + id: testCase.pluginId, + name: testCase.pluginName, + }); + const provider = requireRegisteredProvider( + musicProviders, + testCase.providerId, + "music provider", + ); + const providerModel = resolveProviderModelForLiveTest(testCase.providerId, modelRef!); + + const result = await provider.generateMusic({ + provider: testCase.providerId, + model: providerModel, + prompt: "Upbeat instrumental synthwave with warm neon pads and a simple driving beat.", + cfg: asConfig({ plugins: { enabled: true } }), + agentDir: "/tmp/openclaw-live-music", + instrumental: true, + ...(provider.capabilities.supportsDuration ? { durationSeconds: 12 } : {}), + ...(provider.capabilities.supportsFormat ? { format: "mp3" as const } : {}), + }); + + expect(result.tracks.length).toBeGreaterThan(0); + expect(result.tracks[0]?.mimeType.startsWith("audio/")).toBe(true); + expect(result.tracks[0]?.buffer.byteLength).toBeGreaterThan(1024); + }, + 6 * 60_000, + ); + + if (!hasAuth || !modelRef) { + it.skip(`skips ${testCase.providerId} without live auth/model (${expectedEnvVars || "no env vars"})`, () => {}); + } + } +}); diff --git a/extensions/video-generation-core/src/runtime.ts b/extensions/video-generation-core/src/runtime.ts index 37a7368ff00..7efc716be34 100644 --- a/extensions/video-generation-core/src/runtime.ts +++ b/extensions/video-generation-core/src/runtime.ts @@ -1,13 +1,15 @@ +import { + buildNoCapabilityModelConfiguredMessage, + resolveCapabilityModelCandidates, + throwCapabilityGenerationFailure, +} from "../../../src/media-generation/runtime-shared.js"; import { createSubsystemLogger, describeFailoverError, - getProviderEnvVars, getVideoGenerationProvider, isFailoverError, listVideoGenerationProviders, parseVideoGenerationModelRef, - resolveAgentModelFallbackValues, - resolveAgentModelPrimaryValue, type AuthProfileStore, type FallbackAttempt, type GeneratedVideoAsset, @@ -45,73 +47,13 @@ export type GenerateVideoRuntimeResult = { ignoredOverrides: VideoGenerationIgnoredOverride[]; }; -function resolveVideoGenerationCandidates(params: { - cfg: OpenClawConfig; - modelOverride?: string; -}): Array<{ provider: string; model: string }> { - const candidates: Array<{ provider: string; model: string }> = []; - const seen = new Set(); - const add = (raw: string | undefined) => { - const parsed = parseVideoGenerationModelRef(raw); - if (!parsed) { - return; - } - const key = `${parsed.provider}/${parsed.model}`; - if (seen.has(key)) { - return; - } - seen.add(key); - candidates.push(parsed); - }; - - add(params.modelOverride); - add(resolveAgentModelPrimaryValue(params.cfg.agents?.defaults?.videoGenerationModel)); - for (const fallback of resolveAgentModelFallbackValues( - params.cfg.agents?.defaults?.videoGenerationModel, - )) { - add(fallback); - } - return candidates; -} - -function throwVideoGenerationFailure(params: { - attempts: FallbackAttempt[]; - lastError: unknown; -}): never { - if (params.attempts.length <= 1 && params.lastError) { - throw params.lastError; - } - const summary = - params.attempts.length > 0 - ? params.attempts - .map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`) - .join(" | ") - : "unknown"; - throw new Error(`All video generation models failed (${params.attempts.length}): ${summary}`, { - cause: params.lastError instanceof Error ? params.lastError : undefined, - }); -} - function buildNoVideoGenerationModelConfiguredMessage(cfg: OpenClawConfig): string { - const providers = listVideoGenerationProviders(cfg); - const sampleModel = - providers.find((provider) => provider.defaultModel) ?? - ({ id: "qwen", defaultModel: "wan2.6-t2v" } as const); - const authHints = providers - .flatMap((provider) => { - const envVars = getProviderEnvVars(provider.id); - if (envVars.length === 0) { - return []; - } - return [`${provider.id}: ${envVars.join(" / ")}`]; - }) - .slice(0, 3); - return [ - `No video-generation model configured. Set agents.defaults.videoGenerationModel.primary to a provider/model like "${sampleModel.id}/${sampleModel.defaultModel}".`, - authHints.length > 0 - ? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).` - : "If you want a specific provider, also configure that provider's auth/API key first.", - ].join(" "); + return buildNoCapabilityModelConfiguredMessage({ + capabilityLabel: "video-generation", + modelConfigKey: "videoGenerationModel", + providers: listVideoGenerationProviders(cfg), + fallbackSampleRef: "qwen/wan2.6-t2v", + }); } export function listRuntimeVideoGenerationProviders(params?: { config?: OpenClawConfig }) { @@ -172,9 +114,11 @@ function resolveProviderVideoGenerationOverrides(params: { export async function generateVideo( params: GenerateVideoParams, ): Promise { - const candidates = resolveVideoGenerationCandidates({ + const candidates = resolveCapabilityModelCandidates({ cfg: params.cfg, + modelConfig: params.cfg.agents?.defaults?.videoGenerationModel, modelOverride: params.modelOverride, + parseModelRef: parseVideoGenerationModelRef, }); if (candidates.length === 0) { throw new Error(buildNoVideoGenerationModelConfiguredMessage(params.cfg)); @@ -247,5 +191,9 @@ export async function generateVideo( } } - throwVideoGenerationFailure({ attempts, lastError }); + throwCapabilityGenerationFailure({ + capabilityLabel: "video generation", + attempts, + lastError, + }); } diff --git a/package.json b/package.json index d5f517de360..1f2e910527d 100644 --- a/package.json +++ b/package.json @@ -555,6 +555,14 @@ "types": "./dist/plugin-sdk/image-generation-core.d.ts", "default": "./dist/plugin-sdk/image-generation-core.js" }, + "./plugin-sdk/music-generation": { + "types": "./dist/plugin-sdk/music-generation.d.ts", + "default": "./dist/plugin-sdk/music-generation.js" + }, + "./plugin-sdk/music-generation-core": { + "types": "./dist/plugin-sdk/music-generation-core.d.ts", + "default": "./dist/plugin-sdk/music-generation-core.js" + }, "./plugin-sdk/video-generation": { "types": "./dist/plugin-sdk/video-generation.d.ts", "default": "./dist/plugin-sdk/video-generation.js" diff --git a/src/agents/internal-events.ts b/src/agents/internal-events.ts index 5efe8a03ab3..d2106a67b33 100644 --- a/src/agents/internal-events.ts +++ b/src/agents/internal-events.ts @@ -8,7 +8,7 @@ export type AgentInternalEventType = "task_completion"; export type AgentTaskCompletionInternalEvent = { type: "task_completion"; - source: "subagent" | "cron" | "video_generation"; + source: "subagent" | "cron" | "video_generation" | "music_generation"; childSessionKey: string; childSessionId?: string; announceType: string; diff --git a/src/agents/media-generation-task-status-shared.ts b/src/agents/media-generation-task-status-shared.ts new file mode 100644 index 00000000000..0bf069ec70e --- /dev/null +++ b/src/agents/media-generation-task-status-shared.ts @@ -0,0 +1,100 @@ +import type { TaskRecord } from "../tasks/task-registry.types.js"; +import { + buildSessionAsyncTaskStatusDetails, + findActiveSessionTask, +} from "./session-async-task-status.js"; + +export function isActiveMediaGenerationTask(params: { + task: TaskRecord; + taskKind: string; +}): boolean { + return ( + params.task.runtime === "cli" && + params.task.scopeKind === "session" && + params.task.taskKind === params.taskKind && + (params.task.status === "queued" || params.task.status === "running") + ); +} + +export function getMediaGenerationTaskProviderId( + task: TaskRecord, + sourcePrefix: string, +): string | undefined { + const sourceId = task.sourceId?.trim() ?? ""; + if (!sourceId.startsWith(`${sourcePrefix}:`)) { + return undefined; + } + const providerId = sourceId.slice(`${sourcePrefix}:`.length).trim(); + return providerId || undefined; +} + +export function findActiveMediaGenerationTaskForSession(params: { + sessionKey?: string; + taskKind: string; + sourcePrefix: string; +}): TaskRecord | null { + return findActiveSessionTask({ + sessionKey: params.sessionKey, + runtime: "cli", + taskKind: params.taskKind, + sourceIdPrefix: params.sourcePrefix, + }); +} + +export function buildMediaGenerationTaskStatusDetails(params: { + task: TaskRecord; + sourcePrefix: string; +}): Record { + const provider = getMediaGenerationTaskProviderId(params.task, params.sourcePrefix); + return { + ...buildSessionAsyncTaskStatusDetails(params.task), + ...(provider ? { provider } : {}), + }; +} + +export function buildMediaGenerationTaskStatusText(params: { + task: TaskRecord; + sourcePrefix: string; + nounLabel: string; + toolName: string; + completionLabel: string; + duplicateGuard?: boolean; +}): string { + const provider = getMediaGenerationTaskProviderId(params.task, params.sourcePrefix); + const lines = [ + `${params.nounLabel} task ${params.task.taskId} is already ${params.task.status}${provider ? ` with ${provider}` : ""}.`, + params.task.progressSummary ? `Progress: ${params.task.progressSummary}.` : null, + params.duplicateGuard + ? `Do not call ${params.toolName} again for this request. Wait for the completion event; I will post the finished ${params.completionLabel} here.` + : `Wait for the completion event; I will post the finished ${params.completionLabel} here when it's ready.`, + ].filter((entry): entry is string => Boolean(entry)); + return lines.join("\n"); +} + +export function buildActiveMediaGenerationTaskPromptContextForSession(params: { + sessionKey?: string; + taskKind: string; + sourcePrefix: string; + nounLabel: string; + toolName: string; + completionLabel: string; +}): string | undefined { + const task = findActiveMediaGenerationTaskForSession({ + sessionKey: params.sessionKey, + taskKind: params.taskKind, + sourcePrefix: params.sourcePrefix, + }); + if (!task) { + return undefined; + } + const provider = getMediaGenerationTaskProviderId(task, params.sourcePrefix); + const lines = [ + `An active ${params.nounLabel.toLowerCase()} background task already exists for this session.`, + `Task ${task.taskId} is currently ${task.status}${provider ? ` via ${provider}` : ""}.`, + task.progressSummary ? `Current progress: ${task.progressSummary}.` : null, + `Do not call \`${params.toolName}\` again for the same request while that task is queued or running.`, + `If the user asks for progress or whether the work is async, explain the active task state or call \`${params.toolName}\` with \`action:"status"\` instead of starting a new generation.`, + `Only start a new \`${params.toolName}\` call if the user clearly asks for different/new ${params.completionLabel}.`, + ].filter((entry): entry is string => Boolean(entry)); + return lines.join("\n"); +} diff --git a/src/agents/music-generation-task-status.ts b/src/agents/music-generation-task-status.ts new file mode 100644 index 00000000000..c2481bcb91e --- /dev/null +++ b/src/agents/music-generation-task-status.ts @@ -0,0 +1,65 @@ +import type { TaskRecord } from "../tasks/task-registry.types.js"; +import { + buildActiveMediaGenerationTaskPromptContextForSession, + buildMediaGenerationTaskStatusDetails, + buildMediaGenerationTaskStatusText, + findActiveMediaGenerationTaskForSession, + getMediaGenerationTaskProviderId, + isActiveMediaGenerationTask, +} from "./media-generation-task-status-shared.js"; + +export const MUSIC_GENERATION_TASK_KIND = "music_generation"; +const MUSIC_GENERATION_SOURCE_PREFIX = "music_generate"; + +export function isActiveMusicGenerationTask(task: TaskRecord): boolean { + return isActiveMediaGenerationTask({ + task, + taskKind: MUSIC_GENERATION_TASK_KIND, + }); +} + +export function getMusicGenerationTaskProviderId(task: TaskRecord): string | undefined { + return getMediaGenerationTaskProviderId(task, MUSIC_GENERATION_SOURCE_PREFIX); +} + +export function findActiveMusicGenerationTaskForSession(sessionKey?: string): TaskRecord | null { + return findActiveMediaGenerationTaskForSession({ + sessionKey, + taskKind: MUSIC_GENERATION_TASK_KIND, + sourcePrefix: MUSIC_GENERATION_SOURCE_PREFIX, + }); +} + +export function buildMusicGenerationTaskStatusDetails(task: TaskRecord): Record { + return buildMediaGenerationTaskStatusDetails({ + task, + sourcePrefix: MUSIC_GENERATION_SOURCE_PREFIX, + }); +} + +export function buildMusicGenerationTaskStatusText( + task: TaskRecord, + params?: { duplicateGuard?: boolean }, +): string { + return buildMediaGenerationTaskStatusText({ + task, + sourcePrefix: MUSIC_GENERATION_SOURCE_PREFIX, + nounLabel: "Music generation", + toolName: "music_generate", + completionLabel: "music", + duplicateGuard: params?.duplicateGuard, + }); +} + +export function buildActiveMusicGenerationTaskPromptContextForSession( + sessionKey?: string, +): string | undefined { + return buildActiveMediaGenerationTaskPromptContextForSession({ + sessionKey, + taskKind: MUSIC_GENERATION_TASK_KIND, + sourcePrefix: MUSIC_GENERATION_SOURCE_PREFIX, + nounLabel: "Music generation", + toolName: "music_generate", + completionLabel: "music tracks", + }); +} diff --git a/src/agents/openclaw-tools.ts b/src/agents/openclaw-tools.ts index 2962cf15552..77f423e25a7 100644 --- a/src/agents/openclaw-tools.ts +++ b/src/agents/openclaw-tools.ts @@ -21,6 +21,7 @@ import { createGatewayTool } from "./tools/gateway-tool.js"; import { createImageGenerateTool } from "./tools/image-generate-tool.js"; import { createImageTool } from "./tools/image-tool.js"; import { createMessageTool } from "./tools/message-tool.js"; +import { createMusicGenerateTool } from "./tools/music-generate-tool.js"; import { createNodesTool } from "./tools/nodes-tool.js"; import { createPdfTool } from "./tools/pdf-tool.js"; import { createSessionStatusTool } from "./tools/session-status-tool.js"; @@ -170,6 +171,15 @@ export function createOpenClawTools( sandbox, fsPolicy: options?.fsPolicy, }); + const musicGenerateTool = createMusicGenerateTool({ + config: options?.config, + agentDir: options?.agentDir, + agentSessionKey: options?.agentSessionKey, + requesterOrigin: deliveryContext ?? undefined, + workspaceDir, + sandbox, + fsPolicy: options?.fsPolicy, + }); const pdfTool = options?.agentDir?.trim() ? createPdfTool({ config: options?.config, @@ -227,6 +237,7 @@ export function createOpenClawTools( config: options?.config, }), ...(imageGenerateTool ? [imageGenerateTool] : []), + ...(musicGenerateTool ? [musicGenerateTool] : []), ...(videoGenerateTool ? [videoGenerateTool] : []), createGatewayTool({ agentSessionKey: options?.agentSessionKey, diff --git a/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.test.ts b/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.test.ts index a11c6295c01..c6915703687 100644 --- a/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.test.ts @@ -1,9 +1,14 @@ import { describe, expect, it, vi } from "vitest"; +const musicGenerationTaskStatusMocks = vi.hoisted(() => ({ + buildActiveMusicGenerationTaskPromptContextForSession: vi.fn(), +})); + const videoGenerationTaskStatusMocks = vi.hoisted(() => ({ buildActiveVideoGenerationTaskPromptContextForSession: vi.fn(), })); +vi.mock("../../music-generation-task-status.js", () => musicGenerationTaskStatusMocks); vi.mock("../../video-generation-task-status.js", () => videoGenerationTaskStatusMocks); import { resolveAttemptPrependSystemContext } from "./attempt.prompt-helpers.js"; @@ -13,6 +18,9 @@ describe("resolveAttemptPrependSystemContext", () => { videoGenerationTaskStatusMocks.buildActiveVideoGenerationTaskPromptContextForSession.mockReturnValue( "Active task hint", ); + musicGenerationTaskStatusMocks.buildActiveMusicGenerationTaskPromptContextForSession.mockReturnValue( + "Music task hint", + ); const result = resolveAttemptPrependSystemContext({ sessionKey: "agent:main:discord:direct:123", @@ -23,7 +31,10 @@ describe("resolveAttemptPrependSystemContext", () => { expect( videoGenerationTaskStatusMocks.buildActiveVideoGenerationTaskPromptContextForSession, ).toHaveBeenCalledWith("agent:main:discord:direct:123"); - expect(result).toBe("Active task hint\n\nHook system context"); + expect( + musicGenerationTaskStatusMocks.buildActiveMusicGenerationTaskPromptContextForSession, + ).toHaveBeenCalledWith("agent:main:discord:direct:123"); + expect(result).toBe("Active task hint\n\nMusic task hint\n\nHook system context"); }); it("skips active video task guidance for non-user triggers", () => { @@ -31,6 +42,10 @@ describe("resolveAttemptPrependSystemContext", () => { videoGenerationTaskStatusMocks.buildActiveVideoGenerationTaskPromptContextForSession.mockReturnValue( "Should not be used", ); + musicGenerationTaskStatusMocks.buildActiveMusicGenerationTaskPromptContextForSession.mockReset(); + musicGenerationTaskStatusMocks.buildActiveMusicGenerationTaskPromptContextForSession.mockReturnValue( + "Should not be used", + ); const result = resolveAttemptPrependSystemContext({ sessionKey: "agent:main:discord:direct:123", @@ -41,6 +56,9 @@ describe("resolveAttemptPrependSystemContext", () => { expect( videoGenerationTaskStatusMocks.buildActiveVideoGenerationTaskPromptContextForSession, ).not.toHaveBeenCalled(); + expect( + musicGenerationTaskStatusMocks.buildActiveMusicGenerationTaskPromptContextForSession, + ).not.toHaveBeenCalled(); expect(result).toBe("Hook system context"); }); }); diff --git a/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts b/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts index 14a636fb536..50b34bc07ad 100644 --- a/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts +++ b/src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts @@ -6,6 +6,7 @@ import type { } from "../../../plugins/types.js"; import { isCronSessionKey, isSubagentSessionKey } from "../../../routing/session-key.js"; import { joinPresentTextSegments } from "../../../shared/text/join-segments.js"; +import { buildActiveMusicGenerationTaskPromptContextForSession } from "../../music-generation-task-status.js"; import { prependSystemPromptAdditionAfterCacheBoundary } from "../../system-prompt-cache-boundary.js"; import { resolveEffectiveToolFsWorkspaceOnly } from "../../tool-fs-policy.js"; import { buildActiveVideoGenerationTaskPromptContextForSession } from "../../video-generation-task-status.js"; @@ -125,11 +126,17 @@ export function resolveAttemptPrependSystemContext(params: { trigger?: EmbeddedRunAttemptParams["trigger"]; hookPrependSystemContext?: string; }): string | undefined { - const activeVideoTaskPromptContext = + const activeMediaTaskPromptContexts = params.trigger === "user" || params.trigger === "manual" - ? buildActiveVideoGenerationTaskPromptContextForSession(params.sessionKey) - : undefined; - return joinPresentTextSegments([activeVideoTaskPromptContext, params.hookPrependSystemContext]); + ? [ + buildActiveVideoGenerationTaskPromptContextForSession(params.sessionKey), + buildActiveMusicGenerationTaskPromptContextForSession(params.sessionKey), + ] + : []; + return joinPresentTextSegments([ + ...activeMediaTaskPromptContexts, + params.hookPrependSystemContext, + ]); } /** Build runtime context passed into context-engine afterTurn hooks. */ diff --git a/src/agents/pi-embedded-subscribe.tools.media.test.ts b/src/agents/pi-embedded-subscribe.tools.media.test.ts index 55ac9d7fc13..1d67a9048fb 100644 --- a/src/agents/pi-embedded-subscribe.tools.media.test.ts +++ b/src/agents/pi-embedded-subscribe.tools.media.test.ts @@ -265,6 +265,10 @@ describe("extractToolResultMediaPaths", () => { expect(isToolResultMediaTrusted("image_generate")).toBe(true); }); + it("trusts music_generate local MEDIA paths", () => { + expect(isToolResultMediaTrusted("music_generate")).toBe(true); + }); + it("trusts video_generate local MEDIA paths", () => { expect(isToolResultMediaTrusted("video_generate")).toBe(true); }); diff --git a/src/agents/pi-embedded-subscribe.tools.ts b/src/agents/pi-embedded-subscribe.tools.ts index 0349d17c179..6171dd97596 100644 --- a/src/agents/pi-embedded-subscribe.tools.ts +++ b/src/agents/pi-embedded-subscribe.tools.ts @@ -147,6 +147,7 @@ const TRUSTED_TOOL_RESULT_MEDIA = new Set([ "memory_get", "memory_search", "message", + "music_generate", "nodes", "process", "read", diff --git a/src/agents/tool-catalog.test.ts b/src/agents/tool-catalog.test.ts index 216b0fcab84..b959c1abb19 100644 --- a/src/agents/tool-catalog.test.ts +++ b/src/agents/tool-catalog.test.ts @@ -10,6 +10,7 @@ describe("tool-catalog", () => { expect(policy!.allow).toContain("x_search"); expect(policy!.allow).toContain("web_fetch"); expect(policy!.allow).toContain("image_generate"); + expect(policy!.allow).toContain("music_generate"); expect(policy!.allow).toContain("video_generate"); expect(policy!.allow).toContain("update_plan"); }); diff --git a/src/agents/tool-catalog.ts b/src/agents/tool-catalog.ts index c28249fd584..6b634c56888 100644 --- a/src/agents/tool-catalog.ts +++ b/src/agents/tool-catalog.ts @@ -277,6 +277,14 @@ const CORE_TOOL_DEFINITIONS: CoreToolDefinition[] = [ profiles: ["coding"], includeInOpenClawGroup: true, }, + { + id: "music_generate", + label: "music_generate", + description: "Music generation", + sectionId: "media", + profiles: ["coding"], + includeInOpenClawGroup: true, + }, { id: "video_generate", label: "video_generate", diff --git a/src/agents/tool-display-config.ts b/src/agents/tool-display-config.ts index f5dcd69f6a2..a8e85ac526d 100644 --- a/src/agents/tool-display-config.ts +++ b/src/agents/tool-display-config.ts @@ -640,6 +640,20 @@ export const TOOL_DISPLAY_CONFIG: ToolDisplayConfig = { }, }, }, + music_generate: { + emoji: "🎵", + title: "Music Generation", + actions: { + generate: { + label: "generate", + detailKeys: ["prompt", "model", "durationSeconds", "format", "instrumental"], + }, + list: { + label: "list", + detailKeys: ["provider", "model"], + }, + }, + }, video_generate: { emoji: "🎬", title: "Video Generation", diff --git a/src/agents/tools/media-generate-background-shared.ts b/src/agents/tools/media-generate-background-shared.ts new file mode 100644 index 00000000000..7438694eac6 --- /dev/null +++ b/src/agents/tools/media-generate-background-shared.ts @@ -0,0 +1,223 @@ +import crypto from "node:crypto"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; +import { + completeTaskRunByRunId, + createRunningTaskRun, + failTaskRunByRunId, + recordTaskRunProgressByRunId, +} from "../../tasks/task-executor.js"; +import type { DeliveryContext } from "../../utils/delivery-context.js"; +import { INTERNAL_MESSAGE_CHANNEL } from "../../utils/message-channel.js"; +import { formatAgentInternalEventsForPrompt, type AgentInternalEvent } from "../internal-events.js"; +import { deliverSubagentAnnouncement } from "../subagent-announce-delivery.js"; + +const log = createSubsystemLogger("agents/tools/media-generate-background-shared"); + +export type MediaGenerationTaskHandle = { + taskId: string; + runId: string; + requesterSessionKey: string; + requesterOrigin?: DeliveryContext; + taskLabel: string; +}; + +export function createMediaGenerationTaskRun(params: { + sessionKey?: string; + requesterOrigin?: DeliveryContext; + prompt: string; + providerId?: string; + toolName: string; + taskKind: string; + label: string; + queuedProgressSummary: string; +}): MediaGenerationTaskHandle | null { + const sessionKey = params.sessionKey?.trim(); + if (!sessionKey) { + return null; + } + const runId = `tool:${params.toolName}:${crypto.randomUUID()}`; + try { + const task = createRunningTaskRun({ + runtime: "cli", + taskKind: params.taskKind, + sourceId: params.providerId ? `${params.toolName}:${params.providerId}` : params.toolName, + requesterSessionKey: sessionKey, + ownerKey: sessionKey, + scopeKind: "session", + requesterOrigin: params.requesterOrigin, + childSessionKey: sessionKey, + runId, + label: params.label, + task: params.prompt, + deliveryStatus: "not_applicable", + notifyPolicy: "silent", + startedAt: Date.now(), + lastEventAt: Date.now(), + progressSummary: params.queuedProgressSummary, + }); + return { + taskId: task.taskId, + runId, + requesterSessionKey: sessionKey, + requesterOrigin: params.requesterOrigin, + taskLabel: params.prompt, + }; + } catch (error) { + log.warn("Failed to create media generation task ledger record", { + sessionKey, + toolName: params.toolName, + providerId: params.providerId, + error, + }); + return null; + } +} + +export function recordMediaGenerationTaskProgress(params: { + handle: MediaGenerationTaskHandle | null; + progressSummary: string; + eventSummary?: string; +}) { + if (!params.handle) { + return; + } + recordTaskRunProgressByRunId({ + runId: params.handle.runId, + runtime: "cli", + sessionKey: params.handle.requesterSessionKey, + lastEventAt: Date.now(), + progressSummary: params.progressSummary, + eventSummary: params.eventSummary, + }); +} + +export function completeMediaGenerationTaskRun(params: { + handle: MediaGenerationTaskHandle | null; + provider: string; + model: string; + count: number; + paths: string[]; + generatedLabel: string; +}) { + if (!params.handle) { + return; + } + const endedAt = Date.now(); + const target = params.count === 1 ? params.paths[0] : `${params.count} files`; + completeTaskRunByRunId({ + runId: params.handle.runId, + runtime: "cli", + sessionKey: params.handle.requesterSessionKey, + endedAt, + lastEventAt: endedAt, + progressSummary: `Generated ${params.count} ${params.generatedLabel}${params.count === 1 ? "" : "s"}`, + terminalSummary: `Generated ${params.count} ${params.generatedLabel}${params.count === 1 ? "" : "s"} with ${params.provider}/${params.model}${target ? ` -> ${target}` : ""}.`, + }); +} + +export function failMediaGenerationTaskRun(params: { + handle: MediaGenerationTaskHandle | null; + error: unknown; + progressSummary: string; +}) { + if (!params.handle) { + return; + } + const endedAt = Date.now(); + const errorText = params.error instanceof Error ? params.error.message : String(params.error); + failTaskRunByRunId({ + runId: params.handle.runId, + runtime: "cli", + sessionKey: params.handle.requesterSessionKey, + endedAt, + lastEventAt: endedAt, + error: errorText, + progressSummary: params.progressSummary, + terminalSummary: errorText, + }); +} + +function buildMediaGenerationReplyInstruction(params: { + status: "ok" | "error"; + completionLabel: string; +}) { + if (params.status === "ok") { + return [ + `A completed ${params.completionLabel} generation task is ready for user delivery.`, + `Reply in your normal assistant voice and post the finished ${params.completionLabel} to the original message channel now.`, + "If the result includes MEDIA: lines, include those exact MEDIA: lines in your reply so OpenClaw attaches the generated media.", + "Keep internal task/session details private and do not copy the internal event text verbatim.", + ].join(" "); + } + return [ + `${params.completionLabel[0]?.toUpperCase() ?? "T"}${params.completionLabel.slice(1)} generation task failed.`, + "Reply in your normal assistant voice with the failure summary now.", + "Keep internal task/session details private and do not copy the internal event text verbatim.", + ].join(" "); +} + +export async function wakeMediaGenerationTaskCompletion(params: { + handle: MediaGenerationTaskHandle | null; + status: "ok" | "error"; + statusLabel: string; + result: string; + statsLine?: string; + eventSource: AgentInternalEvent["source"]; + announceType: string; + toolName: string; + completionLabel: string; +}) { + if (!params.handle) { + return; + } + const internalEvents: AgentInternalEvent[] = [ + { + type: "task_completion", + source: params.eventSource, + childSessionKey: `${params.toolName}:${params.handle.taskId}`, + childSessionId: params.handle.taskId, + announceType: params.announceType, + taskLabel: params.handle.taskLabel, + status: params.status, + statusLabel: params.statusLabel, + result: params.result, + ...(params.statsLine?.trim() ? { statsLine: params.statsLine } : {}), + replyInstruction: buildMediaGenerationReplyInstruction({ + status: params.status, + completionLabel: params.completionLabel, + }), + }, + ]; + const triggerMessage = + formatAgentInternalEventsForPrompt(internalEvents) || + `A ${params.completionLabel} generation task finished. Process the completion update now.`; + const announceId = `${params.toolName}:${params.handle.taskId}:${params.status}`; + const delivery = await deliverSubagentAnnouncement({ + requesterSessionKey: params.handle.requesterSessionKey, + targetRequesterSessionKey: params.handle.requesterSessionKey, + announceId, + triggerMessage, + steerMessage: triggerMessage, + internalEvents, + summaryLine: params.handle.taskLabel, + requesterSessionOrigin: params.handle.requesterOrigin, + requesterOrigin: params.handle.requesterOrigin, + completionDirectOrigin: params.handle.requesterOrigin, + directOrigin: params.handle.requesterOrigin, + sourceSessionKey: `${params.toolName}:${params.handle.taskId}`, + sourceChannel: INTERNAL_MESSAGE_CHANNEL, + sourceTool: params.toolName, + requesterIsSubagent: false, + expectsCompletionMessage: true, + bestEffortDeliver: true, + directIdempotencyKey: announceId, + }); + if (!delivery.delivered && delivery.error) { + log.warn("Media generation completion wake failed", { + taskId: params.handle.taskId, + runId: params.handle.runId, + toolName: params.toolName, + error: delivery.error, + }); + } +} diff --git a/src/agents/tools/media-tool-shared.ts b/src/agents/tools/media-tool-shared.ts index 4e384380442..2449d08b836 100644 --- a/src/agents/tools/media-tool-shared.ts +++ b/src/agents/tools/media-tool-shared.ts @@ -1,8 +1,17 @@ import { type Api, type Model } from "@mariozechner/pi-ai"; import type { OpenClawConfig } from "../../config/config.js"; +import type { AgentModelConfig } from "../../config/types.agents-shared.js"; import { getDefaultLocalRoots } from "../../media/web-media.js"; +import { normalizeProviderId } from "../provider-id.js"; import type { ImageModelConfig } from "./image-tool.helpers.js"; -import type { ToolModelConfig } from "./model-config.helpers.js"; +import { + buildToolModelConfigFromCandidates, + coerceToolModelConfig, + hasAuthForProvider, + hasToolModelConfig, + resolveDefaultModelRef, + type ToolModelConfig, +} from "./model-config.helpers.js"; import { getApiKeyForModel, normalizeWorkspaceDir, requireApiKey } from "./tool-runtime.helpers.js"; type TextToolAttempt = { @@ -39,9 +48,16 @@ export function applyVideoGenerationModelConfigDefaults( return applyAgentDefaultModelConfig(cfg, "videoGenerationModel", videoGenerationModelConfig); } +export function applyMusicGenerationModelConfigDefaults( + cfg: OpenClawConfig | undefined, + musicGenerationModelConfig: ToolModelConfig, +): OpenClawConfig | undefined { + return applyAgentDefaultModelConfig(cfg, "musicGenerationModel", musicGenerationModelConfig); +} + function applyAgentDefaultModelConfig( cfg: OpenClawConfig | undefined, - key: "imageModel" | "imageGenerationModel" | "videoGenerationModel", + key: "imageModel" | "imageGenerationModel" | "videoGenerationModel" | "musicGenerationModel", modelConfig: ToolModelConfig, ): OpenClawConfig | undefined { if (!cfg) { @@ -59,6 +75,125 @@ function applyAgentDefaultModelConfig( }; } +type CapabilityProvider = { + id: string; + aliases?: string[]; + defaultModel?: string; + isConfigured?: (ctx: { cfg?: OpenClawConfig; agentDir?: string }) => boolean; +}; + +export function findCapabilityProviderById(params: { + providers: T[]; + providerId?: string; +}): T | undefined { + const selectedProvider = normalizeProviderId(params.providerId ?? ""); + return params.providers.find( + (provider) => + normalizeProviderId(provider.id) === selectedProvider || + (provider.aliases ?? []).some((alias) => normalizeProviderId(alias) === selectedProvider), + ); +} + +export function isCapabilityProviderConfigured(params: { + providers: T[]; + provider?: T; + providerId?: string; + cfg?: OpenClawConfig; + agentDir?: string; +}): boolean { + const provider = + params.provider ?? + findCapabilityProviderById({ + providers: params.providers, + providerId: params.providerId, + }); + if (!provider) { + return params.providerId + ? hasAuthForProvider({ provider: params.providerId, agentDir: params.agentDir }) + : false; + } + if (provider.isConfigured) { + return provider.isConfigured({ + cfg: params.cfg, + agentDir: params.agentDir, + }); + } + return hasAuthForProvider({ provider: provider.id, agentDir: params.agentDir }); +} + +export function resolveCapabilityModelCandidatesForTool(params: { + cfg?: OpenClawConfig; + agentDir?: string; + providers: T[]; +}): string[] { + const providerDefaults = new Map(); + for (const provider of params.providers) { + const providerId = provider.id.trim(); + const modelId = provider.defaultModel?.trim(); + if ( + !providerId || + !modelId || + providerDefaults.has(providerId) || + !isCapabilityProviderConfigured({ + providers: params.providers, + provider, + cfg: params.cfg, + agentDir: params.agentDir, + }) + ) { + continue; + } + providerDefaults.set(providerId, `${providerId}/${modelId}`); + } + + const primaryProvider = resolveDefaultModelRef(params.cfg).provider; + const orderedProviders = [ + primaryProvider, + ...[...providerDefaults.keys()] + .filter((providerId) => providerId !== primaryProvider) + .toSorted(), + ]; + const orderedRefs: string[] = []; + const seen = new Set(); + for (const providerId of orderedProviders) { + const ref = providerDefaults.get(providerId); + if (!ref || seen.has(ref)) { + continue; + } + seen.add(ref); + orderedRefs.push(ref); + } + return orderedRefs; +} + +export function resolveCapabilityModelConfigForTool(params: { + cfg?: OpenClawConfig; + agentDir?: string; + modelConfig?: AgentModelConfig; + providers: T[]; +}): ToolModelConfig | null { + const explicit = coerceToolModelConfig(params.modelConfig); + if (hasToolModelConfig(explicit)) { + return explicit; + } + return buildToolModelConfigFromCandidates({ + explicit, + agentDir: params.agentDir, + candidates: resolveCapabilityModelCandidatesForTool({ + cfg: params.cfg, + agentDir: params.agentDir, + providers: params.providers, + }), + isProviderConfigured: (providerId) => + isCapabilityProviderConfigured({ + providers: params.providers, + providerId, + cfg: params.cfg, + agentDir: params.agentDir, + }), + }); +} + export function resolveMediaToolLocalRoots( workspaceDirRaw: string | undefined, options?: { workspaceOnly?: boolean }, diff --git a/src/agents/tools/music-generate-background.test.ts b/src/agents/tools/music-generate-background.test.ts new file mode 100644 index 00000000000..b75a1406846 --- /dev/null +++ b/src/agents/tools/music-generate-background.test.ts @@ -0,0 +1,121 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { MUSIC_GENERATION_TASK_KIND } from "../music-generation-task-status.js"; +import { + createMusicGenerationTaskRun, + recordMusicGenerationTaskProgress, + wakeMusicGenerationTaskCompletion, +} from "./music-generate-background.js"; + +const taskExecutorMocks = vi.hoisted(() => ({ + createRunningTaskRun: vi.fn(), + recordTaskRunProgressByRunId: vi.fn(), + completeTaskRunByRunId: vi.fn(), + failTaskRunByRunId: vi.fn(), +})); + +const announceDeliveryMocks = vi.hoisted(() => ({ + deliverSubagentAnnouncement: vi.fn(), +})); + +vi.mock("../../tasks/task-executor.js", () => taskExecutorMocks); +vi.mock("../subagent-announce-delivery.js", () => announceDeliveryMocks); + +describe("music generate background helpers", () => { + beforeEach(() => { + taskExecutorMocks.createRunningTaskRun.mockReset(); + taskExecutorMocks.recordTaskRunProgressByRunId.mockReset(); + announceDeliveryMocks.deliverSubagentAnnouncement.mockReset(); + }); + + it("creates a running task with queued progress text", () => { + taskExecutorMocks.createRunningTaskRun.mockReturnValue({ + taskId: "task-123", + }); + + const handle = createMusicGenerationTaskRun({ + sessionKey: "agent:main:discord:direct:123", + requesterOrigin: { + channel: "discord", + to: "channel:1", + }, + prompt: "night-drive synthwave", + providerId: "google", + }); + + expect(handle).toMatchObject({ + taskId: "task-123", + requesterSessionKey: "agent:main:discord:direct:123", + taskLabel: "night-drive synthwave", + }); + expect(taskExecutorMocks.createRunningTaskRun).toHaveBeenCalledWith( + expect.objectContaining({ + taskKind: MUSIC_GENERATION_TASK_KIND, + sourceId: "music_generate:google", + progressSummary: "Queued music generation", + }), + ); + }); + + it("records task progress updates", () => { + recordMusicGenerationTaskProgress({ + handle: { + taskId: "task-123", + runId: "tool:music_generate:abc", + requesterSessionKey: "agent:main:discord:direct:123", + taskLabel: "night-drive synthwave", + }, + progressSummary: "Saving generated music", + }); + + expect(taskExecutorMocks.recordTaskRunProgressByRunId).toHaveBeenCalledWith( + expect.objectContaining({ + runId: "tool:music_generate:abc", + progressSummary: "Saving generated music", + }), + ); + }); + + it("wakes the session with a music-generation completion event", async () => { + announceDeliveryMocks.deliverSubagentAnnouncement.mockResolvedValue({ + delivered: true, + path: "direct", + }); + + await wakeMusicGenerationTaskCompletion({ + handle: { + taskId: "task-123", + runId: "tool:music_generate:abc", + requesterSessionKey: "agent:main:discord:direct:123", + requesterOrigin: { + channel: "discord", + to: "channel:1", + threadId: "thread-1", + }, + taskLabel: "night-drive synthwave", + }, + status: "ok", + statusLabel: "completed successfully", + result: "Generated 1 track.\nMEDIA:/tmp/generated-night-drive.mp3", + }); + + expect(announceDeliveryMocks.deliverSubagentAnnouncement).toHaveBeenCalledWith( + expect.objectContaining({ + requesterSessionKey: "agent:main:discord:direct:123", + requesterOrigin: expect.objectContaining({ + channel: "discord", + to: "channel:1", + }), + expectsCompletionMessage: true, + internalEvents: [ + expect.objectContaining({ + source: "music_generation", + announceType: "music generation task", + status: "ok", + result: expect.stringContaining("MEDIA:/tmp/generated-night-drive.mp3"), + replyInstruction: expect.stringContaining("include those exact MEDIA: lines"), + }), + ], + }), + ); + }); +}); diff --git a/src/agents/tools/music-generate-background.ts b/src/agents/tools/music-generate-background.ts new file mode 100644 index 00000000000..b4508665ad5 --- /dev/null +++ b/src/agents/tools/music-generate-background.ts @@ -0,0 +1,81 @@ +import type { DeliveryContext } from "../../utils/delivery-context.js"; +import { MUSIC_GENERATION_TASK_KIND } from "../music-generation-task-status.js"; +import { + completeMediaGenerationTaskRun, + createMediaGenerationTaskRun, + failMediaGenerationTaskRun, + recordMediaGenerationTaskProgress, + wakeMediaGenerationTaskCompletion, + type MediaGenerationTaskHandle, +} from "./media-generate-background-shared.js"; + +export type MusicGenerationTaskHandle = MediaGenerationTaskHandle; + +export function createMusicGenerationTaskRun(params: { + sessionKey?: string; + requesterOrigin?: DeliveryContext; + prompt: string; + providerId?: string; +}): MusicGenerationTaskHandle | null { + return createMediaGenerationTaskRun({ + sessionKey: params.sessionKey, + requesterOrigin: params.requesterOrigin, + prompt: params.prompt, + providerId: params.providerId, + toolName: "music_generate", + taskKind: MUSIC_GENERATION_TASK_KIND, + label: "Music generation", + queuedProgressSummary: "Queued music generation", + }); +} + +export function recordMusicGenerationTaskProgress(params: { + handle: MusicGenerationTaskHandle | null; + progressSummary: string; + eventSummary?: string; +}) { + recordMediaGenerationTaskProgress(params); +} + +export function completeMusicGenerationTaskRun(params: { + handle: MusicGenerationTaskHandle | null; + provider: string; + model: string; + count: number; + paths: string[]; +}) { + completeMediaGenerationTaskRun({ + ...params, + generatedLabel: "track", + }); +} + +export function failMusicGenerationTaskRun(params: { + handle: MusicGenerationTaskHandle | null; + error: unknown; +}) { + failMediaGenerationTaskRun({ + ...params, + progressSummary: "Music generation failed", + }); +} + +export async function wakeMusicGenerationTaskCompletion(params: { + handle: MusicGenerationTaskHandle | null; + status: "ok" | "error"; + statusLabel: string; + result: string; + statsLine?: string; +}) { + await wakeMediaGenerationTaskCompletion({ + handle: params.handle, + status: params.status, + statusLabel: params.statusLabel, + result: params.result, + statsLine: params.statsLine, + eventSource: "music_generation", + announceType: "music generation task", + toolName: "music_generate", + completionLabel: "music", + }); +} diff --git a/src/agents/tools/music-generate-tool.actions.ts b/src/agents/tools/music-generate-tool.actions.ts new file mode 100644 index 00000000000..5c477ef5c00 --- /dev/null +++ b/src/agents/tools/music-generate-tool.actions.ts @@ -0,0 +1,130 @@ +import type { OpenClawConfig } from "../../config/config.js"; +import { listRuntimeMusicGenerationProviders } from "../../music-generation/runtime.js"; +import { getProviderEnvVars } from "../../secrets/provider-env-vars.js"; +import { + buildMusicGenerationTaskStatusDetails, + buildMusicGenerationTaskStatusText, + findActiveMusicGenerationTaskForSession, +} from "../music-generation-task-status.js"; + +type MusicGenerateActionResult = { + content: Array<{ type: "text"; text: string }>; + details: Record; +}; + +function getMusicGenerationProviderAuthEnvVars(providerId: string): string[] { + return getProviderEnvVars(providerId); +} + +export function createMusicGenerateListActionResult( + config?: OpenClawConfig, +): MusicGenerateActionResult { + const providers = listRuntimeMusicGenerationProviders({ config }); + if (providers.length === 0) { + return { + content: [{ type: "text", text: "No music-generation providers are registered." }], + details: { providers: [] }, + }; + } + const lines = providers.map((provider) => { + const authHints = getMusicGenerationProviderAuthEnvVars(provider.id); + const capabilities = [ + provider.capabilities.maxTracks ? `maxTracks=${provider.capabilities.maxTracks}` : null, + provider.capabilities.maxInputImages + ? `maxInputImages=${provider.capabilities.maxInputImages}` + : null, + provider.capabilities.maxDurationSeconds + ? `maxDurationSeconds=${provider.capabilities.maxDurationSeconds}` + : null, + provider.capabilities.supportsLyrics ? "lyrics" : null, + provider.capabilities.supportsInstrumental ? "instrumental" : null, + provider.capabilities.supportsDuration ? "duration" : null, + provider.capabilities.supportsFormat ? "format" : null, + provider.capabilities.supportedFormats?.length + ? `supportedFormats=${provider.capabilities.supportedFormats.join("/")}` + : null, + provider.capabilities.supportedFormatsByModel && + Object.keys(provider.capabilities.supportedFormatsByModel).length > 0 + ? `supportedFormatsByModel=${Object.entries(provider.capabilities.supportedFormatsByModel) + .map(([modelId, formats]) => `${modelId}:${formats.join("/")}`) + .join("; ")}` + : null, + ] + .filter((entry): entry is string => Boolean(entry)) + .join(", "); + return [ + `${provider.id}: default=${provider.defaultModel ?? "none"}`, + provider.models?.length ? `models=${provider.models.join(", ")}` : null, + capabilities ? `capabilities=${capabilities}` : null, + authHints.length > 0 ? `auth=${authHints.join(" / ")}` : null, + ] + .filter((entry): entry is string => Boolean(entry)) + .join(" | "); + }); + return { + content: [{ type: "text", text: lines.join("\n") }], + details: { + providers: providers.map((provider) => ({ + id: provider.id, + defaultModel: provider.defaultModel, + models: provider.models ?? [], + authEnvVars: getMusicGenerationProviderAuthEnvVars(provider.id), + capabilities: provider.capabilities, + })), + }, + }; +} + +export function createMusicGenerateStatusActionResult( + sessionKey?: string, +): MusicGenerateActionResult { + const activeTask = findActiveMusicGenerationTaskForSession(sessionKey); + if (!activeTask) { + return { + content: [ + { + type: "text", + text: "No active music generation task is currently running for this session.", + }, + ], + details: { + action: "status", + active: false, + }, + }; + } + return { + content: [ + { + type: "text", + text: buildMusicGenerationTaskStatusText(activeTask), + }, + ], + details: { + action: "status", + ...buildMusicGenerationTaskStatusDetails(activeTask), + }, + }; +} + +export function createMusicGenerateDuplicateGuardResult( + sessionKey?: string, +): MusicGenerateActionResult | null { + const activeTask = findActiveMusicGenerationTaskForSession(sessionKey); + if (!activeTask) { + return null; + } + return { + content: [ + { + type: "text", + text: buildMusicGenerationTaskStatusText(activeTask, { duplicateGuard: true }), + }, + ], + details: { + action: "status", + duplicateGuard: true, + ...buildMusicGenerationTaskStatusDetails(activeTask), + }, + }; +} diff --git a/src/agents/tools/music-generate-tool.status.test.ts b/src/agents/tools/music-generate-tool.status.test.ts new file mode 100644 index 00000000000..b201af9e118 --- /dev/null +++ b/src/agents/tools/music-generate-tool.status.test.ts @@ -0,0 +1,106 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import * as musicGenerationRuntime from "../../music-generation/runtime.js"; +import { MUSIC_GENERATION_TASK_KIND } from "../music-generation-task-status.js"; +import { + createMusicGenerateDuplicateGuardResult, + createMusicGenerateStatusActionResult, +} from "./music-generate-tool.actions.js"; + +const taskRuntimeInternalMocks = vi.hoisted(() => ({ + listTasksForOwnerKey: vi.fn(), +})); + +vi.mock("../../tasks/runtime-internal.js", () => taskRuntimeInternalMocks); + +describe("createMusicGenerateTool status actions", () => { + beforeEach(() => { + vi.restoreAllMocks(); + vi.spyOn(musicGenerationRuntime, "listRuntimeMusicGenerationProviders").mockReturnValue([]); + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReset(); + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReturnValue([]); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it("returns active task status instead of starting a duplicate generation", async () => { + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReturnValue([ + { + taskId: "task-active", + runtime: "cli", + taskKind: MUSIC_GENERATION_TASK_KIND, + sourceId: "music_generate:google", + requesterSessionKey: "agent:main:discord:direct:123", + ownerKey: "agent:main:discord:direct:123", + scopeKind: "session", + runId: "tool:music_generate:active", + task: "night-drive synthwave", + status: "running", + deliveryStatus: "not_applicable", + notifyPolicy: "silent", + createdAt: Date.now(), + progressSummary: "Generating music", + }, + ]); + + const result = createMusicGenerateDuplicateGuardResult("agent:main:discord:direct:123"); + const text = (result?.content?.[0] as { text: string } | undefined)?.text ?? ""; + + expect(result).not.toBeNull(); + expect(text).toContain("Music generation task task-active is already running with google."); + expect(text).toContain("Do not call music_generate again for this request."); + expect(result?.details).toMatchObject({ + action: "status", + duplicateGuard: true, + active: true, + existingTask: true, + status: "running", + taskKind: MUSIC_GENERATION_TASK_KIND, + provider: "google", + task: { + taskId: "task-active", + runId: "tool:music_generate:active", + }, + progressSummary: "Generating music", + }); + }); + + it("reports active task status when action=status is requested", async () => { + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReturnValue([ + { + taskId: "task-active", + runtime: "cli", + taskKind: MUSIC_GENERATION_TASK_KIND, + sourceId: "music_generate:minimax", + requesterSessionKey: "agent:main:discord:direct:123", + ownerKey: "agent:main:discord:direct:123", + scopeKind: "session", + runId: "tool:music_generate:active", + task: "night-drive synthwave", + status: "queued", + deliveryStatus: "not_applicable", + notifyPolicy: "silent", + createdAt: Date.now(), + progressSummary: "Queued music generation", + }, + ]); + + const result = createMusicGenerateStatusActionResult("agent:main:discord:direct:123"); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + + expect(text).toContain("Music generation task task-active is already queued with minimax."); + expect(result.details).toMatchObject({ + action: "status", + active: true, + existingTask: true, + status: "queued", + taskKind: MUSIC_GENERATION_TASK_KIND, + provider: "minimax", + task: { + taskId: "task-active", + }, + progressSummary: "Queued music generation", + }); + }); +}); diff --git a/src/agents/tools/music-generate-tool.test.ts b/src/agents/tools/music-generate-tool.test.ts new file mode 100644 index 00000000000..710db3fe3cf --- /dev/null +++ b/src/agents/tools/music-generate-tool.test.ts @@ -0,0 +1,273 @@ +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; +import * as mediaStore from "../../media/store.js"; +import * as musicGenerationRuntime from "../../music-generation/runtime.js"; +import * as musicGenerateBackground from "./music-generate-background.js"; +import { createMusicGenerateTool } from "./music-generate-tool.js"; + +const taskRuntimeInternalMocks = vi.hoisted(() => ({ + listTasksForOwnerKey: vi.fn(), +})); + +const taskExecutorMocks = vi.hoisted(() => ({ + createRunningTaskRun: vi.fn(), + completeTaskRunByRunId: vi.fn(), + failTaskRunByRunId: vi.fn(), + recordTaskRunProgressByRunId: vi.fn(), +})); + +vi.mock("../../tasks/runtime-internal.js", () => taskRuntimeInternalMocks); +vi.mock("../../tasks/task-executor.js", () => taskExecutorMocks); + +function asConfig(value: unknown): OpenClawConfig { + return value as OpenClawConfig; +} + +describe("createMusicGenerateTool", () => { + beforeEach(() => { + vi.restoreAllMocks(); + vi.spyOn(musicGenerationRuntime, "listRuntimeMusicGenerationProviders").mockReturnValue([]); + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReset(); + taskRuntimeInternalMocks.listTasksForOwnerKey.mockReturnValue([]); + taskExecutorMocks.createRunningTaskRun.mockReset(); + taskExecutorMocks.completeTaskRunByRunId.mockReset(); + taskExecutorMocks.failTaskRunByRunId.mockReset(); + taskExecutorMocks.recordTaskRunProgressByRunId.mockReset(); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + }); + + it("returns null when no music-generation config or auth-backed provider is available", () => { + vi.spyOn(musicGenerationRuntime, "listRuntimeMusicGenerationProviders").mockReturnValue([]); + expect(createMusicGenerateTool({ config: asConfig({}) })).toBeNull(); + }); + + it("registers when music-generation config is present", () => { + expect( + createMusicGenerateTool({ + config: asConfig({ + agents: { + defaults: { + musicGenerationModel: { primary: "google/lyria-3-clip-preview" }, + }, + }, + }), + }), + ).not.toBeNull(); + }); + + it("generates tracks, saves them, and emits MEDIA paths without a session-backed detach", async () => { + taskExecutorMocks.createRunningTaskRun.mockReturnValue({ + taskId: "task-123", + runtime: "cli", + requesterSessionKey: "agent:main:discord:direct:123", + ownerKey: "agent:main:discord:direct:123", + scopeKind: "session", + task: "night-drive synthwave", + status: "running", + deliveryStatus: "not_applicable", + notifyPolicy: "silent", + createdAt: Date.now(), + }); + vi.spyOn(musicGenerationRuntime, "generateMusic").mockResolvedValue({ + provider: "google", + model: "lyria-3-clip-preview", + attempts: [], + tracks: [ + { + buffer: Buffer.from("music-bytes"), + mimeType: "audio/mpeg", + fileName: "night-drive.mp3", + }, + ], + lyrics: ["wake the city up"], + metadata: { taskId: "music-task-1" }, + }); + vi.spyOn(mediaStore, "saveMediaBuffer").mockResolvedValueOnce({ + path: "/tmp/generated-night-drive.mp3", + id: "generated-night-drive.mp3", + size: 11, + contentType: "audio/mpeg", + }); + + const tool = createMusicGenerateTool({ + config: asConfig({ + agents: { + defaults: { + musicGenerationModel: { primary: "google/lyria-3-clip-preview" }, + }, + }, + }), + }); + expect(tool).not.toBeNull(); + if (!tool) { + throw new Error("expected music_generate tool"); + } + + const result = await tool.execute("call-1", { + prompt: "night-drive synthwave", + instrumental: true, + }); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + + expect(text).toContain("Generated 1 track with google/lyria-3-clip-preview."); + expect(text).toContain("Lyrics returned."); + expect(text).toContain("MEDIA:/tmp/generated-night-drive.mp3"); + expect(result.details).toMatchObject({ + provider: "google", + model: "lyria-3-clip-preview", + count: 1, + instrumental: true, + lyrics: ["wake the city up"], + task: { + taskId: "task-123", + }, + media: { + mediaUrls: ["/tmp/generated-night-drive.mp3"], + }, + paths: ["/tmp/generated-night-drive.mp3"], + metadata: { taskId: "music-task-1" }, + }); + expect(taskExecutorMocks.createRunningTaskRun).not.toHaveBeenCalled(); + expect(taskExecutorMocks.completeTaskRunByRunId).not.toHaveBeenCalled(); + }); + + it("starts background generation and wakes the session with MEDIA lines", async () => { + taskExecutorMocks.createRunningTaskRun.mockReturnValue({ + taskId: "task-123", + runtime: "cli", + requesterSessionKey: "agent:main:discord:direct:123", + ownerKey: "agent:main:discord:direct:123", + scopeKind: "session", + task: "night-drive synthwave", + status: "running", + deliveryStatus: "not_applicable", + notifyPolicy: "silent", + createdAt: Date.now(), + }); + const wakeSpy = vi + .spyOn(musicGenerateBackground, "wakeMusicGenerationTaskCompletion") + .mockResolvedValue(undefined); + vi.spyOn(musicGenerationRuntime, "generateMusic").mockResolvedValue({ + provider: "google", + model: "lyria-3-clip-preview", + attempts: [], + tracks: [ + { + buffer: Buffer.from("music-bytes"), + mimeType: "audio/mpeg", + fileName: "night-drive.mp3", + }, + ], + metadata: { taskId: "music-task-1" }, + }); + vi.spyOn(mediaStore, "saveMediaBuffer").mockResolvedValueOnce({ + path: "/tmp/generated-night-drive.mp3", + id: "generated-night-drive.mp3", + size: 11, + contentType: "audio/mpeg", + }); + + let scheduledWork: (() => Promise) | undefined; + const tool = createMusicGenerateTool({ + config: asConfig({ + agents: { + defaults: { + musicGenerationModel: { primary: "google/lyria-3-clip-preview" }, + }, + }, + }), + agentSessionKey: "agent:main:discord:direct:123", + requesterOrigin: { + channel: "discord", + to: "channel:1", + }, + scheduleBackgroundWork: (work) => { + scheduledWork = work; + }, + }); + if (!tool) { + throw new Error("expected music_generate tool"); + } + + const result = await tool.execute("call-1", { + prompt: "night-drive synthwave", + instrumental: true, + }); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + + expect(text).toContain("Background task started for music generation (task-123)."); + expect(text).toContain("Do not call music_generate again for this request."); + expect(result.details).toMatchObject({ + async: true, + status: "started", + task: { + taskId: "task-123", + }, + instrumental: true, + }); + expect(typeof scheduledWork).toBe("function"); + await scheduledWork?.(); + expect(taskExecutorMocks.recordTaskRunProgressByRunId).toHaveBeenCalledWith( + expect.objectContaining({ + runId: expect.stringMatching(/^tool:music_generate:/), + progressSummary: "Generating music", + }), + ); + expect(taskExecutorMocks.completeTaskRunByRunId).toHaveBeenCalledWith( + expect.objectContaining({ + runId: expect.stringMatching(/^tool:music_generate:/), + }), + ); + expect(wakeSpy).toHaveBeenCalledWith( + expect.objectContaining({ + handle: expect.objectContaining({ + taskId: "task-123", + }), + status: "ok", + result: expect.stringContaining("MEDIA:/tmp/generated-night-drive.mp3"), + }), + ); + }); + + it("lists provider capabilities", async () => { + vi.spyOn(musicGenerationRuntime, "listRuntimeMusicGenerationProviders").mockReturnValue([ + { + id: "minimax", + defaultModel: "music-2.5+", + models: ["music-2.5+"], + capabilities: { + maxTracks: 1, + supportsLyrics: true, + supportsInstrumental: true, + supportsDuration: true, + supportsFormat: true, + supportedFormats: ["mp3"], + }, + generateMusic: vi.fn(async () => { + throw new Error("not used"); + }), + }, + ]); + + const tool = createMusicGenerateTool({ + config: asConfig({ + agents: { + defaults: { + musicGenerationModel: { primary: "minimax/music-2.5+" }, + }, + }, + }), + }); + if (!tool) { + throw new Error("expected music_generate tool"); + } + + const result = await tool.execute("call-1", { action: "list" }); + const text = (result.content?.[0] as { text: string } | undefined)?.text ?? ""; + expect(text).toContain("supportedFormats=mp3"); + expect(text).toContain("instrumental"); + }); +}); diff --git a/src/agents/tools/music-generate-tool.ts b/src/agents/tools/music-generate-tool.ts new file mode 100644 index 00000000000..bfc2a805612 --- /dev/null +++ b/src/agents/tools/music-generate-tool.ts @@ -0,0 +1,703 @@ +import { Type } from "@sinclair/typebox"; +import type { OpenClawConfig } from "../../config/config.js"; +import { loadConfig } from "../../config/config.js"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; +import { saveMediaBuffer } from "../../media/store.js"; +import { loadWebMedia } from "../../media/web-media.js"; +import { parseMusicGenerationModelRef } from "../../music-generation/model-ref.js"; +import { + generateMusic, + listRuntimeMusicGenerationProviders, +} from "../../music-generation/runtime.js"; +import type { MusicGenerationOutputFormat } from "../../music-generation/types.js"; +import type { + MusicGenerationProvider, + MusicGenerationSourceImage, +} from "../../music-generation/types.js"; +import { readSnakeCaseParamRaw } from "../../param-key.js"; +import { resolveUserPath } from "../../utils.js"; +import type { DeliveryContext } from "../../utils/delivery-context.js"; +import { + ToolInputError, + readNumberParam, + readStringArrayParam, + readStringParam, +} from "./common.js"; +import { decodeDataUrl } from "./image-tool.helpers.js"; +import { + applyMusicGenerationModelConfigDefaults, + findCapabilityProviderById, + resolveCapabilityModelConfigForTool, + resolveMediaToolLocalRoots, +} from "./media-tool-shared.js"; +import { type ToolModelConfig } from "./model-config.helpers.js"; +import { + completeMusicGenerationTaskRun, + createMusicGenerationTaskRun, + failMusicGenerationTaskRun, + recordMusicGenerationTaskProgress, + type MusicGenerationTaskHandle, + wakeMusicGenerationTaskCompletion, +} from "./music-generate-background.js"; +import { + createMusicGenerateDuplicateGuardResult, + createMusicGenerateListActionResult, + createMusicGenerateStatusActionResult, +} from "./music-generate-tool.actions.js"; +import { + createSandboxBridgeReadFile, + resolveSandboxedBridgeMediaPath, + type AnyAgentTool, + type SandboxFsBridge, + type ToolFsPolicy, +} from "./tool-runtime.helpers.js"; + +const log = createSubsystemLogger("agents/tools/music-generate"); +const MAX_INPUT_IMAGES = 10; +const SUPPORTED_OUTPUT_FORMATS = new Set(["mp3", "wav"]); + +const MusicGenerateToolSchema = Type.Object({ + action: Type.Optional( + Type.String({ + description: + 'Optional action: "generate" (default), "status" to inspect the active session task, or "list" to inspect available providers/models.', + }), + ), + prompt: Type.Optional(Type.String({ description: "Music generation prompt." })), + lyrics: Type.Optional( + Type.String({ + description: "Optional lyrics to guide sung output when the provider supports it.", + }), + ), + instrumental: Type.Optional( + Type.Boolean({ + description: "Optional toggle for instrumental-only output when the provider supports it.", + }), + ), + image: Type.Optional( + Type.String({ + description: "Optional single reference image path or URL.", + }), + ), + images: Type.Optional( + Type.Array(Type.String(), { + description: `Optional reference images (up to ${MAX_INPUT_IMAGES}).`, + }), + ), + model: Type.Optional( + Type.String({ + description: "Optional provider/model override, e.g. google/lyria-3-pro-preview.", + }), + ), + durationSeconds: Type.Optional( + Type.Number({ + description: "Optional target duration in seconds when the provider supports duration hints.", + minimum: 1, + }), + ), + format: Type.Optional( + Type.String({ + description: 'Optional output format hint: "mp3" or "wav" when the provider supports it.', + }), + ), + filename: Type.Optional( + Type.String({ + description: + "Optional output filename hint. OpenClaw preserves the basename and saves under its managed media directory.", + }), + ), +}); + +export function resolveMusicGenerationModelConfigForTool(params: { + cfg?: OpenClawConfig; + agentDir?: string; +}): ToolModelConfig | null { + return resolveCapabilityModelConfigForTool({ + cfg: params.cfg, + agentDir: params.agentDir, + modelConfig: params.cfg?.agents?.defaults?.musicGenerationModel, + providers: listRuntimeMusicGenerationProviders({ config: params.cfg }), + }); +} + +function resolveSelectedMusicGenerationProvider(params: { + config?: OpenClawConfig; + musicGenerationModelConfig: ToolModelConfig; + modelOverride?: string; +}): MusicGenerationProvider | undefined { + const selectedRef = + parseMusicGenerationModelRef(params.modelOverride) ?? + parseMusicGenerationModelRef(params.musicGenerationModelConfig.primary); + if (!selectedRef) { + return undefined; + } + return findCapabilityProviderById({ + providers: listRuntimeMusicGenerationProviders({ config: params.config }), + providerId: selectedRef.provider, + }); +} + +function resolveAction(args: Record): "generate" | "list" | "status" { + const raw = readStringParam(args, "action"); + if (!raw) { + return "generate"; + } + const normalized = raw.trim().toLowerCase(); + if (normalized === "generate" || normalized === "list" || normalized === "status") { + return normalized; + } + throw new ToolInputError('action must be "generate", "status", or "list"'); +} + +function readBooleanParam(params: Record, key: string): boolean | undefined { + const raw = readSnakeCaseParamRaw(params, key); + if (typeof raw === "boolean") { + return raw; + } + if (typeof raw === "string") { + const normalized = raw.trim().toLowerCase(); + if (normalized === "true") { + return true; + } + if (normalized === "false") { + return false; + } + } + return undefined; +} + +function normalizeOutputFormat(raw: string | undefined): MusicGenerationOutputFormat | undefined { + const normalized = raw?.trim().toLowerCase() as MusicGenerationOutputFormat | undefined; + if (!normalized) { + return undefined; + } + if (SUPPORTED_OUTPUT_FORMATS.has(normalized)) { + return normalized; + } + throw new ToolInputError('format must be one of "mp3" or "wav"'); +} + +function normalizeReferenceImageInputs(args: Record): string[] { + const single = readStringParam(args, "image"); + const multiple = readStringArrayParam(args, "images"); + const combined = [...(single ? [single] : []), ...(multiple ?? [])]; + const deduped: string[] = []; + const seen = new Set(); + for (const candidate of combined) { + const trimmed = candidate.trim(); + const dedupe = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed; + if (!dedupe || seen.has(dedupe)) { + continue; + } + seen.add(dedupe); + deduped.push(trimmed); + } + if (deduped.length > MAX_INPUT_IMAGES) { + throw new ToolInputError( + `Too many reference images: ${deduped.length} provided, maximum is ${MAX_INPUT_IMAGES}.`, + ); + } + return deduped; +} + +function validateMusicGenerationCapabilities(params: { + provider: MusicGenerationProvider | undefined; + model?: string; + inputImageCount: number; + lyrics?: string; + instrumental?: boolean; + durationSeconds?: number; + format?: MusicGenerationOutputFormat; +}) { + const provider = params.provider; + if (!provider) { + return; + } + const caps = provider.capabilities; + if (params.inputImageCount > 0) { + const maxInputImages = caps.maxInputImages ?? MAX_INPUT_IMAGES; + if (params.inputImageCount > maxInputImages) { + throw new ToolInputError( + `${provider.id} supports at most ${maxInputImages} reference image${maxInputImages === 1 ? "" : "s"}.`, + ); + } + } + if (params.lyrics?.trim() && !caps.supportsLyrics) { + throw new ToolInputError(`${provider.id} does not support explicit lyrics input.`); + } + if (typeof params.instrumental === "boolean" && !caps.supportsInstrumental) { + throw new ToolInputError(`${provider.id} does not support instrumental toggles.`); + } + if (typeof params.durationSeconds === "number" && !caps.supportsDuration) { + throw new ToolInputError(`${provider.id} does not support duration hints.`); + } + if (typeof params.durationSeconds === "number" && typeof caps.maxDurationSeconds === "number") { + if (params.durationSeconds > caps.maxDurationSeconds) { + throw new ToolInputError( + `${provider.id} supports at most ${caps.maxDurationSeconds} seconds per track.`, + ); + } + } + if (params.format) { + if (!caps.supportsFormat) { + throw new ToolInputError(`${provider.id} does not support explicit output-format overrides.`); + } + const supportedFormats = + caps.supportedFormatsByModel?.[params.model ?? ""] ?? caps.supportedFormats ?? []; + if (supportedFormats.length > 0 && !supportedFormats.includes(params.format)) { + throw new ToolInputError( + `${provider.id} supports ${supportedFormats.join(", ")} output${params.model ? ` for ${params.model}` : ""}.`, + ); + } + } +} + +type MusicGenerateSandboxConfig = { + root: string; + bridge: SandboxFsBridge; +}; + +type MusicGenerateBackgroundScheduler = (work: () => Promise) => void; + +function defaultScheduleMusicGenerateBackgroundWork(work: () => Promise) { + queueMicrotask(() => { + void work().catch((error) => { + log.error("Detached music generation job crashed", { + error, + }); + }); + }); +} + +async function loadReferenceImages(params: { + inputs: string[]; + workspaceDir?: string; + sandboxConfig: { root: string; bridge: SandboxFsBridge; workspaceOnly: boolean } | null; +}): Promise< + Array<{ + sourceImage: MusicGenerationSourceImage; + resolvedInput: string; + rewrittenFrom?: string; + }> +> { + const loaded: Array<{ + sourceImage: MusicGenerationSourceImage; + resolvedInput: string; + rewrittenFrom?: string; + }> = []; + + for (const rawInput of params.inputs) { + const trimmed = rawInput.trim(); + const inputRaw = trimmed.startsWith("@") ? trimmed.slice(1).trim() : trimmed; + if (!inputRaw) { + throw new ToolInputError("image required (empty string in array)"); + } + const looksLikeWindowsDrivePath = /^[a-zA-Z]:[\\/]/.test(inputRaw); + const hasScheme = /^[a-z][a-z0-9+.-]*:/i.test(inputRaw); + const isFileUrl = /^file:/i.test(inputRaw); + const isHttpUrl = /^https?:\/\//i.test(inputRaw); + const isDataUrl = /^data:/i.test(inputRaw); + if (hasScheme && !looksLikeWindowsDrivePath && !isFileUrl && !isHttpUrl && !isDataUrl) { + throw new ToolInputError( + `Unsupported image reference: ${rawInput}. Use a file path, a file:// URL, a data: URL, or an http(s) URL.`, + ); + } + if (params.sandboxConfig && isHttpUrl) { + throw new ToolInputError("Sandboxed music_generate does not allow remote image URLs."); + } + + const resolvedInput = params.sandboxConfig + ? inputRaw + : inputRaw.startsWith("~") + ? resolveUserPath(inputRaw) + : inputRaw; + const resolvedPathInfo: { resolved: string; rewrittenFrom?: string } = isDataUrl + ? { resolved: "" } + : params.sandboxConfig + ? await resolveSandboxedBridgeMediaPath({ + sandbox: params.sandboxConfig, + mediaPath: resolvedInput, + inboundFallbackDir: "media/inbound", + }) + : { + resolved: resolvedInput.startsWith("file://") + ? resolvedInput.slice("file://".length) + : resolvedInput, + }; + const resolvedPath = isDataUrl ? null : resolvedPathInfo.resolved; + const localRoots = resolveMediaToolLocalRoots( + params.workspaceDir, + { + workspaceOnly: params.sandboxConfig?.workspaceOnly === true, + }, + resolvedPath ? [resolvedPath] : undefined, + ); + const media = isDataUrl + ? decodeDataUrl(resolvedInput) + : params.sandboxConfig + ? await loadWebMedia(resolvedPath ?? resolvedInput, { + sandboxValidated: true, + readFile: createSandboxBridgeReadFile({ sandbox: params.sandboxConfig }), + }) + : await loadWebMedia(resolvedPath ?? resolvedInput, { + localRoots, + }); + if (media.kind !== "image") { + throw new ToolInputError(`Unsupported media type: ${media.kind ?? "unknown"}`); + } + const mimeType = "mimeType" in media ? media.mimeType : media.contentType; + const fileName = "fileName" in media ? media.fileName : undefined; + loaded.push({ + sourceImage: { + buffer: media.buffer, + mimeType, + fileName, + }, + resolvedInput, + ...(resolvedPathInfo.rewrittenFrom ? { rewrittenFrom: resolvedPathInfo.rewrittenFrom } : {}), + }); + } + + return loaded; +} + +type LoadedReferenceImage = Awaited>[number]; + +type ExecutedMusicGeneration = { + provider: string; + model: string; + savedPaths: string[]; + contentText: string; + details: Record; + wakeResult: string; +}; + +async function executeMusicGenerationJob(params: { + effectiveCfg: OpenClawConfig; + prompt: string; + agentDir?: string; + model?: string; + lyrics?: string; + instrumental?: boolean; + durationSeconds?: number; + format?: MusicGenerationOutputFormat; + filename?: string; + loadedReferenceImages: LoadedReferenceImage[]; + taskHandle?: MusicGenerationTaskHandle | null; +}): Promise { + if (params.taskHandle) { + recordMusicGenerationTaskProgress({ + handle: params.taskHandle, + progressSummary: "Generating music", + }); + } + const result = await generateMusic({ + cfg: params.effectiveCfg, + prompt: params.prompt, + agentDir: params.agentDir, + modelOverride: params.model, + lyrics: params.lyrics, + instrumental: params.instrumental, + durationSeconds: params.durationSeconds, + format: params.format, + inputImages: params.loadedReferenceImages.map((entry) => entry.sourceImage), + }); + if (params.taskHandle) { + recordMusicGenerationTaskProgress({ + handle: params.taskHandle, + progressSummary: "Saving generated music", + }); + } + const savedTracks = await Promise.all( + result.tracks.map((track) => + saveMediaBuffer( + track.buffer, + track.mimeType, + "tool-music-generation", + undefined, + params.filename || track.fileName, + ), + ), + ); + const lines = [ + `Generated ${savedTracks.length} track${savedTracks.length === 1 ? "" : "s"} with ${result.provider}/${result.model}.`, + ...(result.lyrics?.length ? ["Lyrics returned.", ...result.lyrics] : []), + ...savedTracks.map((track) => `MEDIA:${track.path}`), + ]; + return { + provider: result.provider, + model: result.model, + savedPaths: savedTracks.map((track) => track.path), + contentText: lines.join("\n"), + wakeResult: lines.join("\n"), + details: { + provider: result.provider, + model: result.model, + count: savedTracks.length, + media: { + mediaUrls: savedTracks.map((track) => track.path), + }, + paths: savedTracks.map((track) => track.path), + ...(params.taskHandle + ? { + task: { + taskId: params.taskHandle.taskId, + runId: params.taskHandle.runId, + }, + } + : {}), + ...(params.lyrics ? { requestedLyrics: params.lyrics } : {}), + ...(typeof params.instrumental === "boolean" ? { instrumental: params.instrumental } : {}), + ...(typeof params.durationSeconds === "number" + ? { durationSeconds: params.durationSeconds } + : {}), + ...(params.format ? { format: params.format } : {}), + ...(params.filename ? { filename: params.filename } : {}), + ...(params.loadedReferenceImages.length === 1 + ? { + image: params.loadedReferenceImages[0]?.resolvedInput, + ...(params.loadedReferenceImages[0]?.rewrittenFrom + ? { rewrittenFrom: params.loadedReferenceImages[0].rewrittenFrom } + : {}), + } + : params.loadedReferenceImages.length > 1 + ? { + images: params.loadedReferenceImages.map((entry) => ({ + image: entry.resolvedInput, + ...(entry.rewrittenFrom ? { rewrittenFrom: entry.rewrittenFrom } : {}), + })), + } + : {}), + ...(result.lyrics?.length ? { lyrics: result.lyrics } : {}), + attempts: result.attempts, + metadata: result.metadata, + }, + }; +} + +export function createMusicGenerateTool(options?: { + config?: OpenClawConfig; + agentDir?: string; + agentSessionKey?: string; + requesterOrigin?: DeliveryContext; + workspaceDir?: string; + sandbox?: MusicGenerateSandboxConfig; + fsPolicy?: ToolFsPolicy; + scheduleBackgroundWork?: MusicGenerateBackgroundScheduler; +}): AnyAgentTool | null { + const cfg: OpenClawConfig = options?.config ?? loadConfig(); + const musicGenerationModelConfig = resolveMusicGenerationModelConfigForTool({ + cfg, + agentDir: options?.agentDir, + }); + if (!musicGenerationModelConfig) { + return null; + } + + const sandboxConfig = options?.sandbox + ? { + root: options.sandbox.root, + bridge: options.sandbox.bridge, + workspaceOnly: options.fsPolicy?.workspaceOnly === true, + } + : null; + const scheduleBackgroundWork = + options?.scheduleBackgroundWork ?? defaultScheduleMusicGenerateBackgroundWork; + + return { + label: "Music Generation", + name: "music_generate", + displaySummary: "Generate music", + description: + "Generate music using configured providers. Generated tracks are saved under OpenClaw-managed media storage and delivered automatically as attachments.", + parameters: MusicGenerateToolSchema, + execute: async (_toolCallId, rawArgs) => { + const args = rawArgs as Record; + const action = resolveAction(args); + const effectiveCfg = + applyMusicGenerationModelConfigDefaults(cfg, musicGenerationModelConfig) ?? cfg; + + if (action === "list") { + return createMusicGenerateListActionResult(effectiveCfg); + } + + if (action === "status") { + return createMusicGenerateStatusActionResult(options?.agentSessionKey); + } + + const duplicateGuardResult = createMusicGenerateDuplicateGuardResult( + options?.agentSessionKey, + ); + if (duplicateGuardResult) { + return duplicateGuardResult; + } + + const prompt = readStringParam(args, "prompt", { required: true }); + const lyrics = readStringParam(args, "lyrics"); + const instrumental = readBooleanParam(args, "instrumental"); + const model = readStringParam(args, "model"); + const durationSeconds = readNumberParam(args, "durationSeconds", { + integer: true, + strict: true, + }); + const format = normalizeOutputFormat(readStringParam(args, "format")); + const filename = readStringParam(args, "filename"); + const imageInputs = normalizeReferenceImageInputs(args); + const selectedProvider = resolveSelectedMusicGenerationProvider({ + config: effectiveCfg, + musicGenerationModelConfig, + modelOverride: model, + }); + const loadedReferenceImages = await loadReferenceImages({ + inputs: imageInputs, + workspaceDir: options?.workspaceDir, + sandboxConfig, + }); + validateMusicGenerationCapabilities({ + provider: selectedProvider, + model: + parseMusicGenerationModelRef(model)?.model ?? model ?? selectedProvider?.defaultModel, + inputImageCount: loadedReferenceImages.length, + lyrics, + instrumental, + durationSeconds, + format, + }); + const taskHandle = createMusicGenerationTaskRun({ + sessionKey: options?.agentSessionKey, + requesterOrigin: options?.requesterOrigin, + prompt, + providerId: selectedProvider?.id, + }); + const shouldDetach = Boolean(taskHandle && options?.agentSessionKey?.trim()); + + if (shouldDetach) { + scheduleBackgroundWork(async () => { + try { + const executed = await executeMusicGenerationJob({ + effectiveCfg, + prompt, + agentDir: options?.agentDir, + model, + lyrics, + instrumental, + durationSeconds, + format, + filename, + loadedReferenceImages, + taskHandle, + }); + completeMusicGenerationTaskRun({ + handle: taskHandle, + provider: executed.provider, + model: executed.model, + count: executed.savedPaths.length, + paths: executed.savedPaths, + }); + try { + await wakeMusicGenerationTaskCompletion({ + handle: taskHandle, + status: "ok", + statusLabel: "completed successfully", + result: executed.wakeResult, + }); + } catch (error) { + log.warn("Music generation completion wake failed after successful generation", { + taskId: taskHandle?.taskId, + runId: taskHandle?.runId, + error, + }); + } + } catch (error) { + failMusicGenerationTaskRun({ + handle: taskHandle, + error, + }); + await wakeMusicGenerationTaskCompletion({ + handle: taskHandle, + status: "error", + statusLabel: "failed", + result: error instanceof Error ? error.message : String(error), + }); + return; + } + }); + + return { + content: [ + { + type: "text", + text: `Background task started for music generation (${taskHandle?.taskId ?? "unknown"}). Do not call music_generate again for this request. Wait for the completion event; I'll post the finished music here when it's ready.`, + }, + ], + details: { + async: true, + status: "started", + ...(taskHandle + ? { + task: { + taskId: taskHandle.taskId, + runId: taskHandle.runId, + }, + } + : {}), + ...(loadedReferenceImages.length === 1 + ? { + image: loadedReferenceImages[0]?.resolvedInput, + ...(loadedReferenceImages[0]?.rewrittenFrom + ? { rewrittenFrom: loadedReferenceImages[0].rewrittenFrom } + : {}), + } + : loadedReferenceImages.length > 1 + ? { + images: loadedReferenceImages.map((entry) => ({ + image: entry.resolvedInput, + ...(entry.rewrittenFrom ? { rewrittenFrom: entry.rewrittenFrom } : {}), + })), + } + : {}), + ...(model ? { model } : {}), + ...(lyrics ? { requestedLyrics: lyrics } : {}), + ...(typeof instrumental === "boolean" ? { instrumental } : {}), + ...(typeof durationSeconds === "number" ? { durationSeconds } : {}), + ...(format ? { format } : {}), + ...(filename ? { filename } : {}), + }, + }; + } + + try { + const executed = await executeMusicGenerationJob({ + effectiveCfg, + prompt, + agentDir: options?.agentDir, + lyrics, + instrumental, + durationSeconds, + model, + format, + filename, + loadedReferenceImages, + taskHandle, + }); + completeMusicGenerationTaskRun({ + handle: taskHandle, + provider: executed.provider, + model: executed.model, + count: executed.savedPaths.length, + paths: executed.savedPaths, + }); + return { + content: [{ type: "text", text: executed.contentText }], + details: executed.details, + }; + } catch (error) { + failMusicGenerationTaskRun({ + handle: taskHandle, + error, + }); + throw error; + } + }, + }; +} diff --git a/src/agents/tools/video-generate-background.ts b/src/agents/tools/video-generate-background.ts index 3628afea1dd..a2ad266d1a7 100644 --- a/src/agents/tools/video-generate-background.ts +++ b/src/agents/tools/video-generate-background.ts @@ -1,26 +1,15 @@ -import crypto from "node:crypto"; -import { createSubsystemLogger } from "../../logging/subsystem.js"; -import { - completeTaskRunByRunId, - createRunningTaskRun, - failTaskRunByRunId, - recordTaskRunProgressByRunId, -} from "../../tasks/task-executor.js"; import type { DeliveryContext } from "../../utils/delivery-context.js"; -import { INTERNAL_MESSAGE_CHANNEL } from "../../utils/message-channel.js"; -import { formatAgentInternalEventsForPrompt, type AgentInternalEvent } from "../internal-events.js"; -import { deliverSubagentAnnouncement } from "../subagent-announce-delivery.js"; import { VIDEO_GENERATION_TASK_KIND } from "../video-generation-task-status.js"; +import { + completeMediaGenerationTaskRun, + createMediaGenerationTaskRun, + failMediaGenerationTaskRun, + recordMediaGenerationTaskProgress, + wakeMediaGenerationTaskCompletion, + type MediaGenerationTaskHandle, +} from "./media-generate-background-shared.js"; -const log = createSubsystemLogger("agents/tools/video-generate-background"); - -export type VideoGenerationTaskHandle = { - taskId: string; - runId: string; - requesterSessionKey: string; - requesterOrigin?: DeliveryContext; - taskLabel: string; -}; +export type VideoGenerationTaskHandle = MediaGenerationTaskHandle; export function createVideoGenerationTaskRun(params: { sessionKey?: string; @@ -28,45 +17,16 @@ export function createVideoGenerationTaskRun(params: { prompt: string; providerId?: string; }): VideoGenerationTaskHandle | null { - const sessionKey = params.sessionKey?.trim(); - if (!sessionKey) { - return null; - } - const runId = `tool:video_generate:${crypto.randomUUID()}`; - try { - const task = createRunningTaskRun({ - runtime: "cli", - taskKind: VIDEO_GENERATION_TASK_KIND, - sourceId: params.providerId ? `video_generate:${params.providerId}` : "video_generate", - requesterSessionKey: sessionKey, - ownerKey: sessionKey, - scopeKind: "session", - requesterOrigin: params.requesterOrigin, - childSessionKey: sessionKey, - runId, - label: "Video generation", - task: params.prompt, - deliveryStatus: "not_applicable", - notifyPolicy: "silent", - startedAt: Date.now(), - lastEventAt: Date.now(), - progressSummary: "Queued video generation", - }); - return { - taskId: task.taskId, - runId, - requesterSessionKey: sessionKey, - requesterOrigin: params.requesterOrigin, - taskLabel: params.prompt, - }; - } catch (error) { - log.warn("Failed to create video generation task ledger record", { - sessionKey, - providerId: params.providerId, - error, - }); - return null; - } + return createMediaGenerationTaskRun({ + sessionKey: params.sessionKey, + requesterOrigin: params.requesterOrigin, + prompt: params.prompt, + providerId: params.providerId, + toolName: "video_generate", + taskKind: VIDEO_GENERATION_TASK_KIND, + label: "Video generation", + queuedProgressSummary: "Queued video generation", + }); } export function recordVideoGenerationTaskProgress(params: { @@ -74,17 +34,7 @@ export function recordVideoGenerationTaskProgress(params: { progressSummary: string; eventSummary?: string; }) { - if (!params.handle) { - return; - } - recordTaskRunProgressByRunId({ - runId: params.handle.runId, - runtime: "cli", - sessionKey: params.handle.requesterSessionKey, - lastEventAt: Date.now(), - progressSummary: params.progressSummary, - eventSummary: params.eventSummary, - }); + recordMediaGenerationTaskProgress(params); } export function completeVideoGenerationTaskRun(params: { @@ -94,19 +44,9 @@ export function completeVideoGenerationTaskRun(params: { count: number; paths: string[]; }) { - if (!params.handle) { - return; - } - const endedAt = Date.now(); - const target = params.count === 1 ? params.paths[0] : `${params.count} files`; - completeTaskRunByRunId({ - runId: params.handle.runId, - runtime: "cli", - sessionKey: params.handle.requesterSessionKey, - endedAt, - lastEventAt: endedAt, - progressSummary: `Generated ${params.count} video${params.count === 1 ? "" : "s"}`, - terminalSummary: `Generated ${params.count} video${params.count === 1 ? "" : "s"} with ${params.provider}/${params.model}${target ? ` -> ${target}` : ""}.`, + completeMediaGenerationTaskRun({ + ...params, + generatedLabel: "video", }); } @@ -114,39 +54,12 @@ export function failVideoGenerationTaskRun(params: { handle: VideoGenerationTaskHandle | null; error: unknown; }) { - if (!params.handle) { - return; - } - const endedAt = Date.now(); - const errorText = params.error instanceof Error ? params.error.message : String(params.error); - failTaskRunByRunId({ - runId: params.handle.runId, - runtime: "cli", - sessionKey: params.handle.requesterSessionKey, - endedAt, - lastEventAt: endedAt, - error: errorText, + failMediaGenerationTaskRun({ + ...params, progressSummary: "Video generation failed", - terminalSummary: errorText, }); } -function buildVideoGenerationReplyInstruction(status: "ok" | "error"): string { - if (status === "ok") { - return [ - "A completed video generation task is ready for user delivery.", - "Reply in your normal assistant voice and post the finished video to the original message channel now.", - "If the result includes MEDIA: lines, include those exact MEDIA: lines in your reply so OpenClaw attaches the video.", - "Keep internal task/session details private and do not copy the internal event text verbatim.", - ].join(" "); - } - return [ - "A video generation task failed.", - "Reply in your normal assistant voice with the failure summary now.", - "Keep internal task/session details private and do not copy the internal event text verbatim.", - ].join(" "); -} - export async function wakeVideoGenerationTaskCompletion(params: { handle: VideoGenerationTaskHandle | null; status: "ok" | "error"; @@ -154,53 +67,15 @@ export async function wakeVideoGenerationTaskCompletion(params: { result: string; statsLine?: string; }) { - if (!params.handle) { - return; - } - const internalEvents: AgentInternalEvent[] = [ - { - type: "task_completion", - source: "video_generation", - childSessionKey: `video_generate:${params.handle.taskId}`, - childSessionId: params.handle.taskId, - announceType: "video generation task", - taskLabel: params.handle.taskLabel, - status: params.status, - statusLabel: params.statusLabel, - result: params.result, - ...(params.statsLine?.trim() ? { statsLine: params.statsLine } : {}), - replyInstruction: buildVideoGenerationReplyInstruction(params.status), - }, - ]; - const triggerMessage = - formatAgentInternalEventsForPrompt(internalEvents) || - "A video generation task finished. Process the completion update now."; - const announceId = `video-generate:${params.handle.taskId}:${params.status}`; - const delivery = await deliverSubagentAnnouncement({ - requesterSessionKey: params.handle.requesterSessionKey, - targetRequesterSessionKey: params.handle.requesterSessionKey, - announceId, - triggerMessage, - steerMessage: triggerMessage, - internalEvents, - summaryLine: params.handle.taskLabel, - requesterSessionOrigin: params.handle.requesterOrigin, - requesterOrigin: params.handle.requesterOrigin, - completionDirectOrigin: params.handle.requesterOrigin, - directOrigin: params.handle.requesterOrigin, - sourceSessionKey: `video_generate:${params.handle.taskId}`, - sourceChannel: INTERNAL_MESSAGE_CHANNEL, - sourceTool: "video_generate", - requesterIsSubagent: false, - expectsCompletionMessage: true, - bestEffortDeliver: true, - directIdempotencyKey: announceId, + await wakeMediaGenerationTaskCompletion({ + handle: params.handle, + status: params.status, + statusLabel: params.statusLabel, + result: params.result, + statsLine: params.statsLine, + eventSource: "video_generation", + announceType: "video generation task", + toolName: "video_generate", + completionLabel: "video", }); - if (!delivery.delivered && delivery.error) { - log.warn("Video generation completion wake failed", { - taskId: params.handle.taskId, - runId: params.handle.runId, - error: delivery.error, - }); - } } diff --git a/src/agents/tools/video-generate-tool.ts b/src/agents/tools/video-generate-tool.ts index 0a71ad13c53..fb18309c8e8 100644 --- a/src/agents/tools/video-generate-tool.ts +++ b/src/agents/tools/video-generate-tool.ts @@ -19,7 +19,6 @@ import type { VideoGenerationResolution, VideoGenerationSourceAsset, } from "../../video-generation/types.js"; -import { normalizeProviderId } from "../provider-id.js"; import { ToolInputError, readNumberParam, @@ -29,16 +28,11 @@ import { import { decodeDataUrl } from "./image-tool.helpers.js"; import { applyVideoGenerationModelConfigDefaults, + findCapabilityProviderById, + resolveCapabilityModelConfigForTool, resolveMediaToolLocalRoots, } from "./media-tool-shared.js"; -import { - buildToolModelConfigFromCandidates, - coerceToolModelConfig, - hasAuthForProvider, - hasToolModelConfig, - resolveDefaultModelRef, - type ToolModelConfig, -} from "./model-config.helpers.js"; +import { type ToolModelConfig } from "./model-config.helpers.js"; import { createSandboxBridgeReadFile, resolveSandboxedBridgeMediaPath, @@ -148,99 +142,18 @@ const VideoGenerateToolSchema = Type.Object({ ), }); -function resolveVideoGenerationModelCandidates(params: { - cfg?: OpenClawConfig; - agentDir?: string; -}): Array { - const providerDefaults = new Map(); - for (const provider of listRuntimeVideoGenerationProviders({ config: params.cfg })) { - const providerId = provider.id.trim(); - const modelId = provider.defaultModel?.trim(); - if ( - !providerId || - !modelId || - providerDefaults.has(providerId) || - !isVideoGenerationProviderConfigured({ - provider, - cfg: params.cfg, - agentDir: params.agentDir, - }) - ) { - continue; - } - providerDefaults.set(providerId, `${providerId}/${modelId}`); - } - - const primaryProvider = resolveDefaultModelRef(params.cfg).provider; - const orderedProviders = [ - primaryProvider, - ...[...providerDefaults.keys()] - .filter((providerId) => providerId !== primaryProvider) - .toSorted(), - ]; - const orderedRefs: string[] = []; - const seen = new Set(); - for (const providerId of orderedProviders) { - const ref = providerDefaults.get(providerId); - if (!ref || seen.has(ref)) { - continue; - } - seen.add(ref); - orderedRefs.push(ref); - } - return orderedRefs; -} - export function resolveVideoGenerationModelConfigForTool(params: { cfg?: OpenClawConfig; agentDir?: string; }): ToolModelConfig | null { - const explicit = coerceToolModelConfig(params.cfg?.agents?.defaults?.videoGenerationModel); - if (hasToolModelConfig(explicit)) { - return explicit; - } - return buildToolModelConfigFromCandidates({ - explicit, + return resolveCapabilityModelConfigForTool({ + cfg: params.cfg, agentDir: params.agentDir, - candidates: resolveVideoGenerationModelCandidates(params), - isProviderConfigured: (providerId) => - isVideoGenerationProviderConfigured({ - providerId, - cfg: params.cfg, - agentDir: params.agentDir, - }), + modelConfig: params.cfg?.agents?.defaults?.videoGenerationModel, + providers: listRuntimeVideoGenerationProviders({ config: params.cfg }), }); } -function isVideoGenerationProviderConfigured(params: { - provider?: VideoGenerationProvider; - providerId?: string; - cfg?: OpenClawConfig; - agentDir?: string; -}): boolean { - const provider = - params.provider ?? - listRuntimeVideoGenerationProviders({ config: params.cfg }).find((candidate) => { - const normalizedId = normalizeProviderId(params.providerId ?? ""); - return ( - normalizeProviderId(candidate.id) === normalizedId || - (candidate.aliases ?? []).some((alias) => normalizeProviderId(alias) === normalizedId) - ); - }); - if (!provider) { - return params.providerId - ? hasAuthForProvider({ provider: params.providerId, agentDir: params.agentDir }) - : false; - } - if (provider.isConfigured) { - return provider.isConfigured({ - cfg: params.cfg, - agentDir: params.agentDir, - }); - } - return hasAuthForProvider({ provider: provider.id, agentDir: params.agentDir }); -} - function resolveAction(args: Record): "generate" | "list" | "status" { const raw = readStringParam(args, "action"); if (!raw) { @@ -333,12 +246,10 @@ function resolveSelectedVideoGenerationProvider(params: { if (!selectedRef) { return undefined; } - const selectedProvider = normalizeProviderId(selectedRef.provider); - return listRuntimeVideoGenerationProviders({ config: params.config }).find( - (provider) => - normalizeProviderId(provider.id) === selectedProvider || - (provider.aliases ?? []).some((alias) => normalizeProviderId(alias) === selectedProvider), - ); + return findCapabilityProviderById({ + providers: listRuntimeVideoGenerationProviders({ config: params.config }), + providerId: selectedRef.provider, + }); } function validateVideoGenerationCapabilities(params: { diff --git a/src/agents/video-generation-task-status.ts b/src/agents/video-generation-task-status.ts index c4566ccbda8..d6cb3e0db73 100644 --- a/src/agents/video-generation-task-status.ts +++ b/src/agents/video-generation-task-status.ts @@ -1,77 +1,65 @@ import type { TaskRecord } from "../tasks/task-registry.types.js"; import { - buildSessionAsyncTaskStatusDetails, - findActiveSessionTask, -} from "./session-async-task-status.js"; + buildActiveMediaGenerationTaskPromptContextForSession, + buildMediaGenerationTaskStatusDetails, + buildMediaGenerationTaskStatusText, + findActiveMediaGenerationTaskForSession, + getMediaGenerationTaskProviderId, + isActiveMediaGenerationTask, +} from "./media-generation-task-status-shared.js"; export const VIDEO_GENERATION_TASK_KIND = "video_generation"; const VIDEO_GENERATION_SOURCE_PREFIX = "video_generate"; export function isActiveVideoGenerationTask(task: TaskRecord): boolean { - return ( - task.runtime === "cli" && - task.scopeKind === "session" && - task.taskKind === VIDEO_GENERATION_TASK_KIND && - (task.status === "queued" || task.status === "running") - ); + return isActiveMediaGenerationTask({ + task, + taskKind: VIDEO_GENERATION_TASK_KIND, + }); } export function getVideoGenerationTaskProviderId(task: TaskRecord): string | undefined { - const sourceId = task.sourceId?.trim() ?? ""; - if (!sourceId.startsWith(`${VIDEO_GENERATION_SOURCE_PREFIX}:`)) { - return undefined; - } - const providerId = sourceId.slice(`${VIDEO_GENERATION_SOURCE_PREFIX}:`.length).trim(); - return providerId || undefined; + return getMediaGenerationTaskProviderId(task, VIDEO_GENERATION_SOURCE_PREFIX); } export function findActiveVideoGenerationTaskForSession(sessionKey?: string): TaskRecord | null { - return findActiveSessionTask({ + return findActiveMediaGenerationTaskForSession({ sessionKey, - runtime: "cli", taskKind: VIDEO_GENERATION_TASK_KIND, - sourceIdPrefix: VIDEO_GENERATION_SOURCE_PREFIX, + sourcePrefix: VIDEO_GENERATION_SOURCE_PREFIX, }); } export function buildVideoGenerationTaskStatusDetails(task: TaskRecord): Record { - const provider = getVideoGenerationTaskProviderId(task); - return { - ...buildSessionAsyncTaskStatusDetails(task), - ...(provider ? { provider } : {}), - }; + return buildMediaGenerationTaskStatusDetails({ + task, + sourcePrefix: VIDEO_GENERATION_SOURCE_PREFIX, + }); } export function buildVideoGenerationTaskStatusText( task: TaskRecord, params?: { duplicateGuard?: boolean }, ): string { - const provider = getVideoGenerationTaskProviderId(task); - const lines = [ - `Video generation task ${task.taskId} is already ${task.status}${provider ? ` with ${provider}` : ""}.`, - task.progressSummary ? `Progress: ${task.progressSummary}.` : null, - params?.duplicateGuard - ? "Do not call video_generate again for this request. Wait for the completion event; I will post the finished video here." - : "Wait for the completion event; I will post the finished video here when it's ready.", - ].filter((entry): entry is string => Boolean(entry)); - return lines.join("\n"); + return buildMediaGenerationTaskStatusText({ + task, + sourcePrefix: VIDEO_GENERATION_SOURCE_PREFIX, + nounLabel: "Video generation", + toolName: "video_generate", + completionLabel: "video", + duplicateGuard: params?.duplicateGuard, + }); } export function buildActiveVideoGenerationTaskPromptContextForSession( sessionKey?: string, ): string | undefined { - const task = findActiveVideoGenerationTaskForSession(sessionKey); - if (!task) { - return undefined; - } - const provider = getVideoGenerationTaskProviderId(task); - const lines = [ - "An active video generation background task already exists for this session.", - `Task ${task.taskId} is currently ${task.status}${provider ? ` via ${provider}` : ""}.`, - task.progressSummary ? `Current progress: ${task.progressSummary}.` : null, - "Do not call `video_generate` again for the same request while that task is queued or running.", - 'If the user asks for progress or whether the work is async, explain the active task state or call `video_generate` with `action:"status"` instead of starting a new generation.', - "Only start a new `video_generate` call if the user clearly asks for a different/new video.", - ].filter((entry): entry is string => Boolean(entry)); - return lines.join("\n"); + return buildActiveMediaGenerationTaskPromptContextForSession({ + sessionKey, + taskKind: VIDEO_GENERATION_TASK_KIND, + sourcePrefix: VIDEO_GENERATION_SOURCE_PREFIX, + nounLabel: "Video generation", + toolName: "video_generate", + completionLabel: "videos", + }); } diff --git a/src/config/schema.help.ts b/src/config/schema.help.ts index 0074844effb..3f8cf953920 100644 --- a/src/config/schema.help.ts +++ b/src/config/schema.help.ts @@ -1089,6 +1089,10 @@ export const FIELD_HELP: Record = { "Optional video-generation model (provider/model) used by the shared video generation capability.", "agents.defaults.videoGenerationModel.fallbacks": "Ordered fallback video-generation models (provider/model).", + "agents.defaults.musicGenerationModel.primary": + "Optional music-generation model (provider/model) used by the shared music generation capability.", + "agents.defaults.musicGenerationModel.fallbacks": + "Ordered fallback music-generation models (provider/model).", "agents.defaults.pdfModel.primary": "Optional PDF model (provider/model) for the PDF analysis tool. Defaults to imageModel, then session model.", "agents.defaults.pdfModel.fallbacks": "Ordered fallback PDF models (provider/model).", diff --git a/src/config/schema.labels.ts b/src/config/schema.labels.ts index c8434d2e450..65524d4f4b6 100644 --- a/src/config/schema.labels.ts +++ b/src/config/schema.labels.ts @@ -496,6 +496,8 @@ export const FIELD_LABELS: Record = { "agents.defaults.imageGenerationModel.fallbacks": "Image Generation Model Fallbacks", "agents.defaults.videoGenerationModel.primary": "Video Generation Model", "agents.defaults.videoGenerationModel.fallbacks": "Video Generation Model Fallbacks", + "agents.defaults.musicGenerationModel.primary": "Music Generation Model", + "agents.defaults.musicGenerationModel.fallbacks": "Music Generation Model Fallbacks", "agents.defaults.pdfModel.primary": "PDF Model", "agents.defaults.pdfModel.fallbacks": "PDF Model Fallbacks", "agents.defaults.pdfMaxBytesMb": "PDF Max Size (MB)", diff --git a/src/config/types.agent-defaults.ts b/src/config/types.agent-defaults.ts index e98f1a88ce9..1885b36da11 100644 --- a/src/config/types.agent-defaults.ts +++ b/src/config/types.agent-defaults.ts @@ -55,6 +55,8 @@ export type AgentDefaultsConfig = { imageGenerationModel?: AgentModelConfig; /** Optional video-generation model and fallbacks (provider/model). Accepts string or {primary,fallbacks}. */ videoGenerationModel?: AgentModelConfig; + /** Optional music-generation model and fallbacks (provider/model). Accepts string or {primary,fallbacks}. */ + musicGenerationModel?: AgentModelConfig; /** Optional PDF-capable model and fallbacks (provider/model). Accepts string or {primary,fallbacks}. */ pdfModel?: AgentModelConfig; /** Maximum PDF file size in megabytes (default: 10). */ diff --git a/src/config/zod-schema.agent-defaults.ts b/src/config/zod-schema.agent-defaults.ts index faa6d4227be..94165de41ac 100644 --- a/src/config/zod-schema.agent-defaults.ts +++ b/src/config/zod-schema.agent-defaults.ts @@ -21,6 +21,7 @@ export const AgentDefaultsSchema = z imageModel: AgentModelSchema.optional(), imageGenerationModel: AgentModelSchema.optional(), videoGenerationModel: AgentModelSchema.optional(), + musicGenerationModel: AgentModelSchema.optional(), pdfModel: AgentModelSchema.optional(), pdfMaxBytesMb: z.number().positive().optional(), pdfMaxPages: z.number().int().positive().optional(), diff --git a/src/gateway/server-plugins.test.ts b/src/gateway/server-plugins.test.ts index 075d2dd39df..ced9cecca23 100644 --- a/src/gateway/server-plugins.test.ts +++ b/src/gateway/server-plugins.test.ts @@ -73,6 +73,7 @@ const createRegistry = (diagnostics: PluginDiagnostic[]): PluginRegistry => ({ realtimeVoiceProviders: [], mediaUnderstandingProviders: [], imageGenerationProviders: [], + musicGenerationProviders: [], videoGenerationProviders: [], webFetchProviders: [], webSearchProviders: [], diff --git a/src/gateway/test-helpers.plugin-registry.ts b/src/gateway/test-helpers.plugin-registry.ts index c3c39398cc5..468c51ebdea 100644 --- a/src/gateway/test-helpers.plugin-registry.ts +++ b/src/gateway/test-helpers.plugin-registry.ts @@ -19,6 +19,7 @@ function createStubPluginRegistry(): PluginRegistry { mediaUnderstandingProviders: [], imageGenerationProviders: [], videoGenerationProviders: [], + musicGenerationProviders: [], webFetchProviders: [], webSearchProviders: [], memoryEmbeddingProviders: [], diff --git a/src/image-generation/runtime.ts b/src/image-generation/runtime.ts index 698a6367231..ad8687589c8 100644 --- a/src/image-generation/runtime.ts +++ b/src/image-generation/runtime.ts @@ -2,12 +2,12 @@ import type { AuthProfileStore } from "../agents/auth-profiles.js"; import { describeFailoverError, isFailoverError } from "../agents/failover-error.js"; import type { FallbackAttempt } from "../agents/model-fallback.types.js"; import type { OpenClawConfig } from "../config/config.js"; -import { - resolveAgentModelFallbackValues, - resolveAgentModelPrimaryValue, -} from "../config/model-input.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; -import { getProviderEnvVars } from "../secrets/provider-env-vars.js"; +import { + buildNoCapabilityModelConfiguredMessage, + resolveCapabilityModelCandidates, + throwCapabilityGenerationFailure, +} from "../media-generation/runtime-shared.js"; import { parseImageGenerationModelRef } from "./model-ref.js"; import { getImageGenerationProvider, listImageGenerationProviders } from "./provider-registry.js"; import type { @@ -42,76 +42,12 @@ export type GenerateImageRuntimeResult = { ignoredOverrides: ImageGenerationIgnoredOverride[]; }; -function resolveImageGenerationCandidates(params: { - cfg: OpenClawConfig; - modelOverride?: string; -}): Array<{ provider: string; model: string }> { - const candidates: Array<{ provider: string; model: string }> = []; - const seen = new Set(); - const add = (raw: string | undefined) => { - const parsed = parseImageGenerationModelRef(raw); - if (!parsed) { - return; - } - const key = `${parsed.provider}/${parsed.model}`; - if (seen.has(key)) { - return; - } - seen.add(key); - candidates.push(parsed); - }; - - add(params.modelOverride); - add(resolveAgentModelPrimaryValue(params.cfg.agents?.defaults?.imageGenerationModel)); - for (const fallback of resolveAgentModelFallbackValues( - params.cfg.agents?.defaults?.imageGenerationModel, - )) { - add(fallback); - } - return candidates; -} - -function throwImageGenerationFailure(params: { - attempts: FallbackAttempt[]; - lastError: unknown; -}): never { - if (params.attempts.length <= 1 && params.lastError) { - throw params.lastError; - } - const summary = - params.attempts.length > 0 - ? params.attempts - .map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`) - .join(" | ") - : "unknown"; - throw new Error(`All image generation models failed (${params.attempts.length}): ${summary}`, { - cause: params.lastError instanceof Error ? params.lastError : undefined, - }); -} - function buildNoImageGenerationModelConfiguredMessage(cfg: OpenClawConfig): string { - const providers = listImageGenerationProviders(cfg); - const sampleModel = providers.find( - (provider) => provider.id.trim().length > 0 && provider.defaultModel?.trim(), - ); - const sampleRef = sampleModel - ? `${sampleModel.id}/${sampleModel.defaultModel}` - : "/"; - const authHints = providers - .flatMap((provider) => { - const envVars = getProviderEnvVars(provider.id); - if (envVars.length === 0) { - return []; - } - return [`${provider.id}: ${envVars.join(" / ")}`]; - }) - .slice(0, 3); - return [ - `No image-generation model configured. Set agents.defaults.imageGenerationModel.primary to a provider/model like "${sampleRef}".`, - authHints.length > 0 - ? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).` - : "If you want a specific provider, also configure that provider's auth/API key first.", - ].join(" "); + return buildNoCapabilityModelConfiguredMessage({ + capabilityLabel: "image-generation", + modelConfigKey: "imageGenerationModel", + providers: listImageGenerationProviders(cfg), + }); } export function listRuntimeImageGenerationProviders(params?: { config?: OpenClawConfig }) { @@ -173,9 +109,11 @@ function resolveProviderImageGenerationOverrides(params: { export async function generateImage( params: GenerateImageParams, ): Promise { - const candidates = resolveImageGenerationCandidates({ + const candidates = resolveCapabilityModelCandidates({ cfg: params.cfg, + modelConfig: params.cfg.agents?.defaults?.imageGenerationModel, modelOverride: params.modelOverride, + parseModelRef: parseImageGenerationModelRef, }); if (candidates.length === 0) { throw new Error(buildNoImageGenerationModelConfiguredMessage(params.cfg)); @@ -244,5 +182,9 @@ export async function generateImage( } } - throwImageGenerationFailure({ attempts, lastError }); + throwCapabilityGenerationFailure({ + capabilityLabel: "image generation", + attempts, + lastError, + }); } diff --git a/src/media-generation/runtime-shared.ts b/src/media-generation/runtime-shared.ts new file mode 100644 index 00000000000..3f5e83d147e --- /dev/null +++ b/src/media-generation/runtime-shared.ts @@ -0,0 +1,93 @@ +import type { FallbackAttempt } from "../agents/model-fallback.types.js"; +import type { OpenClawConfig } from "../config/config.js"; +import { + resolveAgentModelFallbackValues, + resolveAgentModelPrimaryValue, +} from "../config/model-input.js"; +import type { AgentModelConfig } from "../config/types.agents-shared.js"; +import { getProviderEnvVars } from "../secrets/provider-env-vars.js"; + +export type ParsedProviderModelRef = { + provider: string; + model: string; +}; + +export function resolveCapabilityModelCandidates(params: { + cfg: OpenClawConfig; + modelConfig: AgentModelConfig | undefined; + modelOverride?: string; + parseModelRef: (raw: string | undefined) => ParsedProviderModelRef | null; +}): ParsedProviderModelRef[] { + const candidates: ParsedProviderModelRef[] = []; + const seen = new Set(); + const add = (raw: string | undefined) => { + const parsed = params.parseModelRef(raw); + if (!parsed) { + return; + } + const key = `${parsed.provider}/${parsed.model}`; + if (seen.has(key)) { + return; + } + seen.add(key); + candidates.push(parsed); + }; + + add(params.modelOverride); + add(resolveAgentModelPrimaryValue(params.modelConfig)); + for (const fallback of resolveAgentModelFallbackValues(params.modelConfig)) { + add(fallback); + } + return candidates; +} + +export function throwCapabilityGenerationFailure(params: { + capabilityLabel: string; + attempts: FallbackAttempt[]; + lastError: unknown; +}): never { + if (params.attempts.length <= 1 && params.lastError) { + throw params.lastError; + } + const summary = + params.attempts.length > 0 + ? params.attempts + .map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`) + .join(" | ") + : "unknown"; + throw new Error( + `All ${params.capabilityLabel} models failed (${params.attempts.length}): ${summary}`, + { + cause: params.lastError instanceof Error ? params.lastError : undefined, + }, + ); +} + +export function buildNoCapabilityModelConfiguredMessage(params: { + capabilityLabel: string; + modelConfigKey: string; + providers: Array<{ id: string; defaultModel?: string | null }>; + fallbackSampleRef?: string; +}): string { + const sampleModel = params.providers.find( + (provider) => provider.id.trim().length > 0 && provider.defaultModel?.trim(), + ); + const sampleRef = sampleModel + ? `${sampleModel.id}/${sampleModel.defaultModel}` + : (params.fallbackSampleRef ?? "/"); + const authHints = params.providers + .flatMap((provider) => { + const envVars = getProviderEnvVars(provider.id); + if (envVars.length === 0) { + return []; + } + return [`${provider.id}: ${envVars.join(" / ")}`]; + }) + .slice(0, 3); + return [ + `No ${params.capabilityLabel} model configured. Set agents.defaults.${params.modelConfigKey}.primary to a provider/model like "${sampleRef}".`, + authHints.length > 0 + ? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).` + : "If you want a specific provider, also configure that provider's auth/API key first.", + ].join(" "); +} diff --git a/src/music-generation/live-test-helpers.ts b/src/music-generation/live-test-helpers.ts new file mode 100644 index 00000000000..0ac204f3211 --- /dev/null +++ b/src/music-generation/live-test-helpers.ts @@ -0,0 +1,4 @@ +export const DEFAULT_LIVE_MUSIC_MODELS: Record = { + google: "google/lyria-3-clip-preview", + minimax: "minimax/music-2.5+", +}; diff --git a/src/music-generation/model-ref.ts b/src/music-generation/model-ref.ts new file mode 100644 index 00000000000..d58562570bc --- /dev/null +++ b/src/music-generation/model-ref.ts @@ -0,0 +1,16 @@ +export function parseMusicGenerationModelRef( + raw: string | undefined, +): { provider: string; model: string } | null { + const trimmed = raw?.trim(); + if (!trimmed) { + return null; + } + const slashIndex = trimmed.indexOf("/"); + if (slashIndex <= 0 || slashIndex === trimmed.length - 1) { + return null; + } + return { + provider: trimmed.slice(0, slashIndex).trim(), + model: trimmed.slice(slashIndex + 1).trim(), + }; +} diff --git a/src/music-generation/provider-registry.ts b/src/music-generation/provider-registry.ts new file mode 100644 index 00000000000..c8a0bbb1999 --- /dev/null +++ b/src/music-generation/provider-registry.ts @@ -0,0 +1,77 @@ +import { normalizeProviderId } from "../agents/model-selection.js"; +import type { OpenClawConfig } from "../config/config.js"; +import { isBlockedObjectKey } from "../infra/prototype-keys.js"; +import { resolvePluginCapabilityProviders } from "../plugins/capability-provider-runtime.js"; +import type { MusicGenerationProviderPlugin } from "../plugins/types.js"; + +const BUILTIN_MUSIC_GENERATION_PROVIDERS: readonly MusicGenerationProviderPlugin[] = []; +const UNSAFE_PROVIDER_IDS = new Set(["__proto__", "constructor", "prototype"]); + +function normalizeMusicGenerationProviderId(id: string | undefined): string | undefined { + const normalized = normalizeProviderId(id ?? ""); + if (!normalized || isBlockedObjectKey(normalized)) { + return undefined; + } + return normalized; +} + +function isSafeMusicGenerationProviderId(id: string | undefined): id is string { + return Boolean(id && !UNSAFE_PROVIDER_IDS.has(id)); +} + +function resolvePluginMusicGenerationProviders( + cfg?: OpenClawConfig, +): MusicGenerationProviderPlugin[] { + return resolvePluginCapabilityProviders({ + key: "musicGenerationProviders", + cfg, + }); +} + +function buildProviderMaps(cfg?: OpenClawConfig): { + canonical: Map; + aliases: Map; +} { + const canonical = new Map(); + const aliases = new Map(); + const register = (provider: MusicGenerationProviderPlugin) => { + const id = normalizeMusicGenerationProviderId(provider.id); + if (!isSafeMusicGenerationProviderId(id)) { + return; + } + canonical.set(id, provider); + aliases.set(id, provider); + for (const alias of provider.aliases ?? []) { + const normalizedAlias = normalizeMusicGenerationProviderId(alias); + if (isSafeMusicGenerationProviderId(normalizedAlias)) { + aliases.set(normalizedAlias, provider); + } + } + }; + + for (const provider of BUILTIN_MUSIC_GENERATION_PROVIDERS) { + register(provider); + } + for (const provider of resolvePluginMusicGenerationProviders(cfg)) { + register(provider); + } + + return { canonical, aliases }; +} + +export function listMusicGenerationProviders( + cfg?: OpenClawConfig, +): MusicGenerationProviderPlugin[] { + return [...buildProviderMaps(cfg).canonical.values()]; +} + +export function getMusicGenerationProvider( + providerId: string | undefined, + cfg?: OpenClawConfig, +): MusicGenerationProviderPlugin | undefined { + const normalized = normalizeMusicGenerationProviderId(providerId); + if (!normalized) { + return undefined; + } + return buildProviderMaps(cfg).aliases.get(normalized); +} diff --git a/src/music-generation/runtime.ts b/src/music-generation/runtime.ts new file mode 100644 index 00000000000..8603a8e617c --- /dev/null +++ b/src/music-generation/runtime.ts @@ -0,0 +1,129 @@ +import type { AuthProfileStore } from "../agents/auth-profiles.js"; +import { describeFailoverError, isFailoverError } from "../agents/failover-error.js"; +import type { FallbackAttempt } from "../agents/model-fallback.types.js"; +import type { OpenClawConfig } from "../config/config.js"; +import { createSubsystemLogger } from "../logging/subsystem.js"; +import { + buildNoCapabilityModelConfiguredMessage, + resolveCapabilityModelCandidates, + throwCapabilityGenerationFailure, +} from "../media-generation/runtime-shared.js"; +import { parseMusicGenerationModelRef } from "./model-ref.js"; +import { getMusicGenerationProvider, listMusicGenerationProviders } from "./provider-registry.js"; +import type { + GeneratedMusicAsset, + MusicGenerationOutputFormat, + MusicGenerationResult, + MusicGenerationSourceImage, +} from "./types.js"; + +const log = createSubsystemLogger("music-generation"); + +export type GenerateMusicParams = { + cfg: OpenClawConfig; + prompt: string; + agentDir?: string; + authStore?: AuthProfileStore; + modelOverride?: string; + lyrics?: string; + instrumental?: boolean; + durationSeconds?: number; + format?: MusicGenerationOutputFormat; + inputImages?: MusicGenerationSourceImage[]; +}; + +export type GenerateMusicRuntimeResult = { + tracks: GeneratedMusicAsset[]; + provider: string; + model: string; + attempts: FallbackAttempt[]; + lyrics?: string[]; + metadata?: Record; +}; + +export function listRuntimeMusicGenerationProviders(params?: { config?: OpenClawConfig }) { + return listMusicGenerationProviders(params?.config); +} + +export async function generateMusic( + params: GenerateMusicParams, +): Promise { + const candidates = resolveCapabilityModelCandidates({ + cfg: params.cfg, + modelConfig: params.cfg.agents?.defaults?.musicGenerationModel, + modelOverride: params.modelOverride, + parseModelRef: parseMusicGenerationModelRef, + }); + if (candidates.length === 0) { + throw new Error( + buildNoCapabilityModelConfiguredMessage({ + capabilityLabel: "music-generation", + modelConfigKey: "musicGenerationModel", + providers: listMusicGenerationProviders(params.cfg), + fallbackSampleRef: "google/lyria-3-clip-preview", + }), + ); + } + + const attempts: FallbackAttempt[] = []; + let lastError: unknown; + + for (const candidate of candidates) { + const provider = getMusicGenerationProvider(candidate.provider, params.cfg); + if (!provider) { + const error = `No music-generation provider registered for ${candidate.provider}`; + attempts.push({ + provider: candidate.provider, + model: candidate.model, + error, + }); + lastError = new Error(error); + continue; + } + + try { + const result: MusicGenerationResult = await provider.generateMusic({ + provider: candidate.provider, + model: candidate.model, + prompt: params.prompt, + cfg: params.cfg, + agentDir: params.agentDir, + authStore: params.authStore, + lyrics: params.lyrics, + instrumental: params.instrumental, + durationSeconds: params.durationSeconds, + format: params.format, + inputImages: params.inputImages, + }); + if (!Array.isArray(result.tracks) || result.tracks.length === 0) { + throw new Error("Music generation provider returned no tracks."); + } + return { + tracks: result.tracks, + provider: candidate.provider, + model: result.model ?? candidate.model, + attempts, + lyrics: result.lyrics, + metadata: result.metadata, + }; + } catch (err) { + lastError = err; + const described = isFailoverError(err) ? describeFailoverError(err) : undefined; + attempts.push({ + provider: candidate.provider, + model: candidate.model, + error: described?.message ?? (err instanceof Error ? err.message : String(err)), + reason: described?.reason, + status: described?.status, + code: described?.code, + }); + log.debug(`music-generation candidate failed: ${candidate.provider}/${candidate.model}`); + } + } + + throwCapabilityGenerationFailure({ + capabilityLabel: "music generation", + attempts, + lastError, + }); +} diff --git a/src/music-generation/types.ts b/src/music-generation/types.ts new file mode 100644 index 00000000000..4889f7deb03 --- /dev/null +++ b/src/music-generation/types.ts @@ -0,0 +1,69 @@ +import type { AuthProfileStore } from "../agents/auth-profiles.js"; +import type { OpenClawConfig } from "../config/config.js"; + +export type MusicGenerationOutputFormat = "mp3" | "wav"; + +export type GeneratedMusicAsset = { + buffer: Buffer; + mimeType: string; + fileName?: string; + metadata?: Record; +}; + +export type MusicGenerationSourceImage = { + url?: string; + buffer?: Buffer; + mimeType?: string; + fileName?: string; + metadata?: Record; +}; + +export type MusicGenerationProviderConfiguredContext = { + cfg?: OpenClawConfig; + agentDir?: string; +}; + +export type MusicGenerationRequest = { + provider: string; + model: string; + prompt: string; + cfg: OpenClawConfig; + agentDir?: string; + authStore?: AuthProfileStore; + timeoutMs?: number; + lyrics?: string; + instrumental?: boolean; + durationSeconds?: number; + format?: MusicGenerationOutputFormat; + inputImages?: MusicGenerationSourceImage[]; +}; + +export type MusicGenerationResult = { + tracks: GeneratedMusicAsset[]; + model?: string; + lyrics?: string[]; + metadata?: Record; +}; + +export type MusicGenerationProviderCapabilities = { + maxTracks?: number; + maxInputImages?: number; + maxDurationSeconds?: number; + supportsLyrics?: boolean; + supportsInstrumental?: boolean; + supportsDuration?: boolean; + supportsFormat?: boolean; + supportedFormats?: readonly MusicGenerationOutputFormat[]; + supportedFormatsByModel?: Readonly>; +}; + +export type MusicGenerationProvider = { + id: string; + aliases?: string[]; + label?: string; + defaultModel?: string; + models?: string[]; + capabilities: MusicGenerationProviderCapabilities; + isConfigured?: (ctx: MusicGenerationProviderConfiguredContext) => boolean; + generateMusic: (req: MusicGenerationRequest) => Promise; +}; diff --git a/src/plugin-sdk/index.ts b/src/plugin-sdk/index.ts index 92fd3f44afa..b2703cf3094 100644 --- a/src/plugin-sdk/index.ts +++ b/src/plugin-sdk/index.ts @@ -78,6 +78,7 @@ export type { OpenClawConfig } from "../config/config.js"; /** @deprecated Use OpenClawConfig instead */ export type { OpenClawConfig as ClawdbotConfig } from "../config/config.js"; export * from "./image-generation.js"; +export * from "./music-generation.js"; export type { SecretInput, SecretRef } from "../config/types.secrets.js"; export type { RuntimeEnv } from "../runtime.js"; export type { HookEntry } from "../hooks/types.js"; diff --git a/src/plugin-sdk/music-generation-core.ts b/src/plugin-sdk/music-generation-core.ts new file mode 100644 index 00000000000..687a5f9a612 --- /dev/null +++ b/src/plugin-sdk/music-generation-core.ts @@ -0,0 +1,28 @@ +// Shared music-generation implementation helpers for bundled and third-party plugins. + +export type { AuthProfileStore } from "../agents/auth-profiles.js"; +export type { FallbackAttempt } from "../agents/model-fallback.types.js"; +export type { OpenClawConfig } from "../config/config.js"; +export type { MusicGenerationProviderPlugin } from "../plugins/types.js"; +export type { + GeneratedMusicAsset, + MusicGenerationOutputFormat, + MusicGenerationProvider, + MusicGenerationProviderCapabilities, + MusicGenerationRequest, + MusicGenerationResult, + MusicGenerationSourceImage, +} from "../music-generation/types.js"; + +export { describeFailoverError, isFailoverError } from "../agents/failover-error.js"; +export { + resolveAgentModelFallbackValues, + resolveAgentModelPrimaryValue, +} from "../config/model-input.js"; +export { createSubsystemLogger } from "../logging/subsystem.js"; +export { parseMusicGenerationModelRef } from "../music-generation/model-ref.js"; +export { + getMusicGenerationProvider, + listMusicGenerationProviders, +} from "../music-generation/provider-registry.js"; +export { getProviderEnvVars } from "../secrets/provider-env-vars.js"; diff --git a/src/plugin-sdk/music-generation.ts b/src/plugin-sdk/music-generation.ts new file mode 100644 index 00000000000..a2e9c33e944 --- /dev/null +++ b/src/plugin-sdk/music-generation.ts @@ -0,0 +1,11 @@ +// Public music-generation helpers and types for provider plugins. + +export type { + GeneratedMusicAsset, + MusicGenerationProvider, + MusicGenerationProviderCapabilities, + MusicGenerationRequest, + MusicGenerationResult, + MusicGenerationSourceImage, + MusicGenerationOutputFormat, +} from "../music-generation/types.js"; diff --git a/src/plugins/.DS_Store b/src/plugins/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..d7a901275a9342c8a39f300780c16094bc555750 GIT binary patch literal 6148 zcmeHKO-jR15S|wcR&>#&%N#(#?uEt@PtXfUZLx*Kl-91z%Drdk1&Zhyyn;J-9>JBM z-~3Qp(uyllnSq(_<-M8ATAI6Js$UEfLejpxpb z&uYKpzf!lnKU}(tGvEw31Al@6^lX;;uA*CKz!`7`77XzBA%-$0hNYr>I$&fA0IXqJ z1Y_Pya7=8N7?z4yfdmZ&YG|-6MzCM0py60g%)eCBa0-$c=dsMeb|^uzV?Gvl3W=gy zXTTX)WnfDW8{Gd-H}C&fgZ#)Da0dPr1EH37(l##1dTZt4xYx$eGbjuDm5PfH45Ssq ems{~ZR0Q)e4}ghbsfZbf{|E>TZk&N1W#9u8%UUP^ literal 0 HcmV?d00001 diff --git a/src/plugins/api-builder.ts b/src/plugins/api-builder.ts index 33519cc8228..95fdd3474b9 100644 --- a/src/plugins/api-builder.ts +++ b/src/plugins/api-builder.ts @@ -37,6 +37,7 @@ export type BuildPluginApiParams = { | "registerMediaUnderstandingProvider" | "registerImageGenerationProvider" | "registerVideoGenerationProvider" + | "registerMusicGenerationProvider" | "registerWebFetchProvider" | "registerWebSearchProvider" | "registerInteractiveHandler" @@ -77,6 +78,8 @@ const noopRegisterImageGenerationProvider: OpenClawPluginApi["registerImageGener () => {}; const noopRegisterVideoGenerationProvider: OpenClawPluginApi["registerVideoGenerationProvider"] = () => {}; +const noopRegisterMusicGenerationProvider: OpenClawPluginApi["registerMusicGenerationProvider"] = + () => {}; const noopRegisterWebFetchProvider: OpenClawPluginApi["registerWebFetchProvider"] = () => {}; const noopRegisterWebSearchProvider: OpenClawPluginApi["registerWebSearchProvider"] = () => {}; const noopRegisterInteractiveHandler: OpenClawPluginApi["registerInteractiveHandler"] = () => {}; @@ -130,6 +133,8 @@ export function buildPluginApi(params: BuildPluginApiParams): OpenClawPluginApi handlers.registerImageGenerationProvider ?? noopRegisterImageGenerationProvider, registerVideoGenerationProvider: handlers.registerVideoGenerationProvider ?? noopRegisterVideoGenerationProvider, + registerMusicGenerationProvider: + handlers.registerMusicGenerationProvider ?? noopRegisterMusicGenerationProvider, registerWebFetchProvider: handlers.registerWebFetchProvider ?? noopRegisterWebFetchProvider, registerWebSearchProvider: handlers.registerWebSearchProvider ?? noopRegisterWebSearchProvider, registerInteractiveHandler: diff --git a/src/plugins/bundled-capability-metadata.test.ts b/src/plugins/bundled-capability-metadata.test.ts index 29515065650..32104012ae8 100644 --- a/src/plugins/bundled-capability-metadata.test.ts +++ b/src/plugins/bundled-capability-metadata.test.ts @@ -36,6 +36,7 @@ describe("bundled capability metadata", () => { ), imageGenerationProviderIds: uniqueStrings(manifest.contracts?.imageGenerationProviders), videoGenerationProviderIds: uniqueStrings(manifest.contracts?.videoGenerationProviders), + musicGenerationProviderIds: uniqueStrings(manifest.contracts?.musicGenerationProviders), webFetchProviderIds: uniqueStrings(manifest.contracts?.webFetchProviders), webSearchProviderIds: uniqueStrings(manifest.contracts?.webSearchProviders), toolNames: uniqueStrings(manifest.contracts?.tools), @@ -49,6 +50,7 @@ describe("bundled capability metadata", () => { entry.mediaUnderstandingProviderIds.length > 0 || entry.imageGenerationProviderIds.length > 0 || entry.videoGenerationProviderIds.length > 0 || + entry.musicGenerationProviderIds.length > 0 || entry.webFetchProviderIds.length > 0 || entry.webSearchProviderIds.length > 0 || entry.toolNames.length > 0, diff --git a/src/plugins/bundled-capability-runtime.ts b/src/plugins/bundled-capability-runtime.ts index 54dcd1cda77..59be1cea448 100644 --- a/src/plugins/bundled-capability-runtime.ts +++ b/src/plugins/bundled-capability-runtime.ts @@ -126,6 +126,7 @@ function createCapabilityPluginRecord(params: { mediaUnderstandingProviderIds: [], imageGenerationProviderIds: [], videoGenerationProviderIds: [], + musicGenerationProviderIds: [], webFetchProviderIds: [], webSearchProviderIds: [], memoryEmbeddingProviderIds: [], @@ -289,6 +290,9 @@ export function loadBundledCapabilityRuntimeRegistry(params: { record.videoGenerationProviderIds.push( ...captured.videoGenerationProviders.map((entry) => entry.id), ); + record.musicGenerationProviderIds.push( + ...captured.musicGenerationProviders.map((entry) => entry.id), + ); record.webFetchProviderIds.push(...captured.webFetchProviders.map((entry) => entry.id)); record.webSearchProviderIds.push(...captured.webSearchProviders.map((entry) => entry.id)); record.memoryEmbeddingProviderIds.push( @@ -359,6 +363,15 @@ export function loadBundledCapabilityRuntimeRegistry(params: { rootDir: record.rootDir, })), ); + registry.musicGenerationProviders.push( + ...captured.musicGenerationProviders.map((provider) => ({ + pluginId: record.id, + pluginName: record.name, + provider, + source: record.source, + rootDir: record.rootDir, + })), + ); registry.webFetchProviders.push( ...captured.webFetchProviders.map((provider) => ({ pluginId: record.id, diff --git a/src/plugins/capability-provider-runtime.ts b/src/plugins/capability-provider-runtime.ts index 8a19a606a63..968ab23f034 100644 --- a/src/plugins/capability-provider-runtime.ts +++ b/src/plugins/capability-provider-runtime.ts @@ -14,7 +14,8 @@ type CapabilityProviderRegistryKey = | "realtimeVoiceProviders" | "mediaUnderstandingProviders" | "imageGenerationProviders" - | "videoGenerationProviders"; + | "videoGenerationProviders" + | "musicGenerationProviders"; type CapabilityContractKey = | "memoryEmbeddingProviders" @@ -23,7 +24,8 @@ type CapabilityContractKey = | "realtimeVoiceProviders" | "mediaUnderstandingProviders" | "imageGenerationProviders" - | "videoGenerationProviders"; + | "videoGenerationProviders" + | "musicGenerationProviders"; type CapabilityProviderForKey = PluginRegistry[K][number] extends { provider: infer T } ? T : never; @@ -36,6 +38,7 @@ const CAPABILITY_CONTRACT_KEY: Record 0 || entry.imageGenerationProviderIds.length > 0 || entry.videoGenerationProviderIds.length > 0 || + entry.musicGenerationProviderIds.length > 0 || entry.webFetchProviderIds.length > 0 || entry.webSearchProviderIds.length > 0 || entry.toolNames.length > 0, diff --git a/src/plugins/contracts/registry.ts b/src/plugins/contracts/registry.ts index de997f35b95..9002ef939de 100644 --- a/src/plugins/contracts/registry.ts +++ b/src/plugins/contracts/registry.ts @@ -6,6 +6,7 @@ import { import type { ImageGenerationProviderPlugin, MediaUnderstandingProviderPlugin, + MusicGenerationProviderPlugin, ProviderPlugin, RealtimeTranscriptionProviderPlugin, RealtimeVoiceProviderPlugin, @@ -17,6 +18,7 @@ import type { import { loadVitestImageGenerationProviderContractRegistry, loadVitestMediaUnderstandingProviderContractRegistry, + loadVitestMusicGenerationProviderContractRegistry, loadVitestRealtimeTranscriptionProviderContractRegistry, loadVitestRealtimeVoiceProviderContractRegistry, loadVitestSpeechProviderContractRegistry, @@ -44,6 +46,7 @@ type MediaUnderstandingProviderContractEntry = CapabilityContractEntry; type ImageGenerationProviderContractEntry = CapabilityContractEntry; type VideoGenerationProviderContractEntry = CapabilityContractEntry; +type MusicGenerationProviderContractEntry = CapabilityContractEntry; type PluginRegistrationContractEntry = { pluginId: string; @@ -54,6 +57,7 @@ type PluginRegistrationContractEntry = { mediaUnderstandingProviderIds: string[]; imageGenerationProviderIds: string[]; videoGenerationProviderIds: string[]; + musicGenerationProviderIds: string[]; webFetchProviderIds: string[]; webSearchProviderIds: string[]; toolNames: string[]; @@ -66,6 +70,7 @@ type ManifestContractKey = | "mediaUnderstandingProviders" | "imageGenerationProviders" | "videoGenerationProviders" + | "musicGenerationProviders" | "webFetchProviders" | "webSearchProviders" | "tools"; @@ -97,6 +102,7 @@ function resolveBundledManifestContracts(): PluginRegistrationContractEntry[] { (plugin.contracts?.mediaUnderstandingProviders?.length ?? 0) > 0 || (plugin.contracts?.imageGenerationProviders?.length ?? 0) > 0 || (plugin.contracts?.videoGenerationProviders?.length ?? 0) > 0 || + (plugin.contracts?.musicGenerationProviders?.length ?? 0) > 0 || (plugin.contracts?.webFetchProviders?.length ?? 0) > 0 || (plugin.contracts?.webSearchProviders?.length ?? 0) > 0 || (plugin.contracts?.tools?.length ?? 0) > 0), @@ -114,6 +120,7 @@ function resolveBundledManifestContracts(): PluginRegistrationContractEntry[] { ), imageGenerationProviderIds: uniqueStrings(plugin.contracts?.imageGenerationProviders ?? []), videoGenerationProviderIds: uniqueStrings(plugin.contracts?.videoGenerationProviders ?? []), + musicGenerationProviderIds: uniqueStrings(plugin.contracts?.musicGenerationProviders ?? []), webFetchProviderIds: uniqueStrings(plugin.contracts?.webFetchProviders ?? []), webSearchProviderIds: uniqueStrings(plugin.contracts?.webSearchProviders ?? []), toolNames: uniqueStrings(plugin.contracts?.tools ?? []), @@ -166,6 +173,8 @@ function resolveBundledManifestPluginIdsForContract(contract: ManifestContractKe return entry.imageGenerationProviderIds.length > 0; case "videoGenerationProviders": return entry.videoGenerationProviderIds.length > 0; + case "musicGenerationProviders": + return entry.musicGenerationProviderIds.length > 0; case "webFetchProviders": return entry.webFetchProviderIds.length > 0; case "webSearchProviders": @@ -202,6 +211,8 @@ let imageGenerationProviderContractRegistryCache: ImageGenerationProviderContrac null; let videoGenerationProviderContractRegistryCache: VideoGenerationProviderContractEntry[] | null = null; +let musicGenerationProviderContractRegistryCache: MusicGenerationProviderContractEntry[] | null = + null; export let providerContractLoadError: Error | undefined; @@ -564,6 +575,21 @@ function loadVideoGenerationProviderContractRegistry(): VideoGenerationProviderC return videoGenerationProviderContractRegistryCache; } +function loadMusicGenerationProviderContractRegistry(): MusicGenerationProviderContractEntry[] { + if (!musicGenerationProviderContractRegistryCache) { + musicGenerationProviderContractRegistryCache = process.env.VITEST + ? loadVitestMusicGenerationProviderContractRegistry() + : loadBundledCapabilityRuntimeRegistry({ + pluginIds: resolveBundledManifestPluginIdsForContract("musicGenerationProviders"), + pluginSdkResolution: "dist", + }).musicGenerationProviders.map((entry) => ({ + pluginId: entry.pluginId, + provider: entry.provider, + })); + } + return musicGenerationProviderContractRegistryCache; +} + function createLazyArrayView(load: () => T[]): T[] { return new Proxy([] as T[], { get(_target, prop) { @@ -667,6 +693,8 @@ export const imageGenerationProviderContractRegistry: ImageGenerationProviderCon createLazyArrayView(loadImageGenerationProviderContractRegistry); export const videoGenerationProviderContractRegistry: VideoGenerationProviderContractEntry[] = createLazyArrayView(loadVideoGenerationProviderContractRegistry); +export const musicGenerationProviderContractRegistry: MusicGenerationProviderContractEntry[] = + createLazyArrayView(loadMusicGenerationProviderContractRegistry); function loadPluginRegistrationContractRegistry(): PluginRegistrationContractEntry[] { return resolveBundledManifestContracts(); diff --git a/src/plugins/contracts/speech-vitest-registry.ts b/src/plugins/contracts/speech-vitest-registry.ts index 95d7a548779..e4867206711 100644 --- a/src/plugins/contracts/speech-vitest-registry.ts +++ b/src/plugins/contracts/speech-vitest-registry.ts @@ -3,6 +3,7 @@ import { resolveManifestContractPluginIds } from "../manifest-registry.js"; import type { ImageGenerationProviderPlugin, MediaUnderstandingProviderPlugin, + MusicGenerationProviderPlugin, RealtimeTranscriptionProviderPlugin, RealtimeVoiceProviderPlugin, SpeechProviderPlugin, @@ -39,13 +40,19 @@ export type VideoGenerationProviderContractEntry = { provider: VideoGenerationProviderPlugin; }; +export type MusicGenerationProviderContractEntry = { + pluginId: string; + provider: MusicGenerationProviderPlugin; +}; + type ManifestContractKey = | "imageGenerationProviders" | "speechProviders" | "mediaUnderstandingProviders" | "realtimeVoiceProviders" | "realtimeTranscriptionProviders" - | "videoGenerationProviders"; + | "videoGenerationProviders" + | "musicGenerationProviders"; function loadVitestCapabilityContractEntries(params: { contract: ManifestContractKey; @@ -134,3 +141,14 @@ export function loadVitestVideoGenerationProviderContractRegistry(): VideoGenera })), }); } + +export function loadVitestMusicGenerationProviderContractRegistry(): MusicGenerationProviderContractEntry[] { + return loadVitestCapabilityContractEntries({ + contract: "musicGenerationProviders", + pickEntries: (registry) => + registry.musicGenerationProviders.map((entry) => ({ + pluginId: entry.pluginId, + provider: entry.provider, + })), + }); +} diff --git a/src/plugins/hooks.test-helpers.ts b/src/plugins/hooks.test-helpers.ts index 146ce8b291e..b9b9a168f4f 100644 --- a/src/plugins/hooks.test-helpers.ts +++ b/src/plugins/hooks.test-helpers.ts @@ -39,6 +39,7 @@ export function createMockPluginRegistry( mediaUnderstandingProviders: [], imageGenerationProviders: [], videoGenerationProviders: [], + musicGenerationProviders: [], webSearchProviders: [], httpRoutes: [], gatewayHandlers: {}, diff --git a/src/plugins/loader.ts b/src/plugins/loader.ts index e4371251604..17c759493b3 100644 --- a/src/plugins/loader.ts +++ b/src/plugins/loader.ts @@ -596,6 +596,7 @@ function createPluginRecord(params: { mediaUnderstandingProviderIds: [], imageGenerationProviderIds: [], videoGenerationProviderIds: [], + musicGenerationProviderIds: [], webFetchProviderIds: [], webSearchProviderIds: [], memoryEmbeddingProviderIds: [], diff --git a/src/plugins/manifest-registry.ts b/src/plugins/manifest-registry.ts index 75b7c87d34b..072692e763e 100644 --- a/src/plugins/manifest-registry.ts +++ b/src/plugins/manifest-registry.ts @@ -36,6 +36,7 @@ type PluginManifestContractListKey = | "realtimeTranscriptionProviders" | "imageGenerationProviders" | "videoGenerationProviders" + | "musicGenerationProviders" | "memoryEmbeddingProviders" | "webFetchProviders" | "webSearchProviders"; diff --git a/src/plugins/manifest.ts b/src/plugins/manifest.ts index 63d0537c186..3327d2cf595 100644 --- a/src/plugins/manifest.ts +++ b/src/plugins/manifest.ts @@ -76,6 +76,7 @@ export type PluginManifestContracts = { mediaUnderstandingProviders?: string[]; imageGenerationProviders?: string[]; videoGenerationProviders?: string[]; + musicGenerationProviders?: string[]; webFetchProviders?: string[]; webSearchProviders?: string[]; tools?: string[]; @@ -157,6 +158,7 @@ function normalizeManifestContracts(value: unknown): PluginManifestContracts | u const mediaUnderstandingProviders = normalizeStringList(value.mediaUnderstandingProviders); const imageGenerationProviders = normalizeStringList(value.imageGenerationProviders); const videoGenerationProviders = normalizeStringList(value.videoGenerationProviders); + const musicGenerationProviders = normalizeStringList(value.musicGenerationProviders); const webFetchProviders = normalizeStringList(value.webFetchProviders); const webSearchProviders = normalizeStringList(value.webSearchProviders); const tools = normalizeStringList(value.tools); @@ -168,6 +170,7 @@ function normalizeManifestContracts(value: unknown): PluginManifestContracts | u ...(mediaUnderstandingProviders.length > 0 ? { mediaUnderstandingProviders } : {}), ...(imageGenerationProviders.length > 0 ? { imageGenerationProviders } : {}), ...(videoGenerationProviders.length > 0 ? { videoGenerationProviders } : {}), + ...(musicGenerationProviders.length > 0 ? { musicGenerationProviders } : {}), ...(webFetchProviders.length > 0 ? { webFetchProviders } : {}), ...(webSearchProviders.length > 0 ? { webSearchProviders } : {}), ...(tools.length > 0 ? { tools } : {}), diff --git a/src/plugins/registry-empty.ts b/src/plugins/registry-empty.ts index fb7529864d3..f7622e4764b 100644 --- a/src/plugins/registry-empty.ts +++ b/src/plugins/registry-empty.ts @@ -15,6 +15,7 @@ export function createEmptyPluginRegistry(): PluginRegistry { mediaUnderstandingProviders: [], imageGenerationProviders: [], videoGenerationProviders: [], + musicGenerationProviders: [], webFetchProviders: [], webSearchProviders: [], memoryEmbeddingProviders: [], diff --git a/src/plugins/registry.ts b/src/plugins/registry.ts index 3066aac2ea9..4921ececbb1 100644 --- a/src/plugins/registry.ts +++ b/src/plugins/registry.ts @@ -45,6 +45,7 @@ import { } from "./types.js"; import type { ImageGenerationProviderPlugin, + MusicGenerationProviderPlugin, RealtimeTranscriptionProviderPlugin, OpenClawPluginApi, OpenClawPluginChannelRegistration, @@ -157,6 +158,8 @@ export type PluginImageGenerationProviderRegistration = PluginOwnedProviderRegistration; export type PluginVideoGenerationProviderRegistration = PluginOwnedProviderRegistration; +export type PluginMusicGenerationProviderRegistration = + PluginOwnedProviderRegistration; export type PluginWebFetchProviderRegistration = PluginOwnedProviderRegistration; export type PluginWebSearchProviderRegistration = @@ -254,6 +257,7 @@ export type PluginRecord = { mediaUnderstandingProviderIds: string[]; imageGenerationProviderIds: string[]; videoGenerationProviderIds: string[]; + musicGenerationProviderIds: string[]; webFetchProviderIds: string[]; webSearchProviderIds: string[]; memoryEmbeddingProviderIds: string[]; @@ -284,6 +288,7 @@ export type PluginRegistry = { mediaUnderstandingProviders: PluginMediaUnderstandingProviderRegistration[]; imageGenerationProviders: PluginImageGenerationProviderRegistration[]; videoGenerationProviders: PluginVideoGenerationProviderRegistration[]; + musicGenerationProviders: PluginMusicGenerationProviderRegistration[]; webFetchProviders: PluginWebFetchProviderRegistration[]; webSearchProviders: PluginWebSearchProviderRegistration[]; memoryEmbeddingProviders: PluginMemoryEmbeddingProviderRegistration[]; @@ -787,6 +792,19 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { }); }; + const registerMusicGenerationProvider = ( + record: PluginRecord, + provider: MusicGenerationProviderPlugin, + ) => { + registerUniqueProviderLike({ + record, + provider, + kindLabel: "music-generation provider", + registrations: registry.musicGenerationProviders, + ownedIds: record.musicGenerationProviderIds, + }); + }; + const registerWebFetchProvider = (record: PluginRecord, provider: WebFetchProviderPlugin) => { registerUniqueProviderLike({ record, @@ -1179,6 +1197,8 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { registerImageGenerationProvider(record, provider), registerVideoGenerationProvider: (provider) => registerVideoGenerationProvider(record, provider), + registerMusicGenerationProvider: (provider) => + registerMusicGenerationProvider(record, provider), registerWebFetchProvider: (provider) => registerWebFetchProvider(record, provider), registerWebSearchProvider: (provider) => registerWebSearchProvider(record, provider), registerGatewayMethod: (method, handler, opts) => @@ -1381,6 +1401,7 @@ export function createPluginRegistry(registryParams: PluginRegistryParams) { registerMediaUnderstandingProvider, registerImageGenerationProvider, registerVideoGenerationProvider, + registerMusicGenerationProvider, registerWebSearchProvider, registerGatewayMethod, registerCli, diff --git a/src/plugins/runtime.test.ts b/src/plugins/runtime.test.ts index 2ee8e2a2d82..af33bd91bda 100644 --- a/src/plugins/runtime.test.ts +++ b/src/plugins/runtime.test.ts @@ -203,6 +203,7 @@ describe("setActivePluginRegistry", () => { mediaUnderstandingProviderIds: [], imageGenerationProviderIds: [], videoGenerationProviderIds: [], + musicGenerationProviderIds: [], webFetchProviderIds: [], webSearchProviderIds: [], memoryEmbeddingProviderIds: [], @@ -232,6 +233,7 @@ describe("setActivePluginRegistry", () => { mediaUnderstandingProviderIds: [], imageGenerationProviderIds: [], videoGenerationProviderIds: [], + musicGenerationProviderIds: [], webFetchProviderIds: [], webSearchProviderIds: [], memoryEmbeddingProviderIds: [], diff --git a/src/plugins/runtime/index.ts b/src/plugins/runtime/index.ts index 2685800af36..ae62a2d25ab 100644 --- a/src/plugins/runtime/index.ts +++ b/src/plugins/runtime/index.ts @@ -3,6 +3,10 @@ import { generateImage as generateRuntimeImage, listRuntimeImageGenerationProviders, } from "../../image-generation/runtime.js"; +import { + generateMusic as generateRuntimeMusic, + listRuntimeMusicGenerationProviders, +} from "../../music-generation/runtime.js"; import { resolveGlobalSingleton } from "../../shared/global-singleton.js"; import { createLazyRuntimeMethod, @@ -73,6 +77,13 @@ function createRuntimeVideoGeneration(): PluginRuntime["videoGeneration"] { }; } +function createRuntimeMusicGeneration(): PluginRuntime["musicGeneration"] { + return { + generate: (params) => generateRuntimeMusic(params), + listProviders: (params) => listRuntimeMusicGenerationProviders(params), + }; +} + function createRuntimeModelAuth(): PluginRuntime["modelAuth"] { const getApiKeyForModel = createLazyRuntimeMethod( loadModelAuthRuntime, @@ -211,12 +222,24 @@ export function createPluginRuntime(_options: CreatePluginRuntimeOptions = {}): taskFlow, } satisfies Omit< PluginRuntime, - "tts" | "mediaUnderstanding" | "stt" | "modelAuth" | "imageGeneration" | "videoGeneration" + | "tts" + | "mediaUnderstanding" + | "stt" + | "modelAuth" + | "imageGeneration" + | "videoGeneration" + | "musicGeneration" > & Partial< Pick< PluginRuntime, - "tts" | "mediaUnderstanding" | "stt" | "modelAuth" | "imageGeneration" | "videoGeneration" + | "tts" + | "mediaUnderstanding" + | "stt" + | "modelAuth" + | "imageGeneration" + | "videoGeneration" + | "musicGeneration" > >; @@ -228,6 +251,7 @@ export function createPluginRuntime(_options: CreatePluginRuntimeOptions = {}): defineCachedValue(runtime, "modelAuth", createRuntimeModelAuth); defineCachedValue(runtime, "imageGeneration", createRuntimeImageGeneration); defineCachedValue(runtime, "videoGeneration", createRuntimeVideoGeneration); + defineCachedValue(runtime, "musicGeneration", createRuntimeMusicGeneration); return runtime as PluginRuntime; } diff --git a/src/plugins/runtime/types-core.ts b/src/plugins/runtime/types-core.ts index 6c2fe8e0a13..48c86a5795e 100644 --- a/src/plugins/runtime/types-core.ts +++ b/src/plugins/runtime/types-core.ts @@ -86,6 +86,10 @@ export type PluginRuntimeCore = { generate: typeof import("../../video-generation/runtime.js").generateVideo; listProviders: typeof import("../../video-generation/runtime.js").listRuntimeVideoGenerationProviders; }; + musicGeneration: { + generate: typeof import("../../music-generation/runtime.js").generateMusic; + listProviders: typeof import("../../music-generation/runtime.js").listRuntimeMusicGenerationProviders; + }; webSearch: { listProviders: typeof import("../../web-search/runtime.js").listWebSearchProviders; search: typeof import("../../web-search/runtime.js").runWebSearch; diff --git a/src/plugins/status.test-helpers.ts b/src/plugins/status.test-helpers.ts index 8cf4cedb480..138323d35e3 100644 --- a/src/plugins/status.test-helpers.ts +++ b/src/plugins/status.test-helpers.ts @@ -55,6 +55,7 @@ export function createPluginRecord( mediaUnderstandingProviderIds: [], imageGenerationProviderIds: [], videoGenerationProviderIds: [], + musicGenerationProviderIds: [], webFetchProviderIds: [], webSearchProviderIds: [], memoryEmbeddingProviderIds: [], @@ -121,6 +122,7 @@ export function createPluginLoadResult( mediaUnderstandingProviders: [], imageGenerationProviders: [], videoGenerationProviders: [], + musicGenerationProviders: [], webFetchProviders: [], webSearchProviders: [], memoryEmbeddingProviders: [], diff --git a/src/plugins/types.ts b/src/plugins/types.ts index 50164c7de08..c0807b7e92d 100644 --- a/src/plugins/types.ts +++ b/src/plugins/types.ts @@ -32,6 +32,7 @@ import type { HookEntry } from "../hooks/types.js"; import type { ImageGenerationProvider } from "../image-generation/types.js"; import type { ProviderUsageSnapshot } from "../infra/provider-usage.types.js"; import type { MediaUnderstandingProvider } from "../media-understanding/types.js"; +import type { MusicGenerationProvider } from "../music-generation/types.js"; import type { RealtimeTranscriptionProviderConfig, RealtimeTranscriptionProviderConfiguredContext, @@ -1724,6 +1725,7 @@ export type PluginRealtimeVoiceProviderEntry = RealtimeVoiceProviderPlugin & { export type MediaUnderstandingProviderPlugin = MediaUnderstandingProvider; export type ImageGenerationProviderPlugin = ImageGenerationProvider; export type VideoGenerationProviderPlugin = VideoGenerationProvider; +export type MusicGenerationProviderPlugin = MusicGenerationProvider; export type OpenClawPluginGatewayMethod = { method: string; @@ -2090,6 +2092,8 @@ export type OpenClawPluginApi = { registerImageGenerationProvider: (provider: ImageGenerationProviderPlugin) => void; /** Register a video generation provider (video generation capability). */ registerVideoGenerationProvider: (provider: VideoGenerationProviderPlugin) => void; + /** Register a music generation provider (music generation capability). */ + registerMusicGenerationProvider: (provider: MusicGenerationProviderPlugin) => void; /** Register a web fetch provider (web fetch capability). */ registerWebFetchProvider: (provider: WebFetchProviderPlugin) => void; /** Register a web search provider (web search capability). */ diff --git a/src/test-utils/channel-plugins.ts b/src/test-utils/channel-plugins.ts index 2d33eebc3cd..8da6a64837b 100644 --- a/src/test-utils/channel-plugins.ts +++ b/src/test-utils/channel-plugins.ts @@ -32,6 +32,7 @@ export const createTestRegistry = (channels: TestChannelRegistration[] = []): Pl mediaUnderstandingProviders: [], imageGenerationProviders: [], videoGenerationProviders: [], + musicGenerationProviders: [], webFetchProviders: [], webSearchProviders: [], memoryEmbeddingProviders: [], diff --git a/src/video-generation/runtime.ts b/src/video-generation/runtime.ts index 68eeb6e3fd9..ed76aa03678 100644 --- a/src/video-generation/runtime.ts +++ b/src/video-generation/runtime.ts @@ -2,12 +2,12 @@ import type { AuthProfileStore } from "../agents/auth-profiles.js"; import { describeFailoverError, isFailoverError } from "../agents/failover-error.js"; import type { FallbackAttempt } from "../agents/model-fallback.types.js"; import type { OpenClawConfig } from "../config/config.js"; -import { - resolveAgentModelFallbackValues, - resolveAgentModelPrimaryValue, -} from "../config/model-input.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; -import { getProviderEnvVars } from "../secrets/provider-env-vars.js"; +import { + buildNoCapabilityModelConfiguredMessage, + resolveCapabilityModelCandidates, + throwCapabilityGenerationFailure, +} from "../media-generation/runtime-shared.js"; import { normalizeVideoGenerationDuration, resolveVideoGenerationSupportedDurations, @@ -49,76 +49,12 @@ export type GenerateVideoRuntimeResult = { ignoredOverrides: VideoGenerationIgnoredOverride[]; }; -function resolveVideoGenerationCandidates(params: { - cfg: OpenClawConfig; - modelOverride?: string; -}): Array<{ provider: string; model: string }> { - const candidates: Array<{ provider: string; model: string }> = []; - const seen = new Set(); - const add = (raw: string | undefined) => { - const parsed = parseVideoGenerationModelRef(raw); - if (!parsed) { - return; - } - const key = `${parsed.provider}/${parsed.model}`; - if (seen.has(key)) { - return; - } - seen.add(key); - candidates.push(parsed); - }; - - add(params.modelOverride); - add(resolveAgentModelPrimaryValue(params.cfg.agents?.defaults?.videoGenerationModel)); - for (const fallback of resolveAgentModelFallbackValues( - params.cfg.agents?.defaults?.videoGenerationModel, - )) { - add(fallback); - } - return candidates; -} - -function throwVideoGenerationFailure(params: { - attempts: FallbackAttempt[]; - lastError: unknown; -}): never { - if (params.attempts.length <= 1 && params.lastError) { - throw params.lastError; - } - const summary = - params.attempts.length > 0 - ? params.attempts - .map((attempt) => `${attempt.provider}/${attempt.model}: ${attempt.error}`) - .join(" | ") - : "unknown"; - throw new Error(`All video generation models failed (${params.attempts.length}): ${summary}`, { - cause: params.lastError instanceof Error ? params.lastError : undefined, - }); -} - function buildNoVideoGenerationModelConfiguredMessage(cfg: OpenClawConfig): string { - const providers = listVideoGenerationProviders(cfg); - const sampleModel = providers.find( - (provider) => provider.id.trim().length > 0 && provider.defaultModel?.trim(), - ); - const sampleRef = sampleModel - ? `${sampleModel.id}/${sampleModel.defaultModel}` - : "/"; - const authHints = providers - .flatMap((provider) => { - const envVars = getProviderEnvVars(provider.id); - if (envVars.length === 0) { - return []; - } - return [`${provider.id}: ${envVars.join(" / ")}`]; - }) - .slice(0, 3); - return [ - `No video-generation model configured. Set agents.defaults.videoGenerationModel.primary to a provider/model like "${sampleRef}".`, - authHints.length > 0 - ? `If you want a specific provider, also configure that provider's auth/API key first (${authHints.join("; ")}).` - : "If you want a specific provider, also configure that provider's auth/API key first.", - ].join(" "); + return buildNoCapabilityModelConfiguredMessage({ + capabilityLabel: "video-generation", + modelConfigKey: "videoGenerationModel", + providers: listVideoGenerationProviders(cfg), + }); } export function listRuntimeVideoGenerationProviders(params?: { config?: OpenClawConfig }) { @@ -179,9 +115,11 @@ function resolveProviderVideoGenerationOverrides(params: { export async function generateVideo( params: GenerateVideoParams, ): Promise { - const candidates = resolveVideoGenerationCandidates({ + const candidates = resolveCapabilityModelCandidates({ cfg: params.cfg, + modelConfig: params.cfg.agents?.defaults?.videoGenerationModel, modelOverride: params.modelOverride, + parseModelRef: parseVideoGenerationModelRef, }); if (candidates.length === 0) { throw new Error(buildNoVideoGenerationModelConfiguredMessage(params.cfg)); @@ -277,5 +215,9 @@ export async function generateVideo( } } - throwVideoGenerationFailure({ attempts, lastError }); + throwCapabilityGenerationFailure({ + capabilityLabel: "video generation", + attempts, + lastError, + }); } diff --git a/test/helpers/plugins/plugin-api.ts b/test/helpers/plugins/plugin-api.ts index 910e60799fc..564e8502549 100644 --- a/test/helpers/plugins/plugin-api.ts +++ b/test/helpers/plugins/plugin-api.ts @@ -29,6 +29,7 @@ export function createTestPluginApi(api: TestPluginApiInput = {}): OpenClawPlugi registerRealtimeVoiceProvider() {}, registerMediaUnderstandingProvider() {}, registerImageGenerationProvider() {}, + registerMusicGenerationProvider() {}, registerVideoGenerationProvider() {}, registerWebFetchProvider() {}, registerWebSearchProvider() {}, diff --git a/test/helpers/plugins/plugin-runtime-mock.ts b/test/helpers/plugins/plugin-runtime-mock.ts index 674feaf8d4d..b08330ad1e9 100644 --- a/test/helpers/plugins/plugin-runtime-mock.ts +++ b/test/helpers/plugins/plugin-runtime-mock.ts @@ -162,6 +162,10 @@ export function createPluginRuntimeMock(overrides: DeepPartial = generate: vi.fn() as unknown as PluginRuntime["imageGeneration"]["generate"], listProviders: vi.fn() as unknown as PluginRuntime["imageGeneration"]["listProviders"], }, + musicGeneration: { + generate: vi.fn() as unknown as PluginRuntime["musicGeneration"]["generate"], + listProviders: vi.fn() as unknown as PluginRuntime["musicGeneration"]["listProviders"], + }, videoGeneration: { generate: vi.fn() as unknown as PluginRuntime["videoGeneration"]["generate"], listProviders: vi.fn() as unknown as PluginRuntime["videoGeneration"]["listProviders"], diff --git a/test/helpers/plugins/provider-registration.ts b/test/helpers/plugins/provider-registration.ts index a254e1c3284..cd5d77c657f 100644 --- a/test/helpers/plugins/provider-registration.ts +++ b/test/helpers/plugins/provider-registration.ts @@ -1,6 +1,7 @@ import type { ImageGenerationProviderPlugin, MediaUnderstandingProviderPlugin, + MusicGenerationProviderPlugin, ProviderPlugin, SpeechProviderPlugin, VideoGenerationProviderPlugin, @@ -12,6 +13,7 @@ type RegisteredProviderCollections = { speechProviders: SpeechProviderPlugin[]; mediaProviders: MediaUnderstandingProviderPlugin[]; imageProviders: ImageGenerationProviderPlugin[]; + musicProviders: MusicGenerationProviderPlugin[]; videoProviders: VideoGenerationProviderPlugin[]; }; @@ -28,6 +30,7 @@ export async function registerProviderPlugin(params: { const speechProviders: SpeechProviderPlugin[] = []; const mediaProviders: MediaUnderstandingProviderPlugin[] = []; const imageProviders: ImageGenerationProviderPlugin[] = []; + const musicProviders: MusicGenerationProviderPlugin[] = []; const videoProviders: VideoGenerationProviderPlugin[] = []; await params.plugin.register( @@ -49,13 +52,23 @@ export async function registerProviderPlugin(params: { registerImageGenerationProvider: (provider) => { imageProviders.push(provider); }, + registerMusicGenerationProvider: (provider) => { + musicProviders.push(provider); + }, registerVideoGenerationProvider: (provider) => { videoProviders.push(provider); }, }), ); - return { providers, speechProviders, mediaProviders, imageProviders, videoProviders }; + return { + providers, + speechProviders, + mediaProviders, + imageProviders, + musicProviders, + videoProviders, + }; } export function requireRegisteredProvider(