""" Pre-warm the defeatbeta httpfs cache to a persistent directory. By default the library caches to /tmp/defeatbeta/cache/ which is wiped on reboot. This script redirects the cache to ~/.cache/defeatbeta/ and then does a full SELECT * on every table in parallel, forcing all parquet blocks to be fetched and stored. Re-running is safe: already-warmed tables are skipped. Usage: uv run python warmup_cache.py After this runs once (~3-4 GB download), use persistent_cache.py in your notebooks/scripts to read from the cache with no network: from persistent_cache import enable_persistent_cache enable_persistent_cache() from defeatbeta_api.data.ticker import Ticker t = Ticker("AAPL") """ import json import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path # ── Persistent cache location ──────────────────────────────────────────────── CACHE_DIR = Path.home() / ".cache" / "defeatbeta" CACHE_DIR.mkdir(parents=True, exist_ok=True) STATE_FILE = CACHE_DIR / "warmup_done.json" # ── Redirect cache dir before the library touches it ──────────────────────── import defeatbeta_api.utils.util as _util _util.validate_httpfs_cache_directory = lambda: str(CACHE_DIR) from defeatbeta_api.client.duckdb_conf import Configuration from defeatbeta_api.client.duckdb_client import get_duckdb_client from defeatbeta_api.client.hugging_face_client import HuggingFaceClient from defeatbeta_api.utils.const import tables WORKERS = 2 # keep concurrent connections low to avoid HuggingFace 429s config = Configuration( cache_httpfs_disk_size=500 * 1024 * 1024, http_retries=10, http_retry_wait_ms=10_000, # 10s base wait on 429/5xx http_retry_backoff=2.0, # doubles each retry: 10s, 20s, 40s … http_timeout=180, ) client = get_duckdb_client(config=config) hf = HuggingFaceClient() # ── Resume state ───────────────────────────────────────────────────────────── def _load_done() -> set[str]: try: return set(json.loads(STATE_FILE.read_text())["done"]) except Exception: return set() def _save_done(done: set[str]) -> None: tmp = STATE_FILE.with_suffix(".tmp") tmp.write_text(json.dumps({"done": sorted(done)})) tmp.rename(STATE_FILE) already_done = _load_done() todo = [t for t in tables if t not in already_done] # ── Helpers ─────────────────────────────────────────────────────────────────── def cache_mb() -> float: try: return sum(f.stat().st_size for f in CACHE_DIR.iterdir() if f.is_file()) / 1e6 except Exception: return 0.0 # ── Shared state for the live reporter ─────────────────────────────────────── _lock = threading.Lock() _in_flight: set[str] = set() _done: list[str] = list(already_done) # pre-seed with already-completed tables _stop_reporter = threading.Event() def _reporter() -> None: while not _stop_reporter.is_set(): with _lock: n_done = len(_done) flying = ", ".join(sorted(_in_flight)) or "—" n_flight = len(_in_flight) mb = cache_mb() print( f"\r [live] cache={mb:>6.0f} MB | " f"done={n_done}/{len(tables)} | " f"in-flight({n_flight}): {flying} ", end="", flush=True, ) _stop_reporter.wait(timeout=2) print("\r" + " " * 80 + "\r", end="", flush=True) def fetch_table(table: str) -> tuple[str, int, float]: url = hf.get_url_path(table) with _lock: _in_flight.add(table) t0 = time.perf_counter() for attempt in range(1, 4): try: df = client.query(f"SELECT * FROM '{url}'") break except Exception as exc: if attempt == 3: with _lock: _in_flight.discard(table) raise wait = 60 * attempt print(f"\n ! {table} failed (attempt {attempt}): {exc} — retrying in {wait}s") time.sleep(wait) elapsed = time.perf_counter() - t0 with _lock: _in_flight.discard(table) _done.append(table) _save_done(set(_done)) return table, len(df), elapsed # ── Run ─────────────────────────────────────────────────────────────────────── print(f"Cache dir : {CACHE_DIR}") print(f"Tables : {len(todo)} to fetch ({len(already_done)} already done, skipping)") print(f"Workers : {WORKERS}") print() if not todo: print("All tables already warmed. Nothing to do.") raise SystemExit(0) for t in sorted(already_done): print(f" - {t:<40} (skip)") print() reporter = threading.Thread(target=_reporter, daemon=True) reporter.start() t_start = time.perf_counter() results: list[tuple[str, int, float]] = [] with ThreadPoolExecutor(max_workers=WORKERS) as pool: futures = {pool.submit(fetch_table, t): t for t in todo} for future in as_completed(futures): table, rows, elapsed = future.result() results.append((table, rows, elapsed)) with _lock: n_done = len(_done) mb = cache_mb() print( f" ✓ {table:<40} {rows:>9,} rows {elapsed:5.1f}s " f"[{n_done}/{len(tables)} cache={mb:.0f} MB]" ) _stop_reporter.set() reporter.join() total_elapsed = time.perf_counter() - t_start total_rows = sum(r for _, r, _ in results) print() print(f"Finished in {total_elapsed:.0f}s") print(f"Fetched : {total_rows:,} rows across {len(results)} tables") print(f"Cache size : {cache_mb():.0f} MB → {CACHE_DIR}")