You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

177 lines
6.7 KiB

#!/usr/bin/env python3
"""
Клиент для DaSiWa API Server (асинхронный, как RunPod).
Запускается на ТВОЁМ ПК. Отправляет задачу, поллит статус, забирает результат.
Использование:
python client.py --server http://<ip>:8080 --image photo.png --prompt "woman dancing"
python client.py --server http://<ip>:8080 --image start.png --last-image end.png --prompt "smooth transition"
"""
import argparse
import base64
import json
import os
import sys
import time
import requests
from hmac_auth import sign_request
# ============================================================================
# Конфигурация
# ============================================================================
KEYS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "keys.json")
def load_keys():
if not os.path.exists(KEYS_FILE):
print(f"❌ Файл ключей не найден: {KEYS_FILE}")
print(" Запусти: python generate_keys.py")
sys.exit(1)
with open(KEYS_FILE, "r") as f:
return json.load(f)
def image_to_base64(path: str) -> str:
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode()
def signed_post(server_url: str, path: str, payload: dict, client_id: str, secret_key: str):
"""Отправляет подписанный POST запрос."""
body = json.dumps(payload).encode("utf-8")
auth_headers = sign_request(body, secret_key, client_id)
headers = {"Content-Type": "application/json", **auth_headers}
response = requests.post(f"{server_url}{path}", data=body, headers=headers, timeout=30)
return response.status_code, response.json()
def signed_get(server_url: str, path: str, client_id: str, secret_key: str):
"""Отправляет подписанный GET запрос."""
body = b""
auth_headers = sign_request(body, secret_key, client_id)
response = requests.get(f"{server_url}{path}", headers=auth_headers, timeout=30)
return response.status_code, response.json()
def submit_job(server_url: str, payload: dict, client_id: str, secret_key: str):
"""Отправляет задачу на генерацию. Возвращает job_id."""
code, data = signed_post(server_url, "/run", payload, client_id, secret_key)
if code != 200:
raise RuntimeError(f"Submit failed ({code}): {data.get('error', data)}")
return data["id"]
def wait_for_completion(server_url: str, job_id: str, client_id: str, secret_key: str,
poll_interval: int = 5, max_wait: int = 1800):
"""Поллит статус задачи до завершения."""
start = time.time()
while time.time() - start < max_wait:
code, data = signed_get(server_url, f"/status/{job_id}", client_id, secret_key)
if code != 200:
raise RuntimeError(f"Status check failed ({code}): {data}")
status = data.get("status")
elapsed = int(time.time() - start)
if status == "COMPLETED":
print(f"\r✅ COMPLETED ({elapsed}s)")
return data
elif status == "FAILED":
raise RuntimeError(f"Job failed: {data.get('error', 'Unknown error')}")
else:
print(f"\r{status}... ({elapsed}s)", end="", flush=True)
time.sleep(poll_interval)
raise RuntimeError(f"Timeout waiting for job ({max_wait}s)")
def main():
parser = argparse.ArgumentParser(description="DaSiWa API Client (async)")
parser.add_argument("--server", required=True, help="Server URL, e.g. http://1.2.3.4:8080")
parser.add_argument("--image", required=True, help="Path to first frame image")
parser.add_argument("--last-image", default=None, help="Path to last frame image (FLF2V mode)")
parser.add_argument("--prompt", required=True, help="Text prompt")
parser.add_argument("--negative-prompt", default=None, help="Negative prompt")
parser.add_argument("--width", type=int, default=528)
parser.add_argument("--height", type=int, default=768)
parser.add_argument("--length", type=int, default=81, help="Frame count")
parser.add_argument("--steps", type=int, default=4)
parser.add_argument("--cfg", type=float, default=1.0)
parser.add_argument("--seed", type=int, default=-1)
parser.add_argument("--fps", type=int, default=16)
parser.add_argument("--poll-interval", type=int, default=5, help="Status poll interval (seconds)")
parser.add_argument("--output", "-o", default="output.mp4", help="Output video path")
args = parser.parse_args()
keys = load_keys()
cid, secret = keys["client_id"], keys["secret_key"]
# Формируем payload
payload = {
"prompt": args.prompt,
"image_base64": image_to_base64(args.image),
"width": args.width,
"height": args.height,
"length": args.length,
"steps": args.steps,
"cfg": args.cfg,
"seed": args.seed,
"fps": args.fps,
}
if args.negative_prompt:
payload["negative_prompt"] = args.negative_prompt
if args.last_image:
payload["last_image_base64"] = image_to_base64(args.last_image)
print(f"🎬 Режим: FLF2V (first + last frame)")
else:
print(f"🎬 Режим: I2V (image to video)")
print(f"📐 {args.width}x{args.height}, {args.length} frames, {args.steps} steps")
# 1. Submit job
print(f"📤 Отправляю задачу на {args.server}...")
try:
job_id = submit_job(args.server, payload, cid, secret)
except RuntimeError as e:
print(f"{e}")
sys.exit(1)
print(f"📝 Job ID: {job_id}")
# 2. Poll for completion
print(f"⏳ Жду результат (поллинг каждые {args.poll_interval}s)...")
try:
result = wait_for_completion(args.server, job_id, cid, secret,
poll_interval=args.poll_interval)
except RuntimeError as e:
print(f"\n{e}")
sys.exit(1)
# 3. Save video
output = result.get("output", {})
video_b64 = output.get("video")
if not video_b64:
print(f"❌ Нет видео в ответе")
sys.exit(1)
video_bytes = base64.b64decode(video_b64)
with open(args.output, "wb") as f:
f.write(video_bytes)
print(f"✅ Видео сохранено: {args.output} ({len(video_bytes) / 1024 / 1024:.1f} MB)")
print(f"⏱ Сервер: {output.get('elapsed', '?')}s | Seed: {output.get('seed', '?')} | Mode: {output.get('mode', '?')}")
# 4. Purge job from server memory
try:
signed_post(args.server, f"/purge/{job_id}", {}, cid, secret)
except Exception:
pass # не критично
if __name__ == "__main__":
main()