Skip to content

Commit f41dd51

Browse files
committed
add backend support for toleration lists.
clarify toleration json docs Signed-off-by: Humair Khan <[email protected]>
1 parent e21bbba commit f41dd51

File tree

4 files changed

+241
-11
lines changed

4 files changed

+241
-11
lines changed

backend/src/v2/driver/driver.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -805,19 +805,51 @@ func extendPodSpecPatch(
805805
if toleration != nil {
806806
k8sToleration := &k8score.Toleration{}
807807
if toleration.TolerationJson != nil {
808-
err := resolveK8sJsonParameter(ctx, opts, dag, pipeline, mlmd,
809-
toleration.GetTolerationJson(), inputParams, k8sToleration)
808+
resolvedParam, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd,
809+
toleration.GetTolerationJson(), inputParams)
810810
if err != nil {
811811
return fmt.Errorf("failed to resolve toleration: %w", err)
812812
}
813+
814+
// TolerationJson can be either a single toleration or list of tolerations
815+
// the field accepts both, and in both cases the tolerations are appended
816+
// to the total executor pod toleration list.
817+
var paramJSON []byte
818+
isSingleToleration := resolvedParam.GetStructValue() != nil
819+
isListToleration := resolvedParam.GetListValue() != nil
820+
if isSingleToleration {
821+
paramJSON, err = resolvedParam.GetStructValue().MarshalJSON()
822+
if err != nil {
823+
return err
824+
}
825+
var singleToleration k8score.Toleration
826+
if err = json.Unmarshal(paramJSON, &singleToleration); err != nil {
827+
return fmt.Errorf("failed to marshal single toleration to json: %w", err)
828+
}
829+
k8sTolerations = append(k8sTolerations, singleToleration)
830+
} else if isListToleration {
831+
paramJSON, err = resolvedParam.GetListValue().MarshalJSON()
832+
if err != nil {
833+
return err
834+
}
835+
var k8sTolerationsList []k8score.Toleration
836+
if err = json.Unmarshal(paramJSON, &k8sTolerationsList); err != nil {
837+
return fmt.Errorf("failed to marshal list toleration to json: %w", err)
838+
}
839+
k8sTolerations = append(k8sTolerations, k8sTolerationsList...)
840+
} else {
841+
return fmt.Errorf("encountered unexpected toleration proto value, "+
842+
"must be either struct or list type: %w", err)
843+
}
813844
} else {
814845
k8sToleration.Key = toleration.Key
815846
k8sToleration.Operator = k8score.TolerationOperator(toleration.Operator)
816847
k8sToleration.Value = toleration.Value
817848
k8sToleration.Effect = k8score.TaintEffect(toleration.Effect)
818849
k8sToleration.TolerationSeconds = toleration.TolerationSeconds
850+
k8sTolerations = append(k8sTolerations, *k8sToleration)
819851
}
820-
k8sTolerations = append(k8sTolerations, *k8sToleration)
852+
821853
}
822854
}
823855
podSpec.Tolerations = k8sTolerations

backend/src/v2/driver/driver_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ func Test_initPodSpecPatch_publishLogs(t *testing.T) {
511511
}
512512

513513
}
514+
514515
func Test_makeVolumeMountPatch(t *testing.T) {
515516
type args struct {
516517
pvcMount []*kubernetesplatform.PvcMount
@@ -2132,6 +2133,165 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) {
21322133
}),
21332134
},
21342135
},
2136+
{
2137+
"Valid - toleration json - toleration list",
2138+
&kubernetesplatform.KubernetesExecutorConfig{
2139+
Tolerations: []*kubernetesplatform.Toleration{
2140+
{
2141+
TolerationJson: inputParamComponent("param_1"),
2142+
},
2143+
},
2144+
},
2145+
&k8score.PodSpec{
2146+
Containers: []k8score.Container{
2147+
{
2148+
Name: "main",
2149+
},
2150+
},
2151+
Tolerations: []k8score.Toleration{
2152+
{
2153+
Key: "key1",
2154+
Operator: "Equal",
2155+
Value: "value1",
2156+
Effect: "NoSchedule",
2157+
TolerationSeconds: int64Ptr(3601),
2158+
},
2159+
{
2160+
Key: "key2",
2161+
Operator: "Equal",
2162+
Value: "value2",
2163+
Effect: "NoSchedule",
2164+
TolerationSeconds: int64Ptr(3602),
2165+
},
2166+
{
2167+
Key: "key3",
2168+
Operator: "Equal",
2169+
Value: "value3",
2170+
Effect: "NoSchedule",
2171+
TolerationSeconds: int64Ptr(3603),
2172+
},
2173+
},
2174+
},
2175+
map[string]*structpb.Value{
2176+
"param_1": validListOfStructsOrPanic([]map[string]interface{}{
2177+
{
2178+
"key": "key1",
2179+
"operator": "Equal",
2180+
"value": "value1",
2181+
"effect": "NoSchedule",
2182+
"tolerationSeconds": 3601,
2183+
},
2184+
{
2185+
"key": "key2",
2186+
"operator": "Equal",
2187+
"value": "value2",
2188+
"effect": "NoSchedule",
2189+
"tolerationSeconds": 3602,
2190+
},
2191+
{
2192+
"key": "key3",
2193+
"operator": "Equal",
2194+
"value": "value3",
2195+
"effect": "NoSchedule",
2196+
"tolerationSeconds": 3603,
2197+
},
2198+
}),
2199+
},
2200+
},
2201+
{
2202+
"Valid - toleration json - list toleration & single toleration & constant toleration",
2203+
&kubernetesplatform.KubernetesExecutorConfig{
2204+
Tolerations: []*kubernetesplatform.Toleration{
2205+
{
2206+
TolerationJson: inputParamComponent("param_1"),
2207+
},
2208+
{
2209+
TolerationJson: inputParamComponent("param_2"),
2210+
},
2211+
{
2212+
Key: "key5",
2213+
Operator: "Equal",
2214+
Value: "value5",
2215+
Effect: "NoSchedule",
2216+
},
2217+
},
2218+
},
2219+
&k8score.PodSpec{
2220+
Containers: []k8score.Container{
2221+
{
2222+
Name: "main",
2223+
},
2224+
},
2225+
Tolerations: []k8score.Toleration{
2226+
{
2227+
Key: "key1",
2228+
Operator: "Equal",
2229+
Value: "value1",
2230+
Effect: "NoSchedule",
2231+
TolerationSeconds: int64Ptr(3601),
2232+
},
2233+
{
2234+
Key: "key2",
2235+
Operator: "Equal",
2236+
Value: "value2",
2237+
Effect: "NoSchedule",
2238+
TolerationSeconds: int64Ptr(3602),
2239+
},
2240+
{
2241+
Key: "key3",
2242+
Operator: "Equal",
2243+
Value: "value3",
2244+
Effect: "NoSchedule",
2245+
TolerationSeconds: int64Ptr(3603),
2246+
},
2247+
{
2248+
Key: "key4",
2249+
Operator: "Equal",
2250+
Value: "value4",
2251+
Effect: "NoSchedule",
2252+
TolerationSeconds: int64Ptr(3604),
2253+
},
2254+
{
2255+
Key: "key5",
2256+
Operator: "Equal",
2257+
Value: "value5",
2258+
Effect: "NoSchedule",
2259+
},
2260+
},
2261+
},
2262+
map[string]*structpb.Value{
2263+
"param_1": validListOfStructsOrPanic([]map[string]interface{}{
2264+
{
2265+
"key": "key1",
2266+
"operator": "Equal",
2267+
"value": "value1",
2268+
"effect": "NoSchedule",
2269+
"tolerationSeconds": 3601,
2270+
},
2271+
{
2272+
"key": "key2",
2273+
"operator": "Equal",
2274+
"value": "value2",
2275+
"effect": "NoSchedule",
2276+
"tolerationSeconds": 3602,
2277+
},
2278+
{
2279+
"key": "key3",
2280+
"operator": "Equal",
2281+
"value": "value3",
2282+
"effect": "NoSchedule",
2283+
"tolerationSeconds": 3603,
2284+
},
2285+
}),
2286+
"param_2": validValueStructOrPanic(map[string]interface{}{
2287+
"key": "key4",
2288+
"operator": "Equal",
2289+
"value": "value4",
2290+
"effect": "NoSchedule",
2291+
"tolerationSeconds": 3604,
2292+
}),
2293+
},
2294+
},
21352295
}
21362296
for _, tt := range tests {
21372297
t.Run(tt.name, func(t *testing.T) {
@@ -2611,6 +2771,18 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) {
26112771
}
26122772
}
26132773

2774+
func validListOfStructsOrPanic(data []map[string]interface{}) *structpb.Value {
2775+
var listValues []*structpb.Value
2776+
for _, item := range data {
2777+
s, err := structpb.NewStruct(item)
2778+
if err != nil {
2779+
panic(err)
2780+
}
2781+
listValues = append(listValues, structpb.NewStructValue(s))
2782+
}
2783+
return structpb.NewListValue(&structpb.ListValue{Values: listValues})
2784+
}
2785+
26142786
func validValueStructOrPanic(data map[string]interface{}) *structpb.Value {
26152787
s, err := structpb.NewStruct(data)
26162788
if err != nil {

kubernetes_platform/python/kfp/kubernetes/toleration.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from typing import Optional, Union
1615

1716
from google.protobuf import json_format
@@ -82,21 +81,46 @@ def add_toleration(
8281

8382

8483
def add_toleration_json(task: PipelineTask,
85-
toleration_json: Union[pipeline_channel.PipelineParameterChannel, dict]
84+
toleration_json: Union[pipeline_channel.PipelineParameterChannel, list, dict]
8685
):
87-
"""Add a Pod Toleration in the form of a JSON to a task.
86+
"""Add a Pod Toleration in the form of a Pipeline Input JSON to a task.
8887
8988
Args:
9089
task:
9190
Pipeline task.
9291
toleration_json:
93-
a toleration provided as dict or input parameter. Takes
94-
precedence over other key, operator, value, effect,
95-
and toleration_seconds.
92+
a toleration that is a pipeline input parameter.
93+
The input parameter must be of type dict or list.
94+
95+
If it is a dict, it must be a single toleration object.
96+
For example a pipeline input parameter in this case could be::
97+
{
98+
"key": "key1",
99+
"operator": "Equal",
100+
"value": "value1",
101+
"effect": "NoSchedule"
102+
}
96103
104+
If it is a list, it must be list of toleration objects.+
105+
For example a pipeline input parameter in this case could be:
106+
[
107+
{
108+
"key": "key1",
109+
"operator": "Equal",
110+
"value": "value1",
111+
"effect": "NoSchedule"
112+
},
113+
{
114+
"key": "key2",
115+
"operator": "Exists",
116+
"effect": "NoExecute"
117+
}
118+
]
97119
Returns:
98120
Task object with added toleration.
99121
"""
122+
if not isinstance(toleration_json, pipeline_channel.PipelineParameterChannel):
123+
raise TypeError("toleration_json must be a Pipeline Input Parameter.")
100124

101125
msg = common.get_existing_kubernetes_config_as_message(task)
102126
toleration = pb.Toleration()

kubernetes_platform/python/test/unit/test_tolerations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def test_component_pipeline_input_one(self):
173173
# checks that a pipeline input for
174174
# tasks is supported
175175
@dsl.pipeline
176-
def my_pipeline(toleration_input: str):
176+
def my_pipeline(toleration_input: dict):
177177
task = comp()
178178
kubernetes.add_toleration_json(
179179
task,
@@ -204,7 +204,7 @@ def test_component_pipeline_input_two(self):
204204
# checks that multiple pipeline inputs for
205205
# different tasks are supported
206206
@dsl.pipeline
207-
def my_pipeline(toleration_input_1: str, toleration_input_2: str):
207+
def my_pipeline(toleration_input_1: dict, toleration_input_2: list):
208208
t1 = comp()
209209
kubernetes.add_toleration_json(
210210
t1,
@@ -254,6 +254,8 @@ def my_pipeline(toleration_input_1: str, toleration_input_2: str):
254254
}
255255
}
256256

257+
258+
257259
def test_component_upstream_input_one(self):
258260
# checks that upstream task input parameters
259261
# are supported

0 commit comments

Comments
 (0)