From 81e8bb62ae0d008f7b26f13655da6d30fa074c6f Mon Sep 17 00:00:00 2001 From: zhayujie Date: Wed, 22 Apr 2026 20:39:49 +0800 Subject: [PATCH] feat(skill): support gpt-image-2 in image generation skill --- agent/protocol/agent_stream.py | 11 +- agent/tools/bash/bash.py | 14 +- app.py | 36 ++ bridge/agent_bridge.py | 2 +- channel/web/static/js/console.js | 62 ++- channel/web/web_channel.py | 17 +- cli/commands/skill.py | 58 ++- skills/image-generation/SKILL.md | 124 +++++ skills/image-generation/scripts/generate.py | 503 ++++++++++++++++++++ 9 files changed, 794 insertions(+), 33 deletions(-) create mode 100644 skills/image-generation/SKILL.md create mode 100644 skills/image-generation/scripts/generate.py diff --git a/agent/protocol/agent_stream.py b/agent/protocol/agent_stream.py index a0b0af5c..547603f0 100644 --- a/agent/protocol/agent_stream.py +++ b/agent/protocol/agent_stream.py @@ -330,13 +330,18 @@ class AgentStreamExecutor: }) break - # Log tool calls with arguments + # Log tool calls with arguments (truncate long values like base64) tool_calls_str = [] for tc in tool_calls: - # Safely handle None or missing arguments args = tc.get('arguments') or {} if isinstance(args, dict): - args_str = ', '.join([f"{k}={v}" for k, v in args.items()]) + parts = [] + for k, v in args.items(): + v_str = str(v) + if len(v_str) > 200: + v_str = v_str[:200] + f"...({len(v_str)} chars)" + parts.append(f"{k}={v_str}") + args_str = ', '.join(parts) if args_str: tool_calls_str.append(f"{tc['name']}({args_str})") else: diff --git a/agent/tools/bash/bash.py b/agent/tools/bash/bash.py index dd9114ed..312f5fb5 100644 --- a/agent/tools/bash/bash.py +++ b/agent/tools/bash/bash.py @@ -169,10 +169,16 @@ SAFETY: except Exception as retry_err: logger.warning(f"[Bash] Retry failed: {retry_err}") - # Combine stdout and stderr - output = result.stdout - if result.stderr: - output += "\n" + result.stderr + # When command succeeds with stdout, keep output clean (stderr goes to server log only). + # When command fails or stdout is empty, include stderr so the agent can diagnose. + if result.returncode == 0 and result.stdout.strip(): + output = result.stdout + if result.stderr: + logger.info(f"[Bash] stderr (not forwarded): {result.stderr[:500]}") + else: + output = result.stdout + if result.stderr: + output += "\n" + result.stderr # Check if we need to save full output to temp file temp_file_path = None diff --git a/app.py b/app.py index d401a00b..4503ac67 100644 --- a/app.py +++ b/app.py @@ -274,6 +274,39 @@ def sigterm_handler_wrap(_signo): signal.signal(_signo, func) +def _sync_builtin_skills(): + """Sync builtin skills from project skills/ to workspace skills/ on startup.""" + import shutil + try: + workspace = conf().get("agent_workspace", "~/cow") + workspace = os.path.expanduser(workspace) + project_root = os.path.dirname(os.path.abspath(__file__)) + builtin_dir = os.path.join(project_root, "skills") + custom_dir = os.path.join(workspace, "skills") + + if not os.path.isdir(builtin_dir): + return + + os.makedirs(custom_dir, exist_ok=True) + synced = 0 + for name in os.listdir(builtin_dir): + src = os.path.join(builtin_dir, name) + if not os.path.isdir(src) or not os.path.isfile(os.path.join(src, "SKILL.md")): + continue + dst = os.path.join(custom_dir, name) + try: + if os.path.isdir(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + synced += 1 + except Exception as e: + logger.warning(f"[App] Failed to sync builtin skill '{name}': {e}") + if synced: + logger.info(f"[App] Synced {synced} builtin skill(s) to workspace") + except Exception as e: + logger.warning(f"[App] Builtin skills sync failed: {e}") + + def run(): global _channel_mgr try: @@ -299,6 +332,9 @@ def run(): if web_console_enabled and "web" not in channel_names: channel_names.append("web") + # Sync builtin skills to workspace before channels start + _sync_builtin_skills() + logger.info(f"[App] Starting channels: {channel_names}") _channel_mgr = ChannelManager() diff --git a/bridge/agent_bridge.py b/bridge/agent_bridge.py index c72e30a7..2aac1221 100644 --- a/bridge/agent_bridge.py +++ b/bridge/agent_bridge.py @@ -446,7 +446,7 @@ class AgentBridge: except Exception as e: logger.warning(f"[AgentBridge] Failed to clear DB after recovery: {e}") - # Check if there are files to send (from read tool) + # Check if there are files to send (from send/read tool) if hasattr(agent, 'stream_executor') and hasattr(agent.stream_executor, 'files_to_send'): files_to_send = agent.stream_executor.files_to_send if files_to_send: diff --git a/channel/web/static/js/console.js b/channel/web/static/js/console.js index c0ae38a2..ae155524 100644 --- a/channel/web/static/js/console.js +++ b/channel/web/static/js/console.js @@ -341,23 +341,35 @@ const md = createMd(); const VIDEO_EXT_RE = /\.(?:mp4|webm|mov|avi|mkv)$/i; // tested against URL without query string const IMAGE_EXT_RE = /\.(?:jpg|jpeg|png|gif|webp|bmp|svg)$/i; // tested against URL without query string +function _toWebUrl(url) { + if (/^\/[A-Za-z]/.test(url) && !url.startsWith('/api/')) { + return '/api/file?path=' + encodeURIComponent(url); + } + if (/^file:\/\/\//i.test(url)) { + return '/api/file?path=' + encodeURIComponent(url.replace(/^file:\/\/\//i, '/')); + } + return url; +} + function _buildVideoHtml(url) { + const webUrl = _toWebUrl(url); const fileName = url.split('/').pop().split('?')[0]; return `
` + `` + - `` + + `` + ` ${escapeHtml(fileName)}
`; } function _buildImageHtml(url) { - const safeUrl = url.replace(/"/g, '"'); + const webUrl = _toWebUrl(url); + const safeUrl = webUrl.replace(/"/g, '"'); return `
` + `image` + + `style="max-width:520px;width:100%;border-radius:10px;box-shadow:0 2px 8px rgba(0,0,0,0.15);display:block;cursor:pointer;">` + `
`; } @@ -400,9 +412,20 @@ function injectImagePreviews(html) { }).join(''); } +function _rewriteLocalImgSrc(html) { + return html.replace(/]*?)src="([^"]+)"/gi, (match, pre, src) => { + const webSrc = _toWebUrl(src); + if (webSrc !== src) { + return `` : ''} `; + // If this tool sent a file (send/read tool), render the media inline + // so it persists across page refreshes (SSE-only file events are not stored). + const mediaHtml = _renderSentFileFromToolResult(step); + if (mediaHtml) html += mediaHtml; } } return { stepsHtml: html, lastContentText }; } +// Extract file-to-send metadata from a tool's result and render an inline preview. +// Returns '' if the result isn't a file_to_send payload. +function _renderSentFileFromToolResult(step) { + if (!step || !step.result) return ''; + let payload; + try { + payload = typeof step.result === 'string' ? JSON.parse(step.result) : step.result; + } catch (_) { return ''; } + if (!payload || payload.type !== 'file_to_send' || !payload.path) return ''; + const webUrl = _toWebUrl(payload.path); + const fileType = payload.file_type || 'file'; + const fileName = payload.file_name || payload.path.split('/').pop(); + if (fileType === 'image') { + return `
${_buildImageHtml(webUrl)}
`; + } + if (fileType === 'video') { + return `
${_buildVideoHtml(webUrl)}
`; + } + return ``; +} + function createBotMessageEl(content, timestamp, requestId, msg) { const el = document.createElement('div'); el.className = 'flex gap-3 px-4 sm:px-6 py-3'; diff --git a/channel/web/web_channel.py b/channel/web/web_channel.py index ce60ac62..c19beb9b 100644 --- a/channel/web/web_channel.py +++ b/channel/web/web_channel.py @@ -208,9 +208,24 @@ class WebChannel(ChatChannel): # Fallback: polling mode if session_id in self.session_queues: + content = reply.content if reply.content is not None else "" + # Skip file:// IMAGE_URL/FILE replies originating from an SSE-enabled + # request: they were already pushed via the `file_to_send` event during + # agent execution. By the time the chat_channel sends the IMAGE_URL reply, + # the SSE stream has typically closed (after the text "done") and the + # request_id is gone from sse_queues, so we'd otherwise duplicate the file + # as a polling bubble. Scheduler/push tasks have no on_event and must + # still go through polling normally. + if ( + reply.type in (ReplyType.IMAGE_URL, ReplyType.FILE) + and content.startswith("file://") + and context.get("on_event") is not None + ): + logger.debug(f"Polling skipped duplicate file reply for session {session_id}") + return response_data = { "type": str(reply.type), - "content": reply.content, + "content": content, "timestamp": time.time(), "request_id": request_id } diff --git a/cli/commands/skill.py b/cli/commands/skill.py index 23005d6a..a591ed9c 100644 --- a/cli/commands/skill.py +++ b/cli/commands/skill.py @@ -644,32 +644,52 @@ def _list_local(): skills_dir = get_skills_dir() builtin_dir = get_builtin_skills_dir() + # Merge builtin skills that are on disk but missing from config + _merge_builtin_into_config(config, builtin_dir, skills_dir) + if not config: - # Fallback: scan directories directly - entries = [] - for d in [builtin_dir, skills_dir]: - if not os.path.isdir(d): - continue - source = "builtin" if d == builtin_dir else "custom" - for name in sorted(os.listdir(d)): - skill_path = os.path.join(d, name) - if os.path.isdir(skill_path) and not name.startswith("."): - has_skill_md = os.path.exists(os.path.join(skill_path, "SKILL.md")) - if has_skill_md: - entries.append({"name": name, "source": source, "enabled": True, "description": ""}) - if not entries: - click.echo("No skills installed.") - return - _print_skill_table(entries) + click.echo("No skills installed.") return entries = sorted(config.values(), key=lambda x: x.get("name", "")) - if not entries: - click.echo("No skills installed.") - return _print_skill_table(entries) +def _merge_builtin_into_config(config: dict, builtin_dir: str, skills_dir: str): + """Scan builtin and custom dirs, add any new skills into config dict.""" + dirty = False + for d, source in [(builtin_dir, "builtin"), (skills_dir, "custom")]: + if not os.path.isdir(d): + continue + for name in os.listdir(d): + if name.startswith(".") or name in ("skills_config.json",): + continue + skill_path = os.path.join(d, name) + if not os.path.isdir(skill_path): + continue + if not os.path.isfile(os.path.join(skill_path, "SKILL.md")): + continue + if name in config: + continue + desc = _read_skill_description(skill_path) + config[name] = { + "name": name, + "description": desc, + "source": source, + "enabled": True, + "category": "skill", + } + dirty = True + if dirty: + config_path = os.path.join(skills_dir, "skills_config.json") + try: + os.makedirs(skills_dir, exist_ok=True) + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=4, ensure_ascii=False) + except Exception: + pass + + def _print_skill_table(entries): """Print skills as a formatted table.""" def _display_label(e): diff --git a/skills/image-generation/SKILL.md b/skills/image-generation/SKILL.md new file mode 100644 index 00000000..0195d7c7 --- /dev/null +++ b/skills/image-generation/SKILL.md @@ -0,0 +1,124 @@ +--- +name: image-generation +description: Generate or edit images from text prompts. Use when the user asks to create, draw, design, or edit an image, illustration, photo, icon, poster, or any visual content. +metadata: + cowagent: + requires: + anyEnv: + - OPENAI_API_KEY + - LINKAI_API_KEY +--- + +# Image Generation + +Generate and edit images using AI models (GPT-Image-2, GPT-Image-1, etc.). + +## Usage + +Run `scripts/generate.py` with a JSON argument. The path is relative to this skill's `base_dir`. + +```bash +python /scripts/generate.py '' +``` + +**Set bash timeout to at least 300 seconds**, as image generation can take 30–200s depending on quality/size. + +### Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `prompt` | string | yes | — | Image description | +| `model` | string | no | `gpt-image-2` | Model name (`gpt-image-2`, `gpt-image-1`) | +| `image_url` | string / list | no | null | Input image(s) for editing: local file path or URL | +| `quality` | string | no | auto | `low` / `medium` / `high`; omit to let the model choose | +| `size` | string | no | auto | `1K`/`2K`/`4K`, pixel value (`1024x1024`), or omit to let the model choose | +| `aspect_ratio` | string | no | null | `1:1` / `3:2` / `2:3` / `16:9` / `9:16` | + +### Example — generate + +```bash +python /scripts/generate.py '{"prompt": "A corgi astronaut floating in space"}' +``` + +With explicit quality/size: + +```bash +python /scripts/generate.py '{"prompt": "A corgi astronaut", "quality": "low", "size": "1K", "aspect_ratio": "1:1"}' +``` + +### Important: Editing vs Generating + +When the user asks to **edit, modify, or improve an existing image**, you need to pass the original image via `image_url`. Prefer passing **local file paths** directly — the script handles file reading internally. Without `image_url`, the script generates a brand-new image instead of editing. + +### Example — edit (image-to-image) + +Local file (preferred): + +```bash +python /scripts/generate.py '{"prompt": "Add a Santa hat to the dog", "image_url": "/path/to/dog.png"}' +``` + +URL: + +```bash +python /scripts/generate.py '{"prompt": "Make the background blue", "image_url": "https://example.com/photo.png"}' +``` + +### Output + +Prints JSON to stdout: + +```json +{ + "images": [ + {"url": "/path/to/output.png"} + ] +} +``` + +After success, display the image to the user. You can either embed it in markdown (`![description](/path/to/output.png)`) or use the `send` tool. + +On error: + +```json +{ + "error": "error message" +} +``` + +### Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `OPENAI_API_KEY` | yes (unless using LinkAI) | OpenAI API key | +| `OPENAI_API_BASE` | no | Custom API base URL (default: `https://api.openai.com/v1`) | +| `LINKAI_API_KEY` | alt | LinkAI API key (used when `OPENAI_API_KEY` is absent) | +| `LINKAI_API_BASE` | no | LinkAI API base URL | + +### Size + Aspect Ratio Resolution + +`size` and `aspect_ratio` are combined to determine the actual pixel dimensions: + +| size | aspect_ratio | pixels | +|------|-------------|--------| +| `1K` | `1:1` | 1024×1024 | +| `1K` | `3:2` | 1536×1024 | +| `1K` | `2:3` | 1024×1536 | +| `2K` | `1:1` | 2048×2048 | +| `2K` | `16:9` | 2048×1152 | +| `2K` | `9:16` | 1152×2048 | +| `4K` | `16:9` | 3840×2160 | +| `4K` | `9:16` | 2160×3840 | + +When an exact match isn't found, the script tries: exact match → upgrade to higher tier with same ratio → cross-tier match by ratio → tier default. + +### Error Handling + +The script internally tries all available providers (OpenAI → LinkAI) in sequence. If it returns an error, **do NOT retry with the same or similar parameters** — the failure is a configuration issue (wrong API key, unsupported API base, etc.), not a transient error. Instead, inform the user about the configuration problem and ask them to fix it (e.g. set the correct `OPENAI_API_KEY` / `OPENAI_API_BASE` via `env_config`), then retry after the configuration is updated. + +### Notes + +- HTTP timeout is 300s — high-resolution + high-quality generation can take over 200s. +- When `quality` and `size` are omitted, the API uses `auto` — the model picks the best quality/size based on the prompt. +- `quality=low` + `size=1K` is the fastest combination (~20s). Use when speed matters more than fidelity. +- Input images for editing are auto-compressed to ≤ 4MB / longest edge ≤ 4096px. diff --git a/skills/image-generation/scripts/generate.py b/skills/image-generation/scripts/generate.py new file mode 100644 index 00000000..99e8d24d --- /dev/null +++ b/skills/image-generation/scripts/generate.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +""" +Unified image generation script. + +Usage: + python generate.py '' + +Supports GPT-Image-2 / GPT-Image-1 via the OpenAI-compatible Images API. +Designed for easy extension to other providers (Gemini, etc.). + +Dependencies: requests (stdlib: json, sys, os, base64, io, abc, uuid, pathlib, urllib) +""" + +import json +import sys +import os +import base64 +import io +import uuid +import re +from abc import ABC, abstractmethod +from pathlib import Path +from urllib.request import urlopen, Request +from urllib.parse import urlparse +from urllib.error import URLError + +try: + import requests + + _HAS_REQUESTS = True +except ImportError: + _HAS_REQUESTS = False + + +# --------------------------------------------------------------------------- +# Size / aspect-ratio resolution +# --------------------------------------------------------------------------- + +_SIZE_TABLE = { + # (tier, ratio) -> "WxH" + ("1K", "1:1"): "1024x1024", + ("1K", "3:2"): "1536x1024", + ("1K", "2:3"): "1024x1536", + ("2K", "1:1"): "2048x2048", + ("2K", "16:9"): "2048x1152", + ("2K", "9:16"): "1152x2048", + ("4K", "16:9"): "3840x2160", + ("4K", "9:16"): "2160x3840", +} + +_TIER_ORDER = ["1K", "2K", "4K"] +_RATIO_DEFAULT = {"1K": "1:1", "2K": "1:1", "4K": "16:9"} + +_PIXEL_RE = re.compile(r"^\d+x\d+$") + + +def resolve_size(size: str | None, aspect_ratio: str | None) -> str | None: + """Resolve (size, aspect_ratio) to a concrete 'WxH' string or None.""" + if size and _PIXEL_RE.match(size): + return size + if size and size.lower() == "auto": + size = None + if not size and not aspect_ratio: + return None + + tier = size.upper() if size else None + ratio = aspect_ratio + + if tier and ratio: + key = (tier, ratio) + if key in _SIZE_TABLE: + return _SIZE_TABLE[key] + # Upgrade: try higher tiers with same ratio + start = _TIER_ORDER.index(tier) + 1 if tier in _TIER_ORDER else 0 + for t in _TIER_ORDER[start:]: + if (t, ratio) in _SIZE_TABLE: + return _SIZE_TABLE[(t, ratio)] + # Cross-tier: any tier with this ratio + for t in _TIER_ORDER: + if (t, ratio) in _SIZE_TABLE: + return _SIZE_TABLE[(t, ratio)] + # Tier default + if tier in _RATIO_DEFAULT: + return _SIZE_TABLE.get((tier, _RATIO_DEFAULT[tier])) + + if tier and not ratio: + default_ratio = _RATIO_DEFAULT.get(tier) + if default_ratio: + return _SIZE_TABLE.get((tier, default_ratio)) + + if ratio and not tier: + for t in _TIER_ORDER: + if (t, ratio) in _SIZE_TABLE: + return _SIZE_TABLE[(t, ratio)] + + return None + + +# --------------------------------------------------------------------------- +# Image helpers +# --------------------------------------------------------------------------- + +def _load_image(source: str) -> bytes: + """Load image from a local file path or URL.""" + if os.path.isfile(source): + with open(source, "rb") as f: + return f.read() + if _HAS_REQUESTS: + resp = requests.get(source, timeout=60) + resp.raise_for_status() + return resp.content + req = Request(source) + with urlopen(req, timeout=60) as resp: + return resp.read() + + +def _compress_image(data: bytes, max_bytes: int = 4 * 1024 * 1024, max_edge: int = 4096) -> bytes: + """Compress image to fit size/dimension limits. Requires Pillow only when needed.""" + if len(data) <= max_bytes: + try: + from PIL import Image + + img = Image.open(io.BytesIO(data)) + w, h = img.size + if max(w, h) <= max_edge: + return data + except ImportError: + return data + except Exception: + return data + + try: + from PIL import Image + except ImportError: + return data + + img = Image.open(io.BytesIO(data)) + w, h = img.size + + if max(w, h) > max_edge: + ratio = max_edge / max(w, h) + w, h = int(w * ratio), int(h * ratio) + img = img.resize((w, h), Image.LANCZOS) + + buf = io.BytesIO() + fmt = img.format or "PNG" + if fmt.upper() == "JPEG": + quality = 85 + while True: + buf.seek(0) + buf.truncate() + img.save(buf, format="JPEG", quality=quality) + if buf.tell() <= max_bytes or quality <= 20: + break + quality -= 10 + else: + img.save(buf, format=fmt) + if buf.tell() > max_bytes: + buf.seek(0) + buf.truncate() + img.save(buf, format="JPEG", quality=75) + return buf.getvalue() + + +def _save_image(data: bytes, output_dir: str) -> str: + """Save image bytes to output_dir and return the path.""" + os.makedirs(output_dir, exist_ok=True) + ext = "png" + if data[:3] == b"\xff\xd8\xff": + ext = "jpg" + elif data[:4] == b"RIFF": + ext = "webp" + filename = f"{uuid.uuid4().hex[:12]}.{ext}" + path = os.path.join(output_dir, filename) + with open(path, "wb") as f: + f.write(data) + return path + + +# --------------------------------------------------------------------------- +# Provider interface +# --------------------------------------------------------------------------- + +class ImageProvider(ABC): + """Abstract base class for image generation providers.""" + + @abstractmethod + def generate( + self, + prompt: str, + *, + image_url: str | list | None = None, + quality: str | None = None, + size: str | None = None, + output_dir: str = ".", + ) -> list[str]: + """Generate image(s) and return list of local file paths.""" + ... + + +# --------------------------------------------------------------------------- +# OpenAI-compatible provider (gpt-image-2, gpt-image-1) +# --------------------------------------------------------------------------- + +class OpenAIProvider(ImageProvider): + """Provider for OpenAI Image API (generations + edits).""" + + def __init__(self, api_key: str, api_base: str, model: str): + self.api_key = api_key + self.api_base = api_base.rstrip("/") + self.model = model + + def _headers(self) -> dict: + return { + "Authorization": f"Bearer {self.api_key}", + } + + @staticmethod + def _raise_for_api_error(resp): + """Raise with server error details instead of bare HTTP status.""" + if resp.status_code >= 400: + try: + body = resp.json() + msg = body.get("error", {}).get("message") or body.get("message") or resp.text + except Exception: + msg = resp.text or resp.reason + raise RuntimeError(f"API {resp.status_code}: {msg} (url: {resp.url})") + + def _post_json(self, url: str, payload: dict) -> dict: + headers = {**self._headers(), "Content-Type": "application/json"} + if _HAS_REQUESTS: + resp = requests.post(url, headers=headers, json=payload, timeout=300) + self._raise_for_api_error(resp) + return resp.json() + data = json.dumps(payload).encode() + req = Request(url, data=data, headers=headers, method="POST") + with urlopen(req, timeout=300) as r: + return json.loads(r.read()) + + def _post_multipart(self, url: str, fields: dict, files: list[tuple]) -> dict: + """POST multipart/form-data using requests (or fall back to urllib).""" + headers = self._headers() + if _HAS_REQUESTS: + resp = requests.post(url, headers=headers, data=fields, files=files, timeout=300) + self._raise_for_api_error(resp) + return resp.json() + boundary = uuid.uuid4().hex + body = b"" + for key, val in fields.items(): + body += f"--{boundary}\r\nContent-Disposition: form-data; name=\"{key}\"\r\n\r\n{val}\r\n".encode() + for field_name, (filename, filedata, content_type) in files: + body += ( + f"--{boundary}\r\n" + f"Content-Disposition: form-data; name=\"{field_name}\"; filename=\"{filename}\"\r\n" + f"Content-Type: {content_type}\r\n\r\n" + ).encode() + filedata + b"\r\n" + body += f"--{boundary}--\r\n".encode() + headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" + req = Request(url, data=body, headers=headers, method="POST") + with urlopen(req, timeout=300) as r: + return json.loads(r.read()) + + def generate( + self, + prompt: str, + *, + image_url=None, + quality: str | None = None, + size: str | None = None, + output_dir: str = ".", + ) -> list[str]: + if image_url: + return self._edit(prompt, image_url=image_url, quality=quality, size=size, output_dir=output_dir) + return self._create(prompt, quality=quality, size=size, output_dir=output_dir) + + def _create(self, prompt: str, *, quality: str | None, size: str | None, output_dir: str) -> list[str]: + url = f"{self.api_base}/images/generations" + payload: dict = { + "model": self.model, + "prompt": prompt, + } + if quality: + payload["quality"] = quality + if size: + payload["size"] = size + result = self._post_json(url, payload) + return self._save_results(result, output_dir) + + def _edit( + self, + prompt: str, + *, + image_url, + quality: str | None, + size: str | None, + output_dir: str, + ) -> list[str]: + urls = image_url if isinstance(image_url, list) else [image_url] + image_data_list = [_compress_image(_load_image(u)) for u in urls] + + url = f"{self.api_base}/images/edits" + + fields = {"model": self.model, "prompt": prompt} + if quality: + fields["quality"] = quality + if size: + fields["size"] = size + + files = [] + for i, img_bytes in enumerate(image_data_list): + ext = "png" + if img_bytes[:3] == b"\xff\xd8\xff": + ext = "jpg" + field_name = "image[]" if len(image_data_list) > 1 else "image" + files.append((field_name, (f"image_{i}.{ext}", img_bytes, f"image/{ext}"))) + + result = self._post_multipart(url, fields, files) + return self._save_results(result, output_dir) + + @staticmethod + def _save_results(result: dict, output_dir: str) -> list[str]: + paths = [] + for item in result.get("data", []): + if "b64_json" in item: + raw = base64.b64decode(item["b64_json"]) + paths.append(_save_image(raw, output_dir)) + elif "url" in item: + raw = _load_image(item["url"]) + paths.append(_save_image(raw, output_dir)) + return paths + + +# --------------------------------------------------------------------------- +# LinkAI provider (uses unified /v1/images/generations) +# --------------------------------------------------------------------------- + +class LinkAIProvider(ImageProvider): + """Provider for LinkAI unified image generation API.""" + + def __init__(self, api_key: str, api_base: str, model: str): + self.api_key = api_key + self.api_base = api_base.rstrip("/") + self.model = model + + def generate( + self, + prompt: str, + *, + image_url=None, + quality: str | None = None, + size: str | None = None, + output_dir: str = ".", + ) -> list[str]: + url = f"{self.api_base}/v1/images/generations" + payload: dict = { + "model": self.model, + "prompt": prompt, + } + if quality: + payload["quality"] = quality + if size: + payload["size"] = size + if image_url: + urls = image_url if isinstance(image_url, list) else [image_url] + resolved = [] + for u in urls: + if os.path.isfile(u): + data = _load_image(u) + ext = u.rsplit(".", 1)[-1].lower() if "." in u else "png" + mime = {"jpg": "image/jpeg", "jpeg": "image/jpeg", "webp": "image/webp"}.get(ext, "image/png") + resolved.append(f"data:{mime};base64,{base64.b64encode(data).decode()}") + else: + resolved.append(u) + payload["image_url"] = resolved if len(resolved) > 1 else resolved[0] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + if _HAS_REQUESTS: + resp = requests.post(url, headers=headers, json=payload, timeout=300) + if resp.status_code >= 400: + try: + body = resp.json() + msg = body.get("error", {}).get("message") or body.get("message") or resp.text + except Exception: + msg = resp.text or resp.reason + raise RuntimeError(f"API {resp.status_code}: {msg}") + result = resp.json() + else: + data = json.dumps(payload).encode() + req = Request(url, data=data, headers=headers, method="POST") + with urlopen(req, timeout=300) as r: + result = json.loads(r.read()) + + if "error" in result: + raise RuntimeError(result["error"].get("message", str(result["error"]))) + + paths = [] + for item in result.get("data", []): + if "url" in item: + raw = _load_image(item["url"]) + paths.append(_save_image(raw, output_dir)) + elif "b64_json" in item: + raw = base64.b64decode(item["b64_json"]) + paths.append(_save_image(raw, output_dir)) + return paths + + +# --------------------------------------------------------------------------- +# Provider factory +# --------------------------------------------------------------------------- + +def _build_providers(model: str) -> list[tuple[str, ImageProvider]]: + """Build an ordered list of (label, provider) to try.""" + openai_key = os.environ.get("OPENAI_API_KEY", "") + openai_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") + linkai_key = os.environ.get("LINKAI_API_KEY", "") + linkai_base = os.environ.get("LINKAI_API_BASE", "https://api.link-ai.tech") + + providers = [] + if openai_key: + providers.append(("OpenAI", OpenAIProvider(api_key=openai_key, api_base=openai_base, model=model))) + if linkai_key: + providers.append(("LinkAI", LinkAIProvider(api_key=linkai_key, api_base=linkai_base, model=model))) + return providers + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + if len(sys.argv) < 2: + print(json.dumps({"error": "Usage: python generate.py ''"})) + sys.exit(1) + + try: + args = json.loads(sys.argv[1]) + except json.JSONDecodeError as e: + print(json.dumps({"error": f"Invalid JSON: {e}"})) + sys.exit(1) + + prompt = args.get("prompt") + if not prompt: + print(json.dumps({"error": "Missing required parameter: prompt"})) + sys.exit(1) + + model = args.get("model", "gpt-image-2") + quality = args.get("quality") + raw_size = args.get("size") + aspect_ratio = args.get("aspect_ratio") + image_url = args.get("image_url") + + resolved_size = resolve_size(raw_size, aspect_ratio) + + output_dir = os.environ.get("IMAGE_OUTPUT_DIR", os.path.join(os.getcwd(), "images")) + + providers = _build_providers(model) + if not providers: + print(json.dumps({ + "error": "No API key configured. Please set OPENAI_API_KEY or LINKAI_API_KEY via env_config tool, then try again." + }, ensure_ascii=False)) + sys.exit(1) + + import time + + errors = [] + for label, provider in providers: + try: + print(f"[image-generation] Trying {label} (model={model})...", file=sys.stderr) + t0 = time.time() + paths = provider.generate( + prompt, + image_url=image_url, + quality=quality, + size=resolved_size, + output_dir=output_dir, + ) + elapsed = time.time() - t0 + print(f"[image-generation] ✅ {label} succeeded in {elapsed:.1f}s", file=sys.stderr) + result = {"images": [{"url": p} for p in paths]} + print(json.dumps(result, ensure_ascii=False)) + return + except Exception as e: + elapsed = time.time() - t0 + print(f"[image-generation] ❌ {label} failed in {elapsed:.1f}s: {e}", file=sys.stderr) + errors.append(f"{label}: {e}") + + hint = " | ".join(errors) + print(json.dumps({ + "error": f"All providers failed — {hint}. " + "This is likely an API key or base URL configuration issue. " + "Do NOT retry with the same parameters. " + "Ask the user to verify their OPENAI_API_KEY / OPENAI_API_BASE " + "(or LINKAI_API_KEY / LINKAI_API_BASE) settings via env_config." + }, ensure_ascii=False)) + sys.exit(1) + + +if __name__ == "__main__": + main()