Skip to content

Commit 96e2df2

Browse files
RobotGFclaude
authored andcommitted
[style] Apply ruff lint/format fixes and resolve mypy type errors
Auto-fixed by pre-commit: remove unused imports, Optional -> X | None, zip() strict=False, line length formatting. Also add proper type annotation for _metrics in SimpleStorageUnit to fix mypy errors. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bb95462 commit 96e2df2

4 files changed

Lines changed: 39 additions & 35 deletions

File tree

tests/test_metrics.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,21 @@ def test_collect_with_partitions(self):
132132

133133
# Check partition-level gauges
134134
assert exporter.partition_samples.labels(partition_id="train_0")._value.get() == 20
135-
assert exporter.partition_production_progress.labels(
136-
partition_id="train_0", task_name="gen"
137-
)._value.get() == 0.8
138-
assert exporter.partition_consumption_progress.labels(
139-
partition_id="train_0", task_name="gen"
140-
)._value.get() == 0.5
135+
assert (
136+
exporter.partition_production_progress.labels(partition_id="train_0", task_name="gen")._value.get() == 0.8
137+
)
138+
assert (
139+
exporter.partition_consumption_progress.labels(partition_id="train_0", task_name="gen")._value.get() == 0.5
140+
)
141141

142142
assert exporter.partition_samples.labels(partition_id="train_1")._value.get() == 10
143-
assert exporter.partition_production_progress.labels(
144-
partition_id="train_1", task_name="gen"
145-
)._value.get() == 1.0
146-
assert exporter.partition_consumption_progress.labels(
147-
partition_id="train_1", task_name="train"
148-
)._value.get() == 0.3
143+
assert (
144+
exporter.partition_production_progress.labels(partition_id="train_1", task_name="gen")._value.get() == 1.0
145+
)
146+
assert (
147+
exporter.partition_consumption_progress.labels(partition_id="train_1", task_name="train")._value.get()
148+
== 0.3
149+
)
149150

150151
def test_uptime_increases(self):
151152
"""Controller uptime should be positive after collection."""

transfer_queue/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import subprocess
1919
import time
2020
from importlib import resources
21-
from typing import Any, Callable, Optional
21+
from typing import Any, Callable
2222
from urllib.parse import urlparse
2323

2424
import ray
@@ -384,7 +384,7 @@ def close():
384384

385385

386386
# ==================== Metrics API ====================
387-
def get_metrics_endpoint() -> Optional[str]:
387+
def get_metrics_endpoint() -> str | None:
388388
"""Return the Prometheus metrics endpoint address (``host:port``), or *None* if metrics are disabled.
389389
390390
Works from any process — the endpoint is stored in the Controller's config

transfer_queue/metrics.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import time
1818
from contextlib import contextmanager
1919
from threading import Thread
20-
from typing import Any, Optional
20+
from typing import Any
2121
from uuid import uuid4
2222

2323
import psutil
@@ -334,25 +334,23 @@ def collect_storage_metrics(self) -> None:
334334
active = metrics.get("active_keys", 0)
335335
self.storage_capacity.labels(storage_unit_id=label).set(capacity)
336336
self.storage_active_keys.labels(storage_unit_id=label).set(active)
337-
self.storage_utilization.labels(storage_unit_id=label).set(
338-
active / capacity if capacity > 0 else 0.0
339-
)
337+
self.storage_utilization.labels(storage_unit_id=label).set(active / capacity if capacity > 0 else 0.0)
340338
self.storage_memory_rss.labels(storage_unit_id=label).set(metrics.get("process_rss_bytes", 0))
341339

342340
# Per-operation request stats
343341
for op_type, op_data in metrics.get("op_stats", {}).items():
344-
self.storage_request_ops.labels(
345-
storage_unit_id=label, op_type=op_type
346-
).set(op_data.get("request_count", 0))
347-
self.storage_request_latency_avg.labels(
348-
storage_unit_id=label, op_type=op_type
349-
).set(op_data.get("latency_avg", 0))
350-
self.storage_request_latency_p50.labels(
351-
storage_unit_id=label, op_type=op_type
352-
).set(op_data.get("latency_p50", 0))
353-
self.storage_request_latency_p99.labels(
354-
storage_unit_id=label, op_type=op_type
355-
).set(op_data.get("latency_p99", 0))
342+
self.storage_request_ops.labels(storage_unit_id=label, op_type=op_type).set(
343+
op_data.get("request_count", 0)
344+
)
345+
self.storage_request_latency_avg.labels(storage_unit_id=label, op_type=op_type).set(
346+
op_data.get("latency_avg", 0)
347+
)
348+
self.storage_request_latency_p50.labels(storage_unit_id=label, op_type=op_type).set(
349+
op_data.get("latency_p50", 0)
350+
)
351+
self.storage_request_latency_p99.labels(storage_unit_id=label, op_type=op_type).set(
352+
op_data.get("latency_p99", 0)
353+
)
356354
except Exception as e:
357355
logger.warning(f"Failed to collect metrics from storage unit {su_id}: {e}")
358356

@@ -375,7 +373,7 @@ def _get_or_create_socket(self, su_id: str, su_info: ZMQServerInfo) -> zmq.Socke
375373
self._zmq_sockets[su_id] = sock
376374
return sock
377375

378-
def _query_storage_unit(self, su_info: ZMQServerInfo, su_id: str) -> Optional[dict[str, Any]]:
376+
def _query_storage_unit(self, su_info: ZMQServerInfo, su_id: str) -> dict[str, Any] | None:
379377
"""Send a synchronous GET_METRICS request to a single storage unit."""
380378
try:
381379
sock = self._get_or_create_socket(su_id, su_info)

transfer_queue/storage/simple_storage.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import time
1818
import weakref
1919
from threading import Event, Thread
20-
from typing import Any
20+
from typing import TYPE_CHECKING, Any
2121
from uuid import uuid4
2222

2323
import psutil
@@ -38,6 +38,9 @@
3838
get_node_ip_address,
3939
)
4040

41+
if TYPE_CHECKING:
42+
from transfer_queue.metrics import TQMetricsExporter
43+
4144
logger = get_logger(__name__)
4245

4346
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5)) # in seconds
@@ -172,7 +175,7 @@ def __init__(self, storage_unit_size: int):
172175
self.proxy_thread: Thread | None = None
173176
self.worker_thread: Thread | None = None
174177

175-
self._metrics = None
178+
self._metrics: TQMetricsExporter | None = None
176179

177180
self._init_zmq_socket()
178181
self._start_process_put_get()
@@ -559,9 +562,11 @@ def _quantile_from_cumulative(hist, cumulative_counts: list[float], q: float) ->
559562
target = q * total
560563
prev_bound = 0.0
561564
prev_cumulative = 0.0
562-
for bound, cum_count in zip(hist._upper_bounds, cumulative_counts):
565+
for bound, cum_count in zip(hist._upper_bounds, cumulative_counts, strict=False):
563566
if cum_count >= target:
564-
fraction = (target - prev_cumulative) / (cum_count - prev_cumulative) if cum_count > prev_cumulative else 0
567+
fraction = (
568+
(target - prev_cumulative) / (cum_count - prev_cumulative) if cum_count > prev_cumulative else 0
569+
)
565570
return prev_bound + (bound - prev_bound) * fraction
566571
prev_bound = bound
567572
prev_cumulative = cum_count

0 commit comments

Comments
 (0)