feat: add LangGraph checkpoint resume for crash recovery (#594)

Long analyses can take many minutes; a crash or interruption forced users
to re-run from scratch and re-pay every LLM call.  This adds an opt-in
checkpoint layer backed by per-ticker SQLite databases so the graph
resumes from the last successful node.

How to use:
- CLI:    tradingagents analyze --checkpoint
- CLI:    tradingagents analyze --clear-checkpoints
- Python: config["checkpoint_enabled"] = True

Lifecycle:
- propagate() recompiles the graph with a SqliteSaver when enabled and
  injects a deterministic thread_id derived from ticker+date so the
  same ticker+date resumes while a different date starts fresh.
- On successful completion the per-thread checkpoint rows are cleared.
- The context manager is closed in a try/finally so a crash never
  leaks the SQLite connection or leaves the graph in checkpoint mode.

Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db
(override via TRADINGAGENTS_CACHE_DIR).

The checkpointer module is new (tradingagents/graph/checkpointer.py)
and the GraphSetup now returns the uncompiled workflow so it can be
recompiled with a saver when needed.

Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify
the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
Yijia-Xiao
2026-04-25 08:39:27 +00:00
parent ebd2e12e67
commit 4cbd4b086f
9 changed files with 349 additions and 21 deletions
+3
View File
@@ -16,6 +16,9 @@ DEFAULT_CONFIG = {
"google_thinking_level": None, # "high", "minimal", etc.
"openai_reasoning_effort": None, # "medium", "high", "low"
"anthropic_effort": None, # "high", "medium", "low"
# Checkpoint/resume: when True, LangGraph saves state after each node
# so a crashed run can resume from the last successful step.
"checkpoint_enabled": False,
# Output language for analyst reports and final decision
# Internal agent debate stays in English for reasoning quality
"output_language": "English",
+86
View File
@@ -0,0 +1,86 @@
"""LangGraph checkpoint support for resumable analysis runs.
Per-ticker SQLite databases so concurrent tickers don't contend.
"""
from __future__ import annotations
import hashlib
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from langgraph.checkpoint.sqlite import SqliteSaver
def _db_path(data_dir: str | Path, ticker: str) -> Path:
"""Return the SQLite checkpoint DB path for a ticker."""
p = Path(data_dir) / "checkpoints"
p.mkdir(parents=True, exist_ok=True)
return p / f"{ticker.upper()}.db"
def thread_id(ticker: str, date: str) -> str:
"""Deterministic thread ID for a ticker+date pair."""
return hashlib.sha256(f"{ticker.upper()}:{date}".encode()).hexdigest()[:16]
@contextmanager
def get_checkpointer(data_dir: str | Path, ticker: str) -> Generator[SqliteSaver, None, None]:
"""Context manager yielding a SqliteSaver backed by a per-ticker DB."""
db = _db_path(data_dir, ticker)
conn = sqlite3.connect(str(db), check_same_thread=False)
try:
saver = SqliteSaver(conn)
saver.setup()
yield saver
finally:
conn.close()
def has_checkpoint(data_dir: str | Path, ticker: str, date: str) -> bool:
"""Check whether a resumable checkpoint exists for ticker+date."""
return checkpoint_step(data_dir, ticker, date) is not None
def checkpoint_step(data_dir: str | Path, ticker: str, date: str) -> int | None:
"""Return the step number of the latest checkpoint, or None if none exists."""
db = _db_path(data_dir, ticker)
if not db.exists():
return None
tid = thread_id(ticker, date)
with get_checkpointer(data_dir, ticker) as saver:
config = {"configurable": {"thread_id": tid}}
cp = saver.get_tuple(config)
if cp is None:
return None
return cp.metadata.get("step")
def clear_all_checkpoints(data_dir: str | Path) -> int:
"""Remove all checkpoint DBs. Returns number of files deleted."""
cp_dir = Path(data_dir) / "checkpoints"
if not cp_dir.exists():
return 0
dbs = list(cp_dir.glob("*.db"))
for db in dbs:
db.unlink()
return len(dbs)
def clear_checkpoint(data_dir: str | Path, ticker: str, date: str) -> None:
"""Remove checkpoint for a specific ticker+date by deleting the thread's rows."""
db = _db_path(data_dir, ticker)
if not db.exists():
return
tid = thread_id(ticker, date)
conn = sqlite3.connect(str(db))
try:
for table in ("writes", "checkpoints"):
conn.execute(f"DELETE FROM {table} WHERE thread_id = ?", (tid,))
conn.commit()
except sqlite3.OperationalError:
pass
finally:
conn.close()
+1 -2
View File
@@ -179,5 +179,4 @@ class GraphSetup:
workflow.add_edge("Portfolio Manager", END)
# Compile and return
return workflow.compile()
return workflow
+54 -13
View File
@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
# Set up the graph: keep the workflow for recompilation with a checkpointer.
self.workflow = self.graph_setup.setup_graph(selected_analysts)
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
self.memory_log.batch_update_with_outcomes(updates)
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
"""Run the trading agents graph for a company on a specific date.
When ``checkpoint_enabled`` is set in config, the graph is recompiled
with a per-ticker SqliteSaver so a crashed run can resume from the last
successful node on a subsequent invocation with the same ticker+date.
"""
self.ticker = company_name
# Resolve any pending log entries for this ticker before the pipeline runs.
# This adds the outcome + reflection from the previous run at zero latency cost.
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
self._resolve_pending_entries(company_name)
# Initialize state — inject memory log context for PM
# Recompile with a checkpointer if the user opted in.
if self.config.get("checkpoint_enabled"):
self._checkpointer_ctx = get_checkpointer(
self.config["data_cache_dir"], company_name
)
saver = self._checkpointer_ctx.__enter__()
self.graph = self.workflow.compile(checkpointer=saver)
step = checkpoint_step(
self.config["data_cache_dir"], company_name, str(trade_date)
)
if step is not None:
logger.info(
"Resuming from step %d for %s on %s", step, company_name, trade_date
)
else:
logger.info("Starting fresh for %s on %s", company_name, trade_date)
try:
return self._run_graph(company_name, trade_date)
finally:
if self._checkpointer_ctx is not None:
self._checkpointer_ctx.__exit__(None, None, None)
self._checkpointer_ctx = None
self.graph = self.workflow.compile()
def _run_graph(self, company_name, trade_date):
"""Execute the graph and write the resulting state to disk and memory log."""
# Initialize state — inject memory log context for PM.
past_context = self.memory_log.get_past_context(company_name)
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, past_context=past_context
)
args = self.propagator.get_graph_args()
# Inject thread_id so same ticker+date resumes, different date starts fresh.
if self.config.get("checkpoint_enabled"):
tid = thread_id(company_name, str(trade_date))
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
# Store current state for reflection.
self.curr_state = final_state
# Log state
# Log state to disk.
self._log_state(trade_date, final_state)
# Store decision for deferred reflection.
# Store decision for deferred reflection on the next same-ticker run.
self.memory_log.store_decision(
ticker=company_name,
trade_date=trade_date,
final_trade_decision=final_state["final_trade_decision"],
)
# Return decision and processed signal
# Clear checkpoint on successful completion to avoid stale state.
if self.config.get("checkpoint_enabled"):
clear_checkpoint(
self.config["data_cache_dir"], company_name, str(trade_date)
)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):