qwen_image_client.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. from __future__ import annotations
  2. import argparse
  3. import base64
  4. import json
  5. import time
  6. import uuid
  7. from pathlib import Path
  8. from typing import Any
  9. import requests
  10. from shared.config import load_config, require_provider_config, resolve_output_dir
  11. def _normalize_base_url(base_url: str) -> str:
  12. return base_url.rstrip("/")
  13. def _images_endpoint(base_url: str) -> str:
  14. normalized = _normalize_base_url(base_url)
  15. if normalized.endswith("/images/generations"):
  16. return normalized
  17. return f"{normalized}/images/generations"
  18. def _tasks_endpoint(base_url: str, task_id: str) -> str:
  19. normalized = _normalize_base_url(base_url)
  20. if normalized.endswith("/v1"):
  21. return f"{normalized}/tasks/{task_id}"
  22. return f"{normalized}/v1/tasks/{task_id}"
  23. def _build_payload(config: dict[str, Any], args_payload: dict[str, Any]) -> dict[str, Any]:
  24. provider = config["provider"]
  25. generation = config.get("generation", {})
  26. payload: dict[str, Any] = {
  27. "model": provider["model"],
  28. "prompt": args_payload["prompt"],
  29. "size": args_payload.get("size") or generation.get("default_size", "1024x1024"),
  30. "n": args_payload.get("n") or generation.get("default_n", 1),
  31. "response_format": args_payload.get("response_format") or generation.get("default_response_format", "b64_json"),
  32. "quality": args_payload.get("quality") or generation.get("default_quality", "standard"),
  33. }
  34. if args_payload.get("negative_prompt"):
  35. payload["negative_prompt"] = args_payload["negative_prompt"]
  36. if args_payload.get("seed") is not None:
  37. payload["seed"] = args_payload["seed"]
  38. extra_body = args_payload.get("extra_body")
  39. if isinstance(extra_body, dict):
  40. payload.update(extra_body)
  41. return payload
  42. def _write_b64_image(output_dir: Path, image_data: str, run_id: str, index: int) -> str:
  43. image_bytes = base64.b64decode(image_data)
  44. output_path = output_dir / f"{run_id}_{index}.png"
  45. output_path.write_bytes(image_bytes)
  46. return str(output_path)
  47. def _download_image(output_dir: Path, session: requests.Session, image_url: str, run_id: str, index: int) -> str:
  48. response = session.get(image_url, timeout=60)
  49. response.raise_for_status()
  50. output_path = output_dir / f"{run_id}_{index}.png"
  51. output_path.write_bytes(response.content)
  52. return str(output_path)
  53. def _save_images(config: dict[str, Any], response_payload: dict[str, Any], session: requests.Session, run_id: str) -> list[str]:
  54. output_dir = resolve_output_dir(config)
  55. images: list[str] = []
  56. data = response_payload.get("data") or []
  57. for index, item in enumerate(data, start=1):
  58. if item.get("b64_json"):
  59. images.append(_write_b64_image(output_dir, item["b64_json"], run_id, index))
  60. continue
  61. if item.get("url"):
  62. images.append(_download_image(output_dir, session, item["url"], run_id, index))
  63. continue
  64. return images
  65. def _should_retry_async(response: requests.Response) -> bool:
  66. if response.status_code != 400:
  67. return False
  68. try:
  69. payload = response.json()
  70. except ValueError:
  71. return False
  72. message = str(payload.get("errors", {}).get("message", "")).lower()
  73. return "does not support synchronous calls" in message and "async" in message
  74. def _should_use_async_mode(provider: dict[str, Any]) -> bool:
  75. base_url = str(provider.get("base_url", "")).lower()
  76. return "api-inference.modelscope.cn" in base_url
  77. def _poll_async_task(
  78. config: dict[str, Any],
  79. provider: dict[str, Any],
  80. session: requests.Session,
  81. task_id: str,
  82. run_id: str,
  83. ) -> dict[str, Any]:
  84. generation = config.get("generation", {})
  85. poll_interval = float(generation.get("poll_interval_seconds", 3))
  86. timeout_seconds = int(generation.get("timeout_seconds", 300))
  87. started_at = time.monotonic()
  88. endpoint = _tasks_endpoint(provider["base_url"], task_id)
  89. while True:
  90. if time.monotonic() - started_at > timeout_seconds:
  91. raise TimeoutError(f"Timed out waiting for task {task_id} after {timeout_seconds} seconds.")
  92. response = session.get(
  93. endpoint,
  94. headers={"X-ModelScope-Task-Type": "image_generation"},
  95. timeout=60,
  96. )
  97. response.raise_for_status()
  98. payload = response.json()
  99. task_status = str(payload.get("task_status", "")).upper()
  100. if task_status == "SUCCEED":
  101. output_images = payload.get("output_images") or []
  102. images = [
  103. _download_image(output_dir=resolve_output_dir(config), session=session, image_url=image_url, run_id=run_id, index=index)
  104. for index, image_url in enumerate(output_images, start=1)
  105. ]
  106. if not images:
  107. raise RuntimeError(f"Task {task_id} succeeded but returned no output_images.")
  108. return {
  109. "status": "success",
  110. "run_id": run_id,
  111. "task_id": task_id,
  112. "model": provider["model"],
  113. "images": images,
  114. "raw_response": payload,
  115. }
  116. if task_status == "FAILED":
  117. raise RuntimeError(json.dumps(payload, ensure_ascii=False))
  118. time.sleep(poll_interval)
  119. def execute_generation(args_payload: dict[str, Any]) -> dict[str, Any]:
  120. config = load_config()
  121. provider = require_provider_config(config)
  122. payload = _build_payload(config, args_payload)
  123. endpoint = _images_endpoint(provider["base_url"])
  124. timeout = int(config.get("generation", {}).get("timeout_seconds", 300))
  125. run_id = str(uuid.uuid4())
  126. with requests.Session() as session:
  127. session.headers.update(
  128. {
  129. "Authorization": f"Bearer {provider['api_key']}",
  130. "Content-Type": "application/json",
  131. }
  132. )
  133. if _should_use_async_mode(provider):
  134. async_response = session.post(
  135. endpoint,
  136. json=payload,
  137. headers={"X-ModelScope-Async-Mode": "true"},
  138. timeout=timeout,
  139. )
  140. async_response.raise_for_status()
  141. async_payload = async_response.json()
  142. task_id = async_payload.get("task_id")
  143. if not str(task_id or "").strip():
  144. raise RuntimeError("Async image generation did not return a task_id.")
  145. return _poll_async_task(config, provider, session, str(task_id), run_id)
  146. response = session.post(endpoint, json=payload, timeout=timeout)
  147. if _should_retry_async(response):
  148. async_response = session.post(
  149. endpoint,
  150. json=payload,
  151. headers={"X-ModelScope-Async-Mode": "true"},
  152. timeout=timeout,
  153. )
  154. async_response.raise_for_status()
  155. async_payload = async_response.json()
  156. task_id = async_payload.get("task_id")
  157. if not str(task_id or "").strip():
  158. raise RuntimeError("Async image generation did not return a task_id.")
  159. return _poll_async_task(config, provider, session, str(task_id), run_id)
  160. response.raise_for_status()
  161. response_payload = response.json()
  162. images = _save_images(config, response_payload, session, run_id)
  163. if not images:
  164. raise RuntimeError("Provider returned no downloadable images.")
  165. return {
  166. "status": "success",
  167. "run_id": run_id,
  168. "model": provider["model"],
  169. "images": images,
  170. "raw_response": response_payload,
  171. }
  172. def _parse_args_json(raw_args: str) -> dict[str, Any]:
  173. try:
  174. payload = json.loads(raw_args)
  175. except json.JSONDecodeError as exc:
  176. raise ValueError(f"Invalid JSON in --args: {exc}") from exc
  177. if not isinstance(payload, dict):
  178. raise ValueError("--args must decode to a JSON object.")
  179. if not str(payload.get("prompt", "")).strip():
  180. raise ValueError("The 'prompt' field is required.")
  181. return payload
  182. def _cmd_run(raw_args: str) -> dict[str, Any]:
  183. payload = _parse_args_json(raw_args)
  184. return execute_generation(payload)
  185. def main() -> None:
  186. parser = argparse.ArgumentParser(description="Qwen Image Client for OpenClaw Skill")
  187. subparsers = parser.add_subparsers(dest="command")
  188. sp_run = subparsers.add_parser("run", help="Generate images from a JSON payload")
  189. sp_run.add_argument("--args", required=True, help="JSON string of generation parameters")
  190. parser.add_argument("--workflow", help="Legacy compatibility flag. Only qwen/text-to-image is supported.")
  191. parser.add_argument("--args", dest="legacy_args", help="Legacy compatibility flag when no subcommand is used.")
  192. parsed = parser.parse_args()
  193. result: dict[str, Any]
  194. if parsed.command == "run":
  195. result = _cmd_run(parsed.args)
  196. elif parsed.workflow or parsed.legacy_args:
  197. workflow = parsed.workflow or "qwen/text-to-image"
  198. if workflow != "qwen/text-to-image":
  199. raise SystemExit("Only qwen/text-to-image is supported by this skill.")
  200. if not parsed.legacy_args:
  201. raise SystemExit("--args is required.")
  202. result = _cmd_run(parsed.legacy_args)
  203. else:
  204. parser.print_help()
  205. return
  206. print(json.dumps(result, ensure_ascii=False, indent=2))
  207. if __name__ == "__main__":
  208. main()