Skip to content

[Feat] [Snowflake Provider] Lazy import of snowflake modules in snowflake/hooks/snowflake.py #62362

@dwreeves

Description

@dwreeves

Description

Instead of importing snowflake modules in the global scope, I am proposing these should be lazy-loaded instead.

Use case/motivation

Motivation

snowflake modules take up a lot of resources. Lazy-loading these modules puts a little less strain on the scheduler + DAG processing in a lot of end-user code.

Because SnowflakeSqlApiHook is a subclass of SnowflakeHook, and the operators module imports the SnowflakeSqlApiHook globally, this improves the load time of effectively every import from airflow.providers.snowflake and puts significantly less stress on the scheduler instance when parsing the DAG.

Benchmarking

Here is a script to measure the impact of lazy-loading Snowflake.

Output on my laptop:

─────────────────────────── Snowflake Hook — Import Cost Benchmark ───────────────────────────
  Python: 3.13.1 (main, Jan 14 2025, 23:48:54) [Clang 19.1.6 ]
  Runs per scenario: 10  |  also benchmark w/o .pyc: no

            avg ms (±stdev)  |  avg RSS MB  —  10 runs each             
╭───────────────────────────────┬───────────────────┬──────────────────╮
│ Scenario                      │ Time w/ .pyc (ms) │ RSS w/ .pyc (MB) │
├───────────────────────────────┼───────────────────┼──────────────────┤
│ No Snowflake, no SQLAlchemy   │             60 ±2 │             17.2 │
│ No Snowflake, yes SQLAlchemy  │     121 ±1  (+61) │    31.0  (+13.9) │
│ Yes Snowflake, yes SQLAlchemy │    291 ±5  (+231) │    59.8  (+42.6) │
╰───────────────────────────────┴───────────────────┴──────────────────╯
  (+N) = delta vs first scenario  |  ±N = stdev across runs

Relevant delta is between lines 2 and 3, as I don't believe SQLAlchemy is ever truly avoidable in the scheduler instance to the best of my knowledge. Regardless, you can see, the cost of importing Snowflake is quite significant: an extra +170ms and nearly +30mb of memory.

Here is the script used to run these benchmarks. Note: script is AI generated / assisted. Script is 100% standalone and can be run with uv run file.py --help.

#!/usr/bin/env python3
# NOTE: AI generated script.
# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "rich-click",
#   "psutil",
#   "rich",
#   "requests",
#   "tenacity",
#   "cryptography",
#   "snowflake-connector-python",
#   "snowflake-sqlalchemy",
#   "sqlalchemy",
# ]
# ///
"""
Measures the import cost of airflow/providers/snowflake/hooks/snowflake.py
with and without the snowflake-specific imports, across N cold-cache runs.

Optionally also benchmarks with .pyc bytecode caches enabled (--pyc).

Usage:
    uv run benchmark_snowflake_imports.py
    uv run benchmark_snowflake_imports.py -n 20
    uv run benchmark_snowflake_imports.py --pyc
    uv run benchmark_snowflake_imports.py --pyc -n 5
"""
from __future__ import annotations

import os
import shutil
import statistics
import subprocess
import sys
import tempfile
from dataclasses import dataclass, field

import rich_click as click
import psutil
from rich import box
from rich.console import Console
from rich.table import Table
from rich.text import Text

console = Console()

# All imports from snowflake.py, minus the snowflake-specific ones
WITHOUT_SNOWFLAKE = """\
import base64
import os
from collections.abc import Callable, Iterable, Mapping
from contextlib import closing, contextmanager
from datetime import datetime, timedelta
from functools import cached_property
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, overload
from urllib.parse import urlparse
import requests
import tenacity
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from requests.auth import HTTPBasicAuth
from requests.exceptions import ConnectionError, HTTPError, Timeout
"""

WITHOUT_SNOWFLAKE_WITH_SQLALCHEMY = WITHOUT_SNOWFLAKE + """\
from sqlalchemy import create_engine
"""

WITH_SNOWFLAKE = WITHOUT_SNOWFLAKE + """\
from snowflake import connector
from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
"""

SCENARIOS = [
    ("No Snowflake, no SQLAlchemy", WITHOUT_SNOWFLAKE),
    ("No Snowflake, yes SQLAlchemy", WITHOUT_SNOWFLAKE_WITH_SQLALCHEMY),
    ("Yes Snowflake, yes SQLAlchemy", WITH_SNOWFLAKE),
]


@dataclass
class RunStats:
    times_ms: list[float] = field(default_factory=list)
    rss_mb: list[float] = field(default_factory=list)

    @property
    def avg_ms(self) -> float:
        return statistics.mean(self.times_ms)

    @property
    def avg_rss(self) -> float:
        return statistics.mean(self.rss_mb)

    @property
    def stdev_ms(self) -> float:
        return statistics.stdev(self.times_ms) if len(self.times_ms) > 1 else 0.0


@dataclass
class ScenarioResult:
    label: str
    with_pyc: RunStats
    without_pyc: RunStats | None = None


SUBPROCESS_CODE_TEMPLATE = """\
import time, os, sys
import psutil

proc = psutil.Process(os.getpid())
rss_before = proc.memory_info().rss
t0 = time.perf_counter()
try:
{indented}
    elapsed = time.perf_counter() - t0
    rss_after = proc.memory_info().rss
    print(f"{{elapsed:.6f}} {{(rss_after - rss_before) / 1024 / 1024:.4f}}")
except Exception as e:
    print(f"ERROR: {{e}}", file=sys.stderr)
    sys.exit(1)
"""


def _single_run(import_block: str, use_pyc: bool) -> tuple[float, float]:
    """Spawn a fresh interpreter and return (elapsed_ms, rss_mb)."""
    indented = "\n".join("    " + line for line in import_block.splitlines())
    code = SUBPROCESS_CODE_TEMPLATE.format(indented=indented)

    env = None
    tmp_dir = None
    if not use_pyc:
        # Redirect bytecode cache to a fresh empty dir so Python finds no .pyc
        # files and must compile every module from source.
        tmp_dir = tempfile.mkdtemp(prefix="no_pyc_bench_")
        env = {**os.environ, "PYTHONPYCACHEPREFIX": tmp_dir}

    try:
        result = subprocess.run(
            [sys.executable, "-c", code],
            capture_output=True,
            text=True,
            env=env,
        )
    finally:
        if tmp_dir:
            shutil.rmtree(tmp_dir, ignore_errors=True)

    if result.returncode != 0:
        console.print(f"[red]ERROR:[/] {result.stderr.strip()}")
        sys.exit(1)

    elapsed_s, rss_mb = result.stdout.strip().split()
    return float(elapsed_s) * 1000, float(rss_mb)


def measure(label: str, import_block: str, n: int, include_no_pyc: bool) -> ScenarioResult:
    with_pyc = RunStats()
    without_pyc = RunStats() if include_no_pyc else None

    # Warm-up run to ensure .pyc files are written before timed runs.
    with console.status(f"  [italic]{label}[/]  warming up .pyc cache…", spinner="dots"):
        _single_run(import_block, use_pyc=True)

    for i in range(n):
        with console.status(f"  [italic]{label}[/]  run {i + 1}/{n} (w/ .pyc)…", spinner="dots"):
            ms, mb = _single_run(import_block, use_pyc=True)
        with_pyc.times_ms.append(ms)
        with_pyc.rss_mb.append(mb)

    if include_no_pyc:
        for i in range(n):
            with console.status(f"  [italic]{label}[/]  run {i + 1}/{n} (w/o .pyc)…", spinner="dots"):
                ms, mb = _single_run(import_block, use_pyc=False)
            without_pyc.times_ms.append(ms)  # type: ignore[union-attr]
            without_pyc.rss_mb.append(mb)  # type: ignore[union-attr]

    return ScenarioResult(label=label, with_pyc=with_pyc, without_pyc=without_pyc)


def _time_cell(stats: RunStats, baseline_ms: float) -> Text:
    t = Text(f"{stats.avg_ms:.0f}", style="white")
    t.append(f" ±{stats.stdev_ms:.0f}", style="dim")
    if stats.avg_ms != baseline_ms:
        t.append(f"  (+{stats.avg_ms - baseline_ms:.0f})", style="dim cyan")
    return t


def _rss_cell(stats: RunStats, baseline_mb: float) -> Text:
    t = Text(f"{stats.avg_rss:.1f}", style="white")
    if stats.avg_rss != baseline_mb:
        t.append(f"  (+{stats.avg_rss - baseline_mb:.1f})", style="dim cyan")
    return t


@click.command()
@click.option(
    "-n", "--runs",
    default=10,
    show_default=True,
    help="Number of cold-cache subprocess runs to average per scenario.",
)
@click.option(
    "--no-pyc",
    is_flag=True,
    default=False,
    help="Also run benchmarks without .pyc bytecode caches.",
)
@click.rich_config({"theme": "cargo-slim"})
def main(runs: int, no_pyc: bool) -> None:
    console.print()
    console.rule("[bold cyan]Snowflake Hook — Import Cost Benchmark[/]")
    console.print(f"  [dim]Python: {sys.version}[/]")
    console.print(f"  [dim]Runs per scenario: {runs}  |  also benchmark w/o .pyc: {'yes' if no_pyc else 'no'}[/]")

    console.print()

    results = [measure(label, block, runs, no_pyc) for label, block in SCENARIOS]

    table = Table(
        box=box.ROUNDED,
        show_footer=False,
        title_style="bold cyan",
        border_style="grey50",
        title=f"avg ms (±stdev)  |  avg RSS MB  —  {runs} runs each",
    )
    table.add_column("Scenario", style="white", no_wrap=True)
    table.add_column("Time w/ .pyc (ms)", justify="right")
    table.add_column("RSS w/ .pyc (MB)", justify="right")
    if no_pyc:
        table.add_column("Time w/o .pyc (ms)", justify="right")
        table.add_column("RSS w/o .pyc (MB)", justify="right")

    baseline_pyc_ms = results[0].with_pyc.avg_ms
    baseline_pyc_mb = results[0].with_pyc.avg_rss
    baseline_nopyc_ms = results[0].without_pyc.avg_ms if no_pyc else 0.0
    baseline_nopyc_mb = results[0].without_pyc.avg_rss if no_pyc else 0.0

    for r in results:
        row = [
            r.label,
            _time_cell(r.with_pyc, baseline_pyc_ms),
            _rss_cell(r.with_pyc, baseline_pyc_mb),
        ]
        if no_pyc and r.without_pyc is not None:
            row += [
                _time_cell(r.without_pyc, baseline_nopyc_ms),
                _rss_cell(r.without_pyc, baseline_nopyc_mb),
            ]
        table.add_row(*row)

    console.print(table)
    console.print("  [dim](+N) = delta vs first scenario  |  ±N = stdev across runs[/]")
    console.print()


if __name__ == "__main__":
    main()

Related issues

No response

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions