mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-28 13:01:12 +03:00
feat: add LangGraph checkpoint resume for crash recovery (#594)
Long analyses can take many minutes; a crash or interruption forced users to re-run from scratch and re-pay every LLM call. This adds an opt-in checkpoint layer backed by per-ticker SQLite databases so the graph resumes from the last successful node. How to use: - CLI: tradingagents analyze --checkpoint - CLI: tradingagents analyze --clear-checkpoints - Python: config["checkpoint_enabled"] = True Lifecycle: - propagate() recompiles the graph with a SqliteSaver when enabled and injects a deterministic thread_id derived from ticker+date so the same ticker+date resumes while a different date starts fresh. - On successful completion the per-thread checkpoint rows are cleared. - The context manager is closed in a try/finally so a crash never leaks the SQLite connection or leaves the graph in checkpoint mode. Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db (override via TRADINGAGENTS_CACHE_DIR). The checkpointer module is new (tradingagents/graph/checkpointer.py) and the GraphSetup now returns the uncompiled workflow so it can be recompiled with a saver when needed. Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
"""LangGraph checkpoint support for resumable analysis runs.
|
||||
|
||||
Per-ticker SQLite databases so concurrent tickers don't contend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
|
||||
|
||||
def _db_path(data_dir: str | Path, ticker: str) -> Path:
|
||||
"""Return the SQLite checkpoint DB path for a ticker."""
|
||||
p = Path(data_dir) / "checkpoints"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p / f"{ticker.upper()}.db"
|
||||
|
||||
|
||||
def thread_id(ticker: str, date: str) -> str:
|
||||
"""Deterministic thread ID for a ticker+date pair."""
|
||||
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
|
||||
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
conn = sqlite3.connect(str(db), check_same_thread=False)
|
||||
try:
|
||||
saver = SqliteSaver(conn)
|
||||
saver.setup()
|
||||
yield saver
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
|
||||
"""Check whether a resumable checkpoint exists for ticker+date."""
|
||||
return checkpoint_step(data_dir, ticker, date) is not None
|
||||
|
||||
|
||||
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
|
||||
"""Return the step number of the latest checkpoint, or None if none exists."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
if not db.exists():
|
||||
return None
|
||||
tid = thread_id(ticker, date)
|
||||
with get_checkpointer(data_dir, ticker) as saver:
|
||||
config = {"configurable": {"thread_id": tid}}
|
||||
cp = saver.get_tuple(config)
|
||||
if cp is None:
|
||||
return None
|
||||
return cp.metadata.get("step")
|
||||
|
||||
|
||||
def clear_all_checkpoints(data_dir: str | Path) -> int:
|
||||
"""Remove all checkpoint DBs. Returns number of files deleted."""
|
||||
cp_dir = Path(data_dir) / "checkpoints"
|
||||
if not cp_dir.exists():
|
||||
return 0
|
||||
dbs = list(cp_dir.glob("*.db"))
|
||||
for db in dbs:
|
||||
db.unlink()
|
||||
return len(dbs)
|
||||
|
||||
|
||||
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
|
||||
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
|
||||
db = _db_path(data_dir, ticker)
|
||||
if not db.exists():
|
||||
return
|
||||
tid = thread_id(ticker, date)
|
||||
conn = sqlite3.connect(str(db))
|
||||
try:
|
||||
for table in ("writes", "checkpoints"):
|
||||
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
|
||||
conn.commit()
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
finally:
|
||||
conn.close()
|
||||
Reference in New Issue
Block a user