config.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import annotations
  2. import json
  3. from copy import deepcopy
  4. from pathlib import Path
  5. from typing import Any
  6. BASE_DIR = Path(__file__).resolve().parents[2]
  7. CONFIG_PATH = BASE_DIR / "config.json"
  8. EXAMPLE_CONFIG_PATH = BASE_DIR / "config.example.json"
  9. DEFAULT_CONFIG: dict[str, Any] = {
  10. "provider": {
  11. "name": "qwen-compatible",
  12. "base_url": "https://api-inference.modelscope.cn/v1",
  13. "model": "qwen-image",
  14. },
  15. "generation": {
  16. "output_dir": "./outputs",
  17. "timeout_seconds": 300,
  18. "poll_interval_seconds": 3,
  19. "default_size": "1024x1024",
  20. "default_n": 1,
  21. "default_response_format": "b64_json",
  22. "default_quality": "standard",
  23. },
  24. }
  25. def load_json(path: Path) -> dict[str, Any]:
  26. with path.open("r", encoding="utf-8") as handle:
  27. return json.load(handle)
  28. def _merge_defaults(defaults: dict[str, Any], value: dict[str, Any]) -> dict[str, Any]:
  29. merged = deepcopy(defaults)
  30. for key, item in value.items():
  31. if isinstance(item, dict) and isinstance(merged.get(key), dict):
  32. merged[key] = _merge_defaults(merged[key], item)
  33. else:
  34. merged[key] = item
  35. return merged
  36. def load_config() -> dict[str, Any]:
  37. if CONFIG_PATH.exists():
  38. return _merge_defaults(DEFAULT_CONFIG, load_json(CONFIG_PATH))
  39. return _merge_defaults(DEFAULT_CONFIG, load_json(EXAMPLE_CONFIG_PATH))
  40. def resolve_output_dir(config: dict[str, Any]) -> Path:
  41. raw_output_dir = config.get("generation", {}).get("output_dir", "./outputs")
  42. output_dir = Path(raw_output_dir)
  43. if not output_dir.is_absolute():
  44. output_dir = BASE_DIR / output_dir
  45. output_dir.mkdir(parents=True, exist_ok=True)
  46. return output_dir
  47. def require_provider_config(config: dict[str, Any]) -> dict[str, Any]:
  48. provider = config.get("provider", {})
  49. missing = [
  50. key
  51. for key in ("api_key",)
  52. if not str(provider.get(key, "")).strip()
  53. ]
  54. if missing:
  55. raise ValueError(
  56. "Missing provider config fields: "
  57. + ", ".join(missing)
  58. + ". Update config.json before running the skill."
  59. )
  60. return provider