You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
6.7 KiB
177 lines
6.7 KiB
#!/usr/bin/env python3 |
|
""" |
|
Клиент для DaSiWa API Server (асинхронный, как RunPod). |
|
Запускается на ТВОЁМ ПК. Отправляет задачу, поллит статус, забирает результат. |
|
|
|
Использование: |
|
python client.py --server http://<ip>:8080 --image photo.png --prompt "woman dancing" |
|
python client.py --server http://<ip>:8080 --image start.png --last-image end.png --prompt "smooth transition" |
|
""" |
|
|
|
import argparse |
|
import base64 |
|
import json |
|
import os |
|
import sys |
|
import time |
|
|
|
import requests |
|
|
|
from hmac_auth import sign_request |
|
|
|
# ============================================================================ |
|
# Конфигурация |
|
# ============================================================================ |
|
|
|
KEYS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "keys.json") |
|
|
|
|
|
def load_keys(): |
|
if not os.path.exists(KEYS_FILE): |
|
print(f"❌ Файл ключей не найден: {KEYS_FILE}") |
|
print(" Запусти: python generate_keys.py") |
|
sys.exit(1) |
|
with open(KEYS_FILE, "r") as f: |
|
return json.load(f) |
|
|
|
|
|
def image_to_base64(path: str) -> str: |
|
with open(path, "rb") as f: |
|
return base64.b64encode(f.read()).decode() |
|
|
|
|
|
def signed_post(server_url: str, path: str, payload: dict, client_id: str, secret_key: str): |
|
"""Отправляет подписанный POST запрос.""" |
|
body = json.dumps(payload).encode("utf-8") |
|
auth_headers = sign_request(body, secret_key, client_id) |
|
headers = {"Content-Type": "application/json", **auth_headers} |
|
response = requests.post(f"{server_url}{path}", data=body, headers=headers, timeout=30) |
|
return response.status_code, response.json() |
|
|
|
|
|
def signed_get(server_url: str, path: str, client_id: str, secret_key: str): |
|
"""Отправляет подписанный GET запрос.""" |
|
body = b"" |
|
auth_headers = sign_request(body, secret_key, client_id) |
|
response = requests.get(f"{server_url}{path}", headers=auth_headers, timeout=30) |
|
return response.status_code, response.json() |
|
|
|
|
|
def submit_job(server_url: str, payload: dict, client_id: str, secret_key: str): |
|
"""Отправляет задачу на генерацию. Возвращает job_id.""" |
|
code, data = signed_post(server_url, "/run", payload, client_id, secret_key) |
|
if code != 200: |
|
raise RuntimeError(f"Submit failed ({code}): {data.get('error', data)}") |
|
return data["id"] |
|
|
|
|
|
def wait_for_completion(server_url: str, job_id: str, client_id: str, secret_key: str, |
|
poll_interval: int = 5, max_wait: int = 1800): |
|
"""Поллит статус задачи до завершения.""" |
|
start = time.time() |
|
while time.time() - start < max_wait: |
|
code, data = signed_get(server_url, f"/status/{job_id}", client_id, secret_key) |
|
if code != 200: |
|
raise RuntimeError(f"Status check failed ({code}): {data}") |
|
|
|
status = data.get("status") |
|
elapsed = int(time.time() - start) |
|
|
|
if status == "COMPLETED": |
|
print(f"\r✅ COMPLETED ({elapsed}s)") |
|
return data |
|
elif status == "FAILED": |
|
raise RuntimeError(f"Job failed: {data.get('error', 'Unknown error')}") |
|
else: |
|
print(f"\r⏳ {status}... ({elapsed}s)", end="", flush=True) |
|
time.sleep(poll_interval) |
|
|
|
raise RuntimeError(f"Timeout waiting for job ({max_wait}s)") |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="DaSiWa API Client (async)") |
|
parser.add_argument("--server", required=True, help="Server URL, e.g. http://1.2.3.4:8080") |
|
parser.add_argument("--image", required=True, help="Path to first frame image") |
|
parser.add_argument("--last-image", default=None, help="Path to last frame image (FLF2V mode)") |
|
parser.add_argument("--prompt", required=True, help="Text prompt") |
|
parser.add_argument("--negative-prompt", default=None, help="Negative prompt") |
|
parser.add_argument("--width", type=int, default=528) |
|
parser.add_argument("--height", type=int, default=768) |
|
parser.add_argument("--length", type=int, default=81, help="Frame count") |
|
parser.add_argument("--steps", type=int, default=4) |
|
parser.add_argument("--cfg", type=float, default=1.0) |
|
parser.add_argument("--seed", type=int, default=-1) |
|
parser.add_argument("--fps", type=int, default=16) |
|
parser.add_argument("--poll-interval", type=int, default=5, help="Status poll interval (seconds)") |
|
parser.add_argument("--output", "-o", default="output.mp4", help="Output video path") |
|
args = parser.parse_args() |
|
|
|
keys = load_keys() |
|
cid, secret = keys["client_id"], keys["secret_key"] |
|
|
|
# Формируем payload |
|
payload = { |
|
"prompt": args.prompt, |
|
"image_base64": image_to_base64(args.image), |
|
"width": args.width, |
|
"height": args.height, |
|
"length": args.length, |
|
"steps": args.steps, |
|
"cfg": args.cfg, |
|
"seed": args.seed, |
|
"fps": args.fps, |
|
} |
|
|
|
if args.negative_prompt: |
|
payload["negative_prompt"] = args.negative_prompt |
|
|
|
if args.last_image: |
|
payload["last_image_base64"] = image_to_base64(args.last_image) |
|
print(f"🎬 Режим: FLF2V (first + last frame)") |
|
else: |
|
print(f"🎬 Режим: I2V (image to video)") |
|
|
|
print(f"📐 {args.width}x{args.height}, {args.length} frames, {args.steps} steps") |
|
|
|
# 1. Submit job |
|
print(f"📤 Отправляю задачу на {args.server}...") |
|
try: |
|
job_id = submit_job(args.server, payload, cid, secret) |
|
except RuntimeError as e: |
|
print(f"❌ {e}") |
|
sys.exit(1) |
|
print(f"📝 Job ID: {job_id}") |
|
|
|
# 2. Poll for completion |
|
print(f"⏳ Жду результат (поллинг каждые {args.poll_interval}s)...") |
|
try: |
|
result = wait_for_completion(args.server, job_id, cid, secret, |
|
poll_interval=args.poll_interval) |
|
except RuntimeError as e: |
|
print(f"\n❌ {e}") |
|
sys.exit(1) |
|
|
|
# 3. Save video |
|
output = result.get("output", {}) |
|
video_b64 = output.get("video") |
|
if not video_b64: |
|
print(f"❌ Нет видео в ответе") |
|
sys.exit(1) |
|
|
|
video_bytes = base64.b64decode(video_b64) |
|
with open(args.output, "wb") as f: |
|
f.write(video_bytes) |
|
|
|
print(f"✅ Видео сохранено: {args.output} ({len(video_bytes) / 1024 / 1024:.1f} MB)") |
|
print(f"⏱ Сервер: {output.get('elapsed', '?')}s | Seed: {output.get('seed', '?')} | Mode: {output.get('mode', '?')}") |
|
|
|
# 4. Purge job from server memory |
|
try: |
|
signed_post(args.server, f"/purge/{job_id}", {}, cid, secret) |
|
except Exception: |
|
pass # не критично |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|