|
|
@@ -0,0 +1,267 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import argparse
|
|
|
+import base64
|
|
|
+import json
|
|
|
+import time
|
|
|
+import uuid
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+import requests
|
|
|
+
|
|
|
+from shared.config import load_config, require_provider_config, resolve_output_dir
|
|
|
+
|
|
|
+
|
|
|
+def _normalize_base_url(base_url: str) -> str:
|
|
|
+ return base_url.rstrip("/")
|
|
|
+
|
|
|
+
|
|
|
+def _images_endpoint(base_url: str) -> str:
|
|
|
+ normalized = _normalize_base_url(base_url)
|
|
|
+ if normalized.endswith("/images/generations"):
|
|
|
+ return normalized
|
|
|
+ return f"{normalized}/images/generations"
|
|
|
+
|
|
|
+
|
|
|
+def _tasks_endpoint(base_url: str, task_id: str) -> str:
|
|
|
+ normalized = _normalize_base_url(base_url)
|
|
|
+ if normalized.endswith("/v1"):
|
|
|
+ return f"{normalized}/tasks/{task_id}"
|
|
|
+ return f"{normalized}/v1/tasks/{task_id}"
|
|
|
+
|
|
|
+
|
|
|
+def _build_payload(config: dict[str, Any], args_payload: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ provider = config["provider"]
|
|
|
+ generation = config.get("generation", {})
|
|
|
+
|
|
|
+ payload: dict[str, Any] = {
|
|
|
+ "model": provider["model"],
|
|
|
+ "prompt": args_payload["prompt"],
|
|
|
+ "size": args_payload.get("size") or generation.get("default_size", "1024x1024"),
|
|
|
+ "n": args_payload.get("n") or generation.get("default_n", 1),
|
|
|
+ "response_format": args_payload.get("response_format") or generation.get("default_response_format", "b64_json"),
|
|
|
+ "quality": args_payload.get("quality") or generation.get("default_quality", "standard"),
|
|
|
+ }
|
|
|
+
|
|
|
+ if args_payload.get("negative_prompt"):
|
|
|
+ payload["negative_prompt"] = args_payload["negative_prompt"]
|
|
|
+ if args_payload.get("seed") is not None:
|
|
|
+ payload["seed"] = args_payload["seed"]
|
|
|
+
|
|
|
+ extra_body = args_payload.get("extra_body")
|
|
|
+ if isinstance(extra_body, dict):
|
|
|
+ payload.update(extra_body)
|
|
|
+
|
|
|
+ return payload
|
|
|
+
|
|
|
+
|
|
|
+def _write_b64_image(output_dir: Path, image_data: str, run_id: str, index: int) -> str:
|
|
|
+ image_bytes = base64.b64decode(image_data)
|
|
|
+ output_path = output_dir / f"{run_id}_{index}.png"
|
|
|
+ output_path.write_bytes(image_bytes)
|
|
|
+ return str(output_path)
|
|
|
+
|
|
|
+
|
|
|
+def _download_image(output_dir: Path, session: requests.Session, image_url: str, run_id: str, index: int) -> str:
|
|
|
+ response = session.get(image_url, timeout=60)
|
|
|
+ response.raise_for_status()
|
|
|
+ output_path = output_dir / f"{run_id}_{index}.png"
|
|
|
+ output_path.write_bytes(response.content)
|
|
|
+ return str(output_path)
|
|
|
+
|
|
|
+
|
|
|
+def _save_images(config: dict[str, Any], response_payload: dict[str, Any], session: requests.Session, run_id: str) -> list[str]:
|
|
|
+ output_dir = resolve_output_dir(config)
|
|
|
+ images: list[str] = []
|
|
|
+ data = response_payload.get("data") or []
|
|
|
+
|
|
|
+ for index, item in enumerate(data, start=1):
|
|
|
+ if item.get("b64_json"):
|
|
|
+ images.append(_write_b64_image(output_dir, item["b64_json"], run_id, index))
|
|
|
+ continue
|
|
|
+ if item.get("url"):
|
|
|
+ images.append(_download_image(output_dir, session, item["url"], run_id, index))
|
|
|
+ continue
|
|
|
+
|
|
|
+ return images
|
|
|
+
|
|
|
+
|
|
|
+def _should_retry_async(response: requests.Response) -> bool:
|
|
|
+ if response.status_code != 400:
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = response.json()
|
|
|
+ except ValueError:
|
|
|
+ return False
|
|
|
+
|
|
|
+ message = str(payload.get("errors", {}).get("message", "")).lower()
|
|
|
+ return "does not support synchronous calls" in message and "async" in message
|
|
|
+
|
|
|
+
|
|
|
+def _should_use_async_mode(provider: dict[str, Any]) -> bool:
|
|
|
+ base_url = str(provider.get("base_url", "")).lower()
|
|
|
+ return "api-inference.modelscope.cn" in base_url
|
|
|
+
|
|
|
+
|
|
|
+def _poll_async_task(
|
|
|
+ config: dict[str, Any],
|
|
|
+ provider: dict[str, Any],
|
|
|
+ session: requests.Session,
|
|
|
+ task_id: str,
|
|
|
+ run_id: str,
|
|
|
+) -> dict[str, Any]:
|
|
|
+ generation = config.get("generation", {})
|
|
|
+ poll_interval = float(generation.get("poll_interval_seconds", 3))
|
|
|
+ timeout_seconds = int(generation.get("timeout_seconds", 300))
|
|
|
+ started_at = time.monotonic()
|
|
|
+ endpoint = _tasks_endpoint(provider["base_url"], task_id)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ if time.monotonic() - started_at > timeout_seconds:
|
|
|
+ raise TimeoutError(f"Timed out waiting for task {task_id} after {timeout_seconds} seconds.")
|
|
|
+
|
|
|
+ response = session.get(
|
|
|
+ endpoint,
|
|
|
+ headers={"X-ModelScope-Task-Type": "image_generation"},
|
|
|
+ timeout=60,
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+ payload = response.json()
|
|
|
+ task_status = str(payload.get("task_status", "")).upper()
|
|
|
+
|
|
|
+ if task_status == "SUCCEED":
|
|
|
+ output_images = payload.get("output_images") or []
|
|
|
+ images = [
|
|
|
+ _download_image(output_dir=resolve_output_dir(config), session=session, image_url=image_url, run_id=run_id, index=index)
|
|
|
+ for index, image_url in enumerate(output_images, start=1)
|
|
|
+ ]
|
|
|
+ if not images:
|
|
|
+ raise RuntimeError(f"Task {task_id} succeeded but returned no output_images.")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "run_id": run_id,
|
|
|
+ "task_id": task_id,
|
|
|
+ "model": provider["model"],
|
|
|
+ "images": images,
|
|
|
+ "raw_response": payload,
|
|
|
+ }
|
|
|
+
|
|
|
+ if task_status == "FAILED":
|
|
|
+ raise RuntimeError(json.dumps(payload, ensure_ascii=False))
|
|
|
+
|
|
|
+ time.sleep(poll_interval)
|
|
|
+
|
|
|
+
|
|
|
+def execute_generation(args_payload: dict[str, Any]) -> dict[str, Any]:
|
|
|
+ config = load_config()
|
|
|
+ provider = require_provider_config(config)
|
|
|
+ payload = _build_payload(config, args_payload)
|
|
|
+ endpoint = _images_endpoint(provider["base_url"])
|
|
|
+ timeout = int(config.get("generation", {}).get("timeout_seconds", 300))
|
|
|
+
|
|
|
+ run_id = str(uuid.uuid4())
|
|
|
+ with requests.Session() as session:
|
|
|
+ session.headers.update(
|
|
|
+ {
|
|
|
+ "Authorization": f"Bearer {provider['api_key']}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ if _should_use_async_mode(provider):
|
|
|
+ async_response = session.post(
|
|
|
+ endpoint,
|
|
|
+ json=payload,
|
|
|
+ headers={"X-ModelScope-Async-Mode": "true"},
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ async_response.raise_for_status()
|
|
|
+ async_payload = async_response.json()
|
|
|
+ task_id = async_payload.get("task_id")
|
|
|
+ if not str(task_id or "").strip():
|
|
|
+ raise RuntimeError("Async image generation did not return a task_id.")
|
|
|
+ return _poll_async_task(config, provider, session, str(task_id), run_id)
|
|
|
+
|
|
|
+ response = session.post(endpoint, json=payload, timeout=timeout)
|
|
|
+
|
|
|
+ if _should_retry_async(response):
|
|
|
+ async_response = session.post(
|
|
|
+ endpoint,
|
|
|
+ json=payload,
|
|
|
+ headers={"X-ModelScope-Async-Mode": "true"},
|
|
|
+ timeout=timeout,
|
|
|
+ )
|
|
|
+ async_response.raise_for_status()
|
|
|
+ async_payload = async_response.json()
|
|
|
+ task_id = async_payload.get("task_id")
|
|
|
+ if not str(task_id or "").strip():
|
|
|
+ raise RuntimeError("Async image generation did not return a task_id.")
|
|
|
+ return _poll_async_task(config, provider, session, str(task_id), run_id)
|
|
|
+
|
|
|
+ response.raise_for_status()
|
|
|
+ response_payload = response.json()
|
|
|
+
|
|
|
+ images = _save_images(config, response_payload, session, run_id)
|
|
|
+ if not images:
|
|
|
+ raise RuntimeError("Provider returned no downloadable images.")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "status": "success",
|
|
|
+ "run_id": run_id,
|
|
|
+ "model": provider["model"],
|
|
|
+ "images": images,
|
|
|
+ "raw_response": response_payload,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _parse_args_json(raw_args: str) -> dict[str, Any]:
|
|
|
+ try:
|
|
|
+ payload = json.loads(raw_args)
|
|
|
+ except json.JSONDecodeError as exc:
|
|
|
+ raise ValueError(f"Invalid JSON in --args: {exc}") from exc
|
|
|
+
|
|
|
+ if not isinstance(payload, dict):
|
|
|
+ raise ValueError("--args must decode to a JSON object.")
|
|
|
+ if not str(payload.get("prompt", "")).strip():
|
|
|
+ raise ValueError("The 'prompt' field is required.")
|
|
|
+ return payload
|
|
|
+
|
|
|
+
|
|
|
+def _cmd_run(raw_args: str) -> dict[str, Any]:
|
|
|
+ payload = _parse_args_json(raw_args)
|
|
|
+ return execute_generation(payload)
|
|
|
+
|
|
|
+
|
|
|
+def main() -> None:
|
|
|
+ parser = argparse.ArgumentParser(description="Qwen Image Client for OpenClaw Skill")
|
|
|
+ subparsers = parser.add_subparsers(dest="command")
|
|
|
+
|
|
|
+ sp_run = subparsers.add_parser("run", help="Generate images from a JSON payload")
|
|
|
+ sp_run.add_argument("--args", required=True, help="JSON string of generation parameters")
|
|
|
+
|
|
|
+ parser.add_argument("--workflow", help="Legacy compatibility flag. Only qwen/text-to-image is supported.")
|
|
|
+ parser.add_argument("--args", dest="legacy_args", help="Legacy compatibility flag when no subcommand is used.")
|
|
|
+ parsed = parser.parse_args()
|
|
|
+
|
|
|
+ result: dict[str, Any]
|
|
|
+ if parsed.command == "run":
|
|
|
+ result = _cmd_run(parsed.args)
|
|
|
+ elif parsed.workflow or parsed.legacy_args:
|
|
|
+ workflow = parsed.workflow or "qwen/text-to-image"
|
|
|
+ if workflow != "qwen/text-to-image":
|
|
|
+ raise SystemExit("Only qwen/text-to-image is supported by this skill.")
|
|
|
+ if not parsed.legacy_args:
|
|
|
+ raise SystemExit("--args is required.")
|
|
|
+ result = _cmd_run(parsed.legacy_args)
|
|
|
+ else:
|
|
|
+ parser.print_help()
|
|
|
+ return
|
|
|
+
|
|
|
+ print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|