from __future__ import annotations import argparse import base64 import json import time import uuid from pathlib import Path from typing import Any import requests from shared.config import load_config, require_provider_config, resolve_output_dir def _normalize_base_url(base_url: str) -> str: return base_url.rstrip("/") def _images_endpoint(base_url: str) -> str: normalized = _normalize_base_url(base_url) if normalized.endswith("/images/generations"): return normalized return f"{normalized}/images/generations" def _tasks_endpoint(base_url: str, task_id: str) -> str: normalized = _normalize_base_url(base_url) if normalized.endswith("/v1"): return f"{normalized}/tasks/{task_id}" return f"{normalized}/v1/tasks/{task_id}" def _build_payload(config: dict[str, Any], args_payload: dict[str, Any]) -> dict[str, Any]: provider = config["provider"] generation = config.get("generation", {}) payload: dict[str, Any] = { "model": provider["model"], "prompt": args_payload["prompt"], "size": args_payload.get("size") or generation.get("default_size", "1024x1024"), "n": args_payload.get("n") or generation.get("default_n", 1), "response_format": args_payload.get("response_format") or generation.get("default_response_format", "b64_json"), "quality": args_payload.get("quality") or generation.get("default_quality", "standard"), } if args_payload.get("negative_prompt"): payload["negative_prompt"] = args_payload["negative_prompt"] if args_payload.get("seed") is not None: payload["seed"] = args_payload["seed"] extra_body = args_payload.get("extra_body") if isinstance(extra_body, dict): payload.update(extra_body) return payload def _write_b64_image(output_dir: Path, image_data: str, run_id: str, index: int) -> str: image_bytes = base64.b64decode(image_data) output_path = output_dir / f"{run_id}_{index}.png" output_path.write_bytes(image_bytes) return str(output_path) def _download_image(output_dir: Path, session: requests.Session, image_url: str, run_id: str, index: int) -> str: response = session.get(image_url, timeout=60) response.raise_for_status() output_path = output_dir / f"{run_id}_{index}.png" output_path.write_bytes(response.content) return str(output_path) def _save_images(config: dict[str, Any], response_payload: dict[str, Any], session: requests.Session, run_id: str) -> list[str]: output_dir = resolve_output_dir(config) images: list[str] = [] data = response_payload.get("data") or [] for index, item in enumerate(data, start=1): if item.get("b64_json"): images.append(_write_b64_image(output_dir, item["b64_json"], run_id, index)) continue if item.get("url"): images.append(_download_image(output_dir, session, item["url"], run_id, index)) continue return images def _should_retry_async(response: requests.Response) -> bool: if response.status_code != 400: return False try: payload = response.json() except ValueError: return False message = str(payload.get("errors", {}).get("message", "")).lower() return "does not support synchronous calls" in message and "async" in message def _should_use_async_mode(provider: dict[str, Any]) -> bool: base_url = str(provider.get("base_url", "")).lower() return "api-inference.modelscope.cn" in base_url def _poll_async_task( config: dict[str, Any], provider: dict[str, Any], session: requests.Session, task_id: str, run_id: str, ) -> dict[str, Any]: generation = config.get("generation", {}) poll_interval = float(generation.get("poll_interval_seconds", 3)) timeout_seconds = int(generation.get("timeout_seconds", 300)) started_at = time.monotonic() endpoint = _tasks_endpoint(provider["base_url"], task_id) while True: if time.monotonic() - started_at > timeout_seconds: raise TimeoutError(f"Timed out waiting for task {task_id} after {timeout_seconds} seconds.") response = session.get( endpoint, headers={"X-ModelScope-Task-Type": "image_generation"}, timeout=60, ) response.raise_for_status() payload = response.json() task_status = str(payload.get("task_status", "")).upper() if task_status == "SUCCEED": output_images = payload.get("output_images") or [] images = [ _download_image(output_dir=resolve_output_dir(config), session=session, image_url=image_url, run_id=run_id, index=index) for index, image_url in enumerate(output_images, start=1) ] if not images: raise RuntimeError(f"Task {task_id} succeeded but returned no output_images.") return { "status": "success", "run_id": run_id, "task_id": task_id, "model": provider["model"], "images": images, "raw_response": payload, } if task_status == "FAILED": raise RuntimeError(json.dumps(payload, ensure_ascii=False)) time.sleep(poll_interval) def execute_generation(args_payload: dict[str, Any]) -> dict[str, Any]: config = load_config() provider = require_provider_config(config) payload = _build_payload(config, args_payload) endpoint = _images_endpoint(provider["base_url"]) timeout = int(config.get("generation", {}).get("timeout_seconds", 300)) run_id = str(uuid.uuid4()) with requests.Session() as session: session.headers.update( { "Authorization": f"Bearer {provider['api_key']}", "Content-Type": "application/json", } ) if _should_use_async_mode(provider): async_response = session.post( endpoint, json=payload, headers={"X-ModelScope-Async-Mode": "true"}, timeout=timeout, ) async_response.raise_for_status() async_payload = async_response.json() task_id = async_payload.get("task_id") if not str(task_id or "").strip(): raise RuntimeError("Async image generation did not return a task_id.") return _poll_async_task(config, provider, session, str(task_id), run_id) response = session.post(endpoint, json=payload, timeout=timeout) if _should_retry_async(response): async_response = session.post( endpoint, json=payload, headers={"X-ModelScope-Async-Mode": "true"}, timeout=timeout, ) async_response.raise_for_status() async_payload = async_response.json() task_id = async_payload.get("task_id") if not str(task_id or "").strip(): raise RuntimeError("Async image generation did not return a task_id.") return _poll_async_task(config, provider, session, str(task_id), run_id) response.raise_for_status() response_payload = response.json() images = _save_images(config, response_payload, session, run_id) if not images: raise RuntimeError("Provider returned no downloadable images.") return { "status": "success", "run_id": run_id, "model": provider["model"], "images": images, "raw_response": response_payload, } def _parse_args_json(raw_args: str) -> dict[str, Any]: try: payload = json.loads(raw_args) except json.JSONDecodeError as exc: raise ValueError(f"Invalid JSON in --args: {exc}") from exc if not isinstance(payload, dict): raise ValueError("--args must decode to a JSON object.") if not str(payload.get("prompt", "")).strip(): raise ValueError("The 'prompt' field is required.") return payload def _cmd_run(raw_args: str) -> dict[str, Any]: payload = _parse_args_json(raw_args) return execute_generation(payload) def main() -> None: parser = argparse.ArgumentParser(description="Qwen Image Client for OpenClaw Skill") subparsers = parser.add_subparsers(dest="command") sp_run = subparsers.add_parser("run", help="Generate images from a JSON payload") sp_run.add_argument("--args", required=True, help="JSON string of generation parameters") parser.add_argument("--workflow", help="Legacy compatibility flag. Only qwen/text-to-image is supported.") parser.add_argument("--args", dest="legacy_args", help="Legacy compatibility flag when no subcommand is used.") parsed = parser.parse_args() result: dict[str, Any] if parsed.command == "run": result = _cmd_run(parsed.args) elif parsed.workflow or parsed.legacy_args: workflow = parsed.workflow or "qwen/text-to-image" if workflow != "qwen/text-to-image": raise SystemExit("Only qwen/text-to-image is supported by this skill.") if not parsed.legacy_args: raise SystemExit("--args is required.") result = _cmd_run(parsed.legacy_args) else: parser.print_help() return print(json.dumps(result, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()