Skip to content

Commit 550cd46

Browse files
Adding checkpointing improvement
Signed-off-by: Vishesh Tanksale <vtanksale@nvidia.com>
1 parent 7690628 commit 550cd46

6 files changed

Lines changed: 627 additions & 42 deletions

File tree

cmd/compute-domain-kubelet-plugin/checkpoint.go

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ import (
2323
"k8s.io/kubernetes/pkg/kubelet/checkpointmanager/checksum"
2424
)
2525

26+
const CheckpointVersion = "v2"
27+
2628
type Checkpoint struct {
29+
// Version records the latest checkpoint version written, allowing readers
30+
// to quickly determine the format without inspecting nested fields.
31+
Version string `json:"version,omitempty"`
2732
// Note: The Checksum below is only associated with the V1 checkpoint
2833
// (because it doesn't have an embedded one). All future versions have
2934
// their checksum directly embedded in them to better support
@@ -32,6 +37,95 @@ type Checkpoint struct {
3237
Checksum checksum.Checksum `json:"checksum"`
3338
V1 *CheckpointV1 `json:"v1,omitempty"`
3439
V2 *CheckpointV2 `json:"v2,omitempty"`
40+
// other holds unknown fields from a newer checkpoint format, preserved
41+
// so that a downgraded driver can round-trip data it does not understand.
42+
other map[string]json.RawMessage
43+
}
44+
45+
// MarshalJSON implements json.Marshaler, merging known fields with any
46+
// unknown fields captured from a newer checkpoint format.
47+
func (cp *Checkpoint) MarshalJSON() ([]byte, error) {
48+
type Alias struct {
49+
Version string `json:"version,omitempty"`
50+
Checksum checksum.Checksum `json:"checksum"`
51+
V1 *CheckpointV1 `json:"v1,omitempty"`
52+
V2 *CheckpointV2 `json:"v2,omitempty"`
53+
}
54+
known, err := json.Marshal(&Alias{
55+
Version: cp.Version,
56+
Checksum: cp.Checksum,
57+
V1: cp.V1,
58+
V2: cp.V2,
59+
})
60+
if err != nil {
61+
return nil, err
62+
}
63+
if len(cp.other) == 0 {
64+
return known, nil
65+
}
66+
var merged map[string]json.RawMessage
67+
if err := json.Unmarshal(known, &merged); err != nil {
68+
return nil, err
69+
}
70+
for k, v := range cp.other {
71+
merged[k] = v
72+
}
73+
return json.Marshal(merged)
74+
}
75+
76+
// UnmarshalJSON implements json.Unmarshaler, populating known fields and
77+
// preserving any unrecognised fields (future versions) in cp.other.
78+
func (cp *Checkpoint) UnmarshalJSON(data []byte) error {
79+
type Alias struct {
80+
Version string `json:"version,omitempty"`
81+
Checksum checksum.Checksum `json:"checksum"`
82+
V1 *CheckpointV1 `json:"v1,omitempty"`
83+
V2 *CheckpointV2 `json:"v2,omitempty"`
84+
}
85+
var alias Alias
86+
if err := json.Unmarshal(data, &alias); err != nil {
87+
return err
88+
}
89+
cp.Version = alias.Version
90+
cp.Checksum = alias.Checksum
91+
cp.V1 = alias.V1
92+
cp.V2 = alias.V2
93+
94+
var all map[string]json.RawMessage
95+
if err := json.Unmarshal(data, &all); err != nil {
96+
return err
97+
}
98+
delete(all, "version")
99+
delete(all, "checksum")
100+
delete(all, "v1")
101+
delete(all, "v2")
102+
if len(all) > 0 {
103+
cp.other = all
104+
} else {
105+
cp.other = nil
106+
}
107+
return nil
108+
}
109+
110+
func (cp *Checkpoint) DeepCopy() *Checkpoint {
111+
if cp == nil {
112+
return nil
113+
}
114+
out := &Checkpoint{
115+
Version: cp.Version,
116+
Checksum: cp.Checksum,
117+
V1: cp.V1.DeepCopy(),
118+
V2: cp.V2.DeepCopy(),
119+
}
120+
if len(cp.other) > 0 {
121+
out.other = make(map[string]json.RawMessage, len(cp.other))
122+
for k, v := range cp.other {
123+
raw := make(json.RawMessage, len(v))
124+
copy(raw, v)
125+
out.other[k] = raw
126+
}
127+
}
128+
return out
35129
}
36130

37131
func (cp *Checkpoint) ToLatestVersion() *Checkpoint {
@@ -52,25 +146,27 @@ func (cp *Checkpoint) ToLatestVersion() *Checkpoint {
52146

53147
func (cp *Checkpoint) MarshalCheckpoint() ([]byte, error) {
54148
cp = cp.ToLatestVersion()
149+
cp.Version = CheckpointVersion
55150
cp.V1 = cp.V2.ToV1()
56151
if err := cp.SetChecksumV1(); err != nil {
57152
return nil, fmt.Errorf("error setting v1 checksum: %v", err)
58153
}
59154
if err := cp.SetChecksumV2(); err != nil {
60155
return nil, fmt.Errorf("error setting v2 checksum: %v", err)
61156
}
62-
return json.Marshal(*cp)
157+
return json.Marshal(cp)
63158
}
64159

160+
// SetChecksumV1 computes and sets the V1 checksum, which covers only the
161+
// V1 view of the checkpoint (Version, V2, and other are excluded so that
162+
// older drivers computing the same checksum get identical JSON).
65163
func (cp *Checkpoint) SetChecksumV1() error {
66-
v2 := cp.V2
67-
cp.V2 = nil
68-
defer func() {
69-
cp.V2 = v2
70-
}()
71-
72-
cp.Checksum = 0
73-
out, err := json.Marshal(*cp)
164+
type v1View struct {
165+
Checksum checksum.Checksum `json:"checksum"`
166+
V1 *CheckpointV1 `json:"v1,omitempty"`
167+
}
168+
view := v1View{V1: cp.V1}
169+
out, err := json.Marshal(view)
74170
if err != nil {
75171
return err
76172
}
@@ -126,22 +222,19 @@ func (cp *Checkpoint) VerifyChecksum() error {
126222
return nil
127223
}
128224

225+
// VerifyChecksumV1 verifies the V1 checksum using the same V1-only view that
226+
// SetChecksumV1 used, ensuring older drivers can also verify successfully.
129227
func (cp *Checkpoint) VerifyChecksumV1() error {
130-
ck := cp.Checksum
131-
v2 := cp.V2
132-
cp.V2 = nil
133-
defer func() {
134-
cp.Checksum = ck
135-
cp.V2 = v2
136-
}()
137-
138-
cp.Checksum = 0
139-
out, err := json.Marshal(*cp)
228+
type v1View struct {
229+
Checksum checksum.Checksum `json:"checksum"`
230+
V1 *CheckpointV1 `json:"v1,omitempty"`
231+
}
232+
view := v1View{V1: cp.V1}
233+
out, err := json.Marshal(view)
140234
if err != nil {
141235
return err
142236
}
143-
144-
return ck.Verify(out)
237+
return cp.Checksum.Verify(out)
145238
}
146239

147240
func (cp *Checkpoint) VerifyChecksumV2() error {

cmd/compute-domain-kubelet-plugin/checkpointv.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,72 @@ type PreparedClaimV1 struct {
6565
PreparedDevices PreparedDevices `json:"preparedDevices,omitempty"`
6666
}
6767

68+
// DeepCopy methods
69+
70+
func (v1 *CheckpointV1) DeepCopy() *CheckpointV1 {
71+
if v1 == nil {
72+
return nil
73+
}
74+
return &CheckpointV1{PreparedClaims: v1.PreparedClaims.DeepCopy()}
75+
}
76+
77+
func (v2 *CheckpointV2) DeepCopy() *CheckpointV2 {
78+
if v2 == nil {
79+
return nil
80+
}
81+
return &CheckpointV2{
82+
Checksum: v2.Checksum,
83+
PreparedClaims: v2.PreparedClaims.DeepCopy(),
84+
}
85+
}
86+
87+
func (m PreparedClaimsByUIDV1) DeepCopy() PreparedClaimsByUIDV1 {
88+
if m == nil {
89+
return nil
90+
}
91+
out := make(PreparedClaimsByUIDV1, len(m))
92+
for k, v := range m {
93+
out[k] = v.DeepCopy()
94+
}
95+
return out
96+
}
97+
98+
func (m PreparedClaimsByUIDV2) DeepCopy() PreparedClaimsByUIDV2 {
99+
if m == nil {
100+
return nil
101+
}
102+
out := make(PreparedClaimsByUIDV2, len(m))
103+
for k, v := range m {
104+
out[k] = v.DeepCopy()
105+
}
106+
return out
107+
}
108+
109+
func (c PreparedClaimV1) DeepCopy() PreparedClaimV1 {
110+
var status resourceapi.ResourceClaimStatus
111+
if s := c.Status.DeepCopy(); s != nil {
112+
status = *s
113+
}
114+
return PreparedClaimV1{
115+
Status: status,
116+
PreparedDevices: c.PreparedDevices.DeepCopy(),
117+
}
118+
}
119+
120+
func (c PreparedClaimV2) DeepCopy() PreparedClaimV2 {
121+
var status resourceapi.ResourceClaimStatus
122+
if s := c.Status.DeepCopy(); s != nil {
123+
status = *s
124+
}
125+
return PreparedClaimV2{
126+
CheckpointState: c.CheckpointState,
127+
Status: status,
128+
PreparedDevices: c.PreparedDevices.DeepCopy(),
129+
Name: c.Name,
130+
Namespace: c.Namespace,
131+
}
132+
}
133+
68134
// Conversion functions
69135

70136
func (v1 *CheckpointV1) ToV2() *CheckpointV2 {
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
Copyright The Kubernetes Authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"k8s.io/dynamic-resource-allocation/kubeletplugin"
21+
)
22+
23+
func (d PreparedDevices) DeepCopy() PreparedDevices {
24+
if d == nil {
25+
return nil
26+
}
27+
out := make(PreparedDevices, len(d))
28+
for i, group := range d {
29+
out[i] = group.DeepCopy()
30+
}
31+
return out
32+
}
33+
34+
func (g *PreparedDeviceGroup) DeepCopy() *PreparedDeviceGroup {
35+
if g == nil {
36+
return nil
37+
}
38+
return &PreparedDeviceGroup{
39+
Devices: g.Devices.DeepCopy(),
40+
ConfigState: g.ConfigState.DeepCopy(),
41+
}
42+
}
43+
44+
func (l PreparedDeviceList) DeepCopy() PreparedDeviceList {
45+
if l == nil {
46+
return nil
47+
}
48+
out := make(PreparedDeviceList, len(l))
49+
for i, d := range l {
50+
out[i] = d.DeepCopy()
51+
}
52+
return out
53+
}
54+
55+
func (d PreparedDevice) DeepCopy() PreparedDevice {
56+
return PreparedDevice{
57+
Channel: d.Channel.DeepCopy(),
58+
Daemon: d.Daemon.DeepCopy(),
59+
}
60+
}
61+
62+
func (c *PreparedComputeDomainChannel) DeepCopy() *PreparedComputeDomainChannel {
63+
if c == nil {
64+
return nil
65+
}
66+
return &PreparedComputeDomainChannel{
67+
Info: c.Info.DeepCopy(),
68+
Device: deepCopyDevice(c.Device),
69+
}
70+
}
71+
72+
func (d *PreparedComputeDomainDaemon) DeepCopy() *PreparedComputeDomainDaemon {
73+
if d == nil {
74+
return nil
75+
}
76+
return &PreparedComputeDomainDaemon{
77+
Info: d.Info.DeepCopy(),
78+
Device: deepCopyDevice(d.Device),
79+
}
80+
}
81+
82+
func (d *ComputeDomainChannelInfo) DeepCopy() *ComputeDomainChannelInfo {
83+
if d == nil {
84+
return nil
85+
}
86+
return &ComputeDomainChannelInfo{ID: d.ID}
87+
}
88+
89+
func (d *ComputeDomainDaemonInfo) DeepCopy() *ComputeDomainDaemonInfo {
90+
if d == nil {
91+
return nil
92+
}
93+
return &ComputeDomainDaemonInfo{ID: d.ID}
94+
}
95+
96+
func (d DeviceConfigState) DeepCopy() DeviceConfigState {
97+
return DeviceConfigState{
98+
Type: d.Type,
99+
ComputeDomain: d.ComputeDomain,
100+
}
101+
}
102+
103+
func deepCopyDevice(d *kubeletplugin.Device) *kubeletplugin.Device {
104+
if d == nil {
105+
return nil
106+
}
107+
cp := &kubeletplugin.Device{
108+
PoolName: d.PoolName,
109+
DeviceName: d.DeviceName,
110+
}
111+
if len(d.Requests) > 0 {
112+
cp.Requests = make([]string, len(d.Requests))
113+
copy(cp.Requests, d.Requests)
114+
}
115+
if len(d.CDIDeviceIDs) > 0 {
116+
cp.CDIDeviceIDs = make([]string, len(d.CDIDeviceIDs))
117+
copy(cp.CDIDeviceIDs, d.CDIDeviceIDs)
118+
}
119+
if d.ShareID != nil {
120+
uid := *d.ShareID
121+
cp.ShareID = &uid
122+
}
123+
return cp
124+
}

0 commit comments

Comments
 (0)