| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267 |
- from __future__ import annotations
- import argparse
- import base64
- import json
- import time
- import uuid
- from pathlib import Path
- from typing import Any
- import requests
- from shared.config import load_config, require_provider_config, resolve_output_dir
- def _normalize_base_url(base_url: str) -> str:
- return base_url.rstrip("/")
- def _images_endpoint(base_url: str) -> str:
- normalized = _normalize_base_url(base_url)
- if normalized.endswith("/images/generations"):
- return normalized
- return f"{normalized}/images/generations"
- def _tasks_endpoint(base_url: str, task_id: str) -> str:
- normalized = _normalize_base_url(base_url)
- if normalized.endswith("/v1"):
- return f"{normalized}/tasks/{task_id}"
- return f"{normalized}/v1/tasks/{task_id}"
- def _build_payload(config: dict[str, Any], args_payload: dict[str, Any]) -> dict[str, Any]:
- provider = config["provider"]
- generation = config.get("generation", {})
- payload: dict[str, Any] = {
- "model": provider["model"],
- "prompt": args_payload["prompt"],
- "size": args_payload.get("size") or generation.get("default_size", "1024x1024"),
- "n": args_payload.get("n") or generation.get("default_n", 1),
- "response_format": args_payload.get("response_format") or generation.get("default_response_format", "b64_json"),
- "quality": args_payload.get("quality") or generation.get("default_quality", "standard"),
- }
- if args_payload.get("negative_prompt"):
- payload["negative_prompt"] = args_payload["negative_prompt"]
- if args_payload.get("seed") is not None:
- payload["seed"] = args_payload["seed"]
- extra_body = args_payload.get("extra_body")
- if isinstance(extra_body, dict):
- payload.update(extra_body)
- return payload
- def _write_b64_image(output_dir: Path, image_data: str, run_id: str, index: int) -> str:
- image_bytes = base64.b64decode(image_data)
- output_path = output_dir / f"{run_id}_{index}.png"
- output_path.write_bytes(image_bytes)
- return str(output_path)
- def _download_image(output_dir: Path, session: requests.Session, image_url: str, run_id: str, index: int) -> str:
- response = session.get(image_url, timeout=60)
- response.raise_for_status()
- output_path = output_dir / f"{run_id}_{index}.png"
- output_path.write_bytes(response.content)
- return str(output_path)
- def _save_images(config: dict[str, Any], response_payload: dict[str, Any], session: requests.Session, run_id: str) -> list[str]:
- output_dir = resolve_output_dir(config)
- images: list[str] = []
- data = response_payload.get("data") or []
- for index, item in enumerate(data, start=1):
- if item.get("b64_json"):
- images.append(_write_b64_image(output_dir, item["b64_json"], run_id, index))
- continue
- if item.get("url"):
- images.append(_download_image(output_dir, session, item["url"], run_id, index))
- continue
- return images
- def _should_retry_async(response: requests.Response) -> bool:
- if response.status_code != 400:
- return False
- try:
- payload = response.json()
- except ValueError:
- return False
- message = str(payload.get("errors", {}).get("message", "")).lower()
- return "does not support synchronous calls" in message and "async" in message
- def _should_use_async_mode(provider: dict[str, Any]) -> bool:
- base_url = str(provider.get("base_url", "")).lower()
- return "api-inference.modelscope.cn" in base_url
- def _poll_async_task(
- config: dict[str, Any],
- provider: dict[str, Any],
- session: requests.Session,
- task_id: str,
- run_id: str,
- ) -> dict[str, Any]:
- generation = config.get("generation", {})
- poll_interval = float(generation.get("poll_interval_seconds", 3))
- timeout_seconds = int(generation.get("timeout_seconds", 300))
- started_at = time.monotonic()
- endpoint = _tasks_endpoint(provider["base_url"], task_id)
- while True:
- if time.monotonic() - started_at > timeout_seconds:
- raise TimeoutError(f"Timed out waiting for task {task_id} after {timeout_seconds} seconds.")
- response = session.get(
- endpoint,
- headers={"X-ModelScope-Task-Type": "image_generation"},
- timeout=60,
- )
- response.raise_for_status()
- payload = response.json()
- task_status = str(payload.get("task_status", "")).upper()
- if task_status == "SUCCEED":
- output_images = payload.get("output_images") or []
- images = [
- _download_image(output_dir=resolve_output_dir(config), session=session, image_url=image_url, run_id=run_id, index=index)
- for index, image_url in enumerate(output_images, start=1)
- ]
- if not images:
- raise RuntimeError(f"Task {task_id} succeeded but returned no output_images.")
- return {
- "status": "success",
- "run_id": run_id,
- "task_id": task_id,
- "model": provider["model"],
- "images": images,
- "raw_response": payload,
- }
- if task_status == "FAILED":
- raise RuntimeError(json.dumps(payload, ensure_ascii=False))
- time.sleep(poll_interval)
- def execute_generation(args_payload: dict[str, Any]) -> dict[str, Any]:
- config = load_config()
- provider = require_provider_config(config)
- payload = _build_payload(config, args_payload)
- endpoint = _images_endpoint(provider["base_url"])
- timeout = int(config.get("generation", {}).get("timeout_seconds", 300))
- run_id = str(uuid.uuid4())
- with requests.Session() as session:
- session.headers.update(
- {
- "Authorization": f"Bearer {provider['api_key']}",
- "Content-Type": "application/json",
- }
- )
- if _should_use_async_mode(provider):
- async_response = session.post(
- endpoint,
- json=payload,
- headers={"X-ModelScope-Async-Mode": "true"},
- timeout=timeout,
- )
- async_response.raise_for_status()
- async_payload = async_response.json()
- task_id = async_payload.get("task_id")
- if not str(task_id or "").strip():
- raise RuntimeError("Async image generation did not return a task_id.")
- return _poll_async_task(config, provider, session, str(task_id), run_id)
- response = session.post(endpoint, json=payload, timeout=timeout)
- if _should_retry_async(response):
- async_response = session.post(
- endpoint,
- json=payload,
- headers={"X-ModelScope-Async-Mode": "true"},
- timeout=timeout,
- )
- async_response.raise_for_status()
- async_payload = async_response.json()
- task_id = async_payload.get("task_id")
- if not str(task_id or "").strip():
- raise RuntimeError("Async image generation did not return a task_id.")
- return _poll_async_task(config, provider, session, str(task_id), run_id)
- response.raise_for_status()
- response_payload = response.json()
- images = _save_images(config, response_payload, session, run_id)
- if not images:
- raise RuntimeError("Provider returned no downloadable images.")
- return {
- "status": "success",
- "run_id": run_id,
- "model": provider["model"],
- "images": images,
- "raw_response": response_payload,
- }
- def _parse_args_json(raw_args: str) -> dict[str, Any]:
- try:
- payload = json.loads(raw_args)
- except json.JSONDecodeError as exc:
- raise ValueError(f"Invalid JSON in --args: {exc}") from exc
- if not isinstance(payload, dict):
- raise ValueError("--args must decode to a JSON object.")
- if not str(payload.get("prompt", "")).strip():
- raise ValueError("The 'prompt' field is required.")
- return payload
- def _cmd_run(raw_args: str) -> dict[str, Any]:
- payload = _parse_args_json(raw_args)
- return execute_generation(payload)
- def main() -> None:
- parser = argparse.ArgumentParser(description="Qwen Image Client for OpenClaw Skill")
- subparsers = parser.add_subparsers(dest="command")
- sp_run = subparsers.add_parser("run", help="Generate images from a JSON payload")
- sp_run.add_argument("--args", required=True, help="JSON string of generation parameters")
- parser.add_argument("--workflow", help="Legacy compatibility flag. Only qwen/text-to-image is supported.")
- parser.add_argument("--args", dest="legacy_args", help="Legacy compatibility flag when no subcommand is used.")
- parsed = parser.parse_args()
- result: dict[str, Any]
- if parsed.command == "run":
- result = _cmd_run(parsed.args)
- elif parsed.workflow or parsed.legacy_args:
- workflow = parsed.workflow or "qwen/text-to-image"
- if workflow != "qwen/text-to-image":
- raise SystemExit("Only qwen/text-to-image is supported by this skill.")
- if not parsed.legacy_args:
- raise SystemExit("--args is required.")
- result = _cmd_run(parsed.legacy_args)
- else:
- parser.print_help()
- return
- print(json.dumps(result, ensure_ascii=False, indent=2))
- if __name__ == "__main__":
- main()
|