feat: add LangGraph checkpoint resume for crash recovery (#594)

Long analyses can take many minutes; a crash or interruption forced users
to re-run from scratch and re-pay every LLM call.  This adds an opt-in
checkpoint layer backed by per-ticker SQLite databases so the graph
resumes from the last successful node.

How to use:
- CLI:    tradingagents analyze --checkpoint
- CLI:    tradingagents analyze --clear-checkpoints
- Python: config["checkpoint_enabled"] = True

Lifecycle:
- propagate() recompiles the graph with a SqliteSaver when enabled and
  injects a deterministic thread_id derived from ticker+date so the
  same ticker+date resumes while a different date starts fresh.
- On successful completion the per-thread checkpoint rows are cleared.
- The context manager is closed in a try/finally so a crash never
  leaks the SQLite connection or leaves the graph in checkpoint mode.

Storage: ~/.tradingagents/cache/checkpoints/<TICKER>.db
(override via TRADINGAGENTS_CACHE_DIR).

The checkpointer module is new (tradingagents/graph/checkpointer.py)
and the GraphSetup now returns the uncompiled workflow so it can be
recompiled with a saver when needed.

Adds langgraph-checkpoint-sqlite>=2.0.0 dependency. 3 new tests verify
the crash/resume cycle and that a different date starts fresh.
This commit is contained in:
Yijia-Xiao
2026-04-25 08:39:27 +00:00
parent ebd2e12e67
commit 4cbd4b086f
9 changed files with 349 additions and 21 deletions
+54 -13
View File
@@ -38,6 +38,7 @@ from tradingagents.agents.utils.agent_utils import (
get_global_news
)
from .checkpointer import checkpoint_step, clear_checkpoint, get_checkpointer, thread_id
from .conditional_logic import ConditionalLogic
from .setup import GraphSetup
from .propagation import Propagator
@@ -123,8 +124,10 @@ class TradingAgentsGraph:
self.ticker = None
self.log_states_dict = {} # date to full state dict
# Set up the graph
self.graph = self.graph_setup.setup_graph(selected_analysts)
# Set up the graph: keep the workflow for recompilation with a checkpointer.
self.workflow = self.graph_setup.setup_graph(selected_analysts)
self.graph = self.workflow.compile()
self._checkpointer_ctx = None
def _get_provider_kwargs(self) -> Dict[str, Any]:
"""Get provider-specific kwargs for LLM client creation."""
@@ -259,23 +262,58 @@ class TradingAgentsGraph:
self.memory_log.batch_update_with_outcomes(updates)
def propagate(self, company_name, trade_date):
"""Run the trading agents graph for a company on a specific date."""
"""Run the trading agents graph for a company on a specific date.
When ``checkpoint_enabled`` is set in config, the graph is recompiled
with a per-ticker SqliteSaver so a crashed run can resume from the last
successful node on a subsequent invocation with the same ticker+date.
"""
self.ticker = company_name
# Resolve any pending log entries for this ticker before the pipeline runs.
# This adds the outcome + reflection from the previous run at zero latency cost.
# Resolve any pending memory-log entries for this ticker before the pipeline runs.
self._resolve_pending_entries(company_name)
# Initialize state — inject memory log context for PM
# Recompile with a checkpointer if the user opted in.
if self.config.get("checkpoint_enabled"):
self._checkpointer_ctx = get_checkpointer(
self.config["data_cache_dir"], company_name
)
saver = self._checkpointer_ctx.__enter__()
self.graph = self.workflow.compile(checkpointer=saver)
step = checkpoint_step(
self.config["data_cache_dir"], company_name, str(trade_date)
)
if step is not None:
logger.info(
"Resuming from step %d for %s on %s", step, company_name, trade_date
)
else:
logger.info("Starting fresh for %s on %s", company_name, trade_date)
try:
return self._run_graph(company_name, trade_date)
finally:
if self._checkpointer_ctx is not None:
self._checkpointer_ctx.__exit__(None, None, None)
self._checkpointer_ctx = None
self.graph = self.workflow.compile()
def _run_graph(self, company_name, trade_date):
"""Execute the graph and write the resulting state to disk and memory log."""
# Initialize state — inject memory log context for PM.
past_context = self.memory_log.get_past_context(company_name)
init_agent_state = self.propagator.create_initial_state(
company_name, trade_date, past_context=past_context
)
args = self.propagator.get_graph_args()
# Inject thread_id so same ticker+date resumes, different date starts fresh.
if self.config.get("checkpoint_enabled"):
tid = thread_id(company_name, str(trade_date))
args.setdefault("config", {}).setdefault("configurable", {})["thread_id"] = tid
if self.debug:
# Debug mode with tracing
trace = []
for chunk in self.graph.stream(init_agent_state, **args):
if len(chunk["messages"]) == 0:
@@ -283,26 +321,29 @@ class TradingAgentsGraph:
else:
chunk["messages"][-1].pretty_print()
trace.append(chunk)
final_state = trace[-1]
else:
# Standard mode without tracing
final_state = self.graph.invoke(init_agent_state, **args)
# Store current state for reflection
# Store current state for reflection.
self.curr_state = final_state
# Log state
# Log state to disk.
self._log_state(trade_date, final_state)
# Store decision for deferred reflection.
# Store decision for deferred reflection on the next same-ticker run.
self.memory_log.store_decision(
ticker=company_name,
trade_date=trade_date,
final_trade_decision=final_state["final_trade_decision"],
)
# Return decision and processed signal
# Clear checkpoint on successful completion to avoid stale state.
if self.config.get("checkpoint_enabled"):
clear_checkpoint(
self.config["data_cache_dir"], company_name, str(trade_date)
)
return final_state, self.process_signal(final_state["final_trade_decision"])
def _log_state(self, trade_date, final_state):