1010)
1111logger = logging .getLogger (__name__ )
1212
13+ MEMORY_GRADIENT_WARMUP_STEPS = 5
14+ MEMORY_GRADIENT_MIN_SEGMENT_LEN = 8
15+ MEMORY_GRADIENT_POSITIVE_RATIO = 0.65
16+ MEMORY_GRADIENT_MIN_SLOPE_GB = 1e-4
17+ MEMORY_GRADIENT_MIN_REL_DRIFT = 0.00015
18+ MEMORY_GRADIENT_RESUME_DROP_GB = 0.005
19+
1320
1421def extract_value (file , metrics ):
1522 metric_all = {metric : [] for metric in metrics }
@@ -25,7 +32,57 @@ def extract_value(file, metrics):
2532 return total_step , metric_all
2633
2734
28- def check_result (case_name , base_path , cur_path , check_metric ):
35+ def _split_memory_segments (values : np .ndarray ) -> list [np .ndarray ]:
36+ if len (values ) < MEMORY_GRADIENT_MIN_SEGMENT_LEN :
37+ return [values ]
38+
39+ segments : list [np .ndarray ] = []
40+ start = 0
41+ for idx in range (1 , len (values )):
42+ dropped = values [idx - 1 ] - values [idx ]
43+ if dropped >= MEMORY_GRADIENT_RESUME_DROP_GB :
44+ if idx - start >= MEMORY_GRADIENT_MIN_SEGMENT_LEN :
45+ segments .append (values [start :idx ])
46+ start = idx
47+ if len (values ) - start >= MEMORY_GRADIENT_MIN_SEGMENT_LEN :
48+ segments .append (values [start :])
49+ return segments or [values ]
50+
51+
52+ def detect_memory_upward_gradient (values : list [float ]) -> tuple [bool , str ]:
53+ """Detect sustained upward memory drift (possible leak) in the current
54+ run."""
55+ if len (values ) <= MEMORY_GRADIENT_WARMUP_STEPS + MEMORY_GRADIENT_MIN_SEGMENT_LEN :
56+ return False , ""
57+
58+ series = np .asarray (values [MEMORY_GRADIENT_WARMUP_STEPS :], dtype = float )
59+
60+ for seg_idx , segment in enumerate (_split_memory_segments (series )):
61+ if len (segment ) < MEMORY_GRADIENT_MIN_SEGMENT_LEN :
62+ continue
63+
64+ deltas = np .diff (segment )
65+ positive_ratio = float (np .mean (deltas > 1e-4 ))
66+ x = np .arange (len (segment ))
67+ slope , _ = np .polyfit (x , segment , 1 )
68+ mean_val = float (np .mean (segment ))
69+ if mean_val < 1e-10 :
70+ continue
71+
72+ relative_drift = float (slope * (len (segment ) - 1 ) / mean_val )
73+ slope_rising = slope > MEMORY_GRADIENT_MIN_SLOPE_GB
74+ mostly_increasing = positive_ratio >= MEMORY_GRADIENT_POSITIVE_RATIO
75+ drift_too_large = relative_drift > MEMORY_GRADIENT_MIN_REL_DRIFT
76+
77+ if slope_rising and mostly_increasing and drift_too_large :
78+ return True , (
79+ f"segment { seg_idx } : slope={ slope :.6f} GB/step, "
80+ f"relative_drift={ relative_drift :.4f} , positive_ratio={ positive_ratio :.2f} "
81+ )
82+ return False , ""
83+
84+
85+ def check_result (case_name , base_path , cur_path , check_metric , phase = None ):
2986 fail_metric = {}
3087 metric_list = list (check_metric .keys ())
3188 base_steps , base_metrics = extract_value (base_path , metric_list )
@@ -34,28 +91,57 @@ def check_result(case_name, base_path, cur_path, check_metric):
3491 f"current steps is not equal to base steps, current steps: { cur_steps } , base steps: { base_steps } "
3592 )
3693
37- publish_comparison_report (case_name , check_metric , base_metrics , cur_metrics , base_path , cur_path )
94+ publish_comparison_report (case_name , check_metric , base_metrics , cur_metrics , base_path , cur_path , phase = phase )
3895
3996 for metric , threshold in check_metric .items ():
4097 max_error = 0.0
4198 max_error_idx = 0
4299 check_flag = True
43100 if metric == "runtime_info/tgs" :
44101 if cur_steps > 10 :
45- relative_errors = abs (np .array (base_metrics [metric ][10 :- 1 ]) - np .array (cur_metrics [metric ][10 :- 1 ])) / (
46- np .array (base_metrics [metric ][10 :- 1 ])
102+ base_vals = np .array (base_metrics [metric ][10 :- 1 ], dtype = float )
103+ cur_vals = np .array (cur_metrics [metric ][10 :- 1 ], dtype = float )
104+ degradation = np .zeros_like (base_vals , dtype = float )
105+ valid_base = np .abs (base_vals ) >= 1e-10
106+ degradation [valid_base ] = np .maximum (
107+ (base_vals [valid_base ] - cur_vals [valid_base ]) / np .abs (base_vals [valid_base ]),
108+ 0.0 ,
47109 )
48- max_error = np .percentile (relative_errors , 80 )
110+ max_error = float ( np .percentile (degradation , 80 ) )
49111 if max_error > threshold :
50112 fail_metric [metric ] = (
51- f"{ metric } relative error bigger than { threshold } after 10 step, baseline: { base_metrics [metric ][10 :- 1 ]} , now: { cur_metrics [metric ][10 :- 1 ]} , relative error: { relative_errors } "
113+ f"{ metric } degradation bigger than { threshold } after step 10, "
114+ f"baseline: { base_metrics [metric ][10 :- 1 ]} , now: { cur_metrics [metric ][10 :- 1 ]} , "
115+ f"degradation: { degradation .tolist ()} "
52116 )
53117 check_flag = False
54118 else :
55119 check_flag = True
56120 else :
57121 logger .warning ("It's meaningless to compare tgs because of the small steps." )
58122 check_flag = False
123+ elif metric == "memory/max_memory_GB" :
124+ for idx , (old , cur ) in enumerate (zip (base_metrics [metric ], cur_metrics [metric ])):
125+ if abs (old ) < 1e-10 :
126+ relative_error = float ("inf" ) if abs (cur ) > 1e-10 else 0.0
127+ else :
128+ relative_error = round (abs (old - cur ) / abs (old ), 2 )
129+ if relative_error > max_error :
130+ max_error = relative_error
131+ max_error_idx = idx
132+ if relative_error > threshold :
133+ fail_metric [metric ] = (
134+ f"{ metric } relative error bigger than { threshold } in { idx } steps, "
135+ f"baseline: { old :.6f} , now: { cur :.6f} , relative error: { relative_error } "
136+ )
137+ check_flag = False
138+ break
139+
140+ if check_flag :
141+ has_gradient , gradient_info = detect_memory_upward_gradient (cur_metrics [metric ])
142+ if has_gradient :
143+ fail_metric [metric ] = f"{ metric } shows sustained upward gradient in current run, { gradient_info } "
144+ check_flag = False
59145 else :
60146 for idx , (old , cur ) in enumerate (zip (base_metrics [metric ], cur_metrics [metric ])):
61147 if abs (old ) < 1e-10 :
@@ -82,7 +168,7 @@ def check_result(case_name, base_path, cur_path, check_metric):
82168 return result , f"Some metric check failed: { fail_metric } "
83169
84170
85- def check_rl_result (case_name , base_path , cur_path , assert_info ):
171+ def check_rl_result (case_name , base_path , cur_path , assert_info , phase = None ):
86172 fail_metric = {}
87173 check_metrics_list = assert_info ["check_metrics" ]
88174
@@ -96,7 +182,9 @@ def check_rl_result(case_name, base_path, cur_path, assert_info):
96182 )
97183
98184 check_metric_dict = {item ["metric" ]: item ["threshold" ] for item in check_metrics_list }
99- publish_comparison_report (case_name , check_metric_dict , base_metrics , cur_metrics , base_path , cur_path )
185+ publish_comparison_report (
186+ case_name , check_metric_dict , base_metrics , cur_metrics , base_path , cur_path , phase = phase
187+ )
100188
101189 for config in check_metrics_list :
102190 metric = config ["metric" ]
0 commit comments