Skip to content

Commit 0f60df6

Browse files
authored
Support for quota as object (#244)
Fixes #238
1 parent 32bfbb6 commit 0f60df6

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

nftables_test.go

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6077,6 +6077,198 @@ func TestGetRulesObjref(t *testing.T) {
60776077
}
60786078
}
60796079

6080+
func TestAddQuotaObj(t *testing.T) {
6081+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
6082+
defer nftest.CleanupSystemConn(t, newNS)
6083+
conn.FlushRuleset()
6084+
defer conn.FlushRuleset()
6085+
6086+
table := &nftables.Table{
6087+
Name: "quota_demo",
6088+
Family: nftables.TableFamilyIPv4,
6089+
}
6090+
tr := conn.AddTable(table)
6091+
6092+
c := &nftables.Chain{
6093+
Name: "filter",
6094+
Table: table,
6095+
}
6096+
conn.AddChain(c)
6097+
6098+
o := &nftables.QuotaObj{
6099+
Table: tr,
6100+
Name: "q_test",
6101+
Bytes: 0x06400000,
6102+
Consumed: 0,
6103+
Over: true,
6104+
}
6105+
conn.AddObj(o)
6106+
6107+
if err := conn.Flush(); err != nil {
6108+
t.Errorf("conn.Flush() failed: %v", err)
6109+
}
6110+
6111+
obj, err := conn.GetObj(&nftables.QuotaObj{
6112+
Table: table,
6113+
Name: "q_test",
6114+
})
6115+
if err != nil {
6116+
t.Fatalf("conn.GetObj() failed: %v", err)
6117+
}
6118+
6119+
if got, want := len(obj), 1; got != want {
6120+
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
6121+
}
6122+
6123+
o1, ok := obj[0].(*nftables.QuotaObj)
6124+
if !ok {
6125+
t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0])
6126+
}
6127+
if got, want := o1.Name, o.Name; got != want {
6128+
t.Fatalf("quota name mismatch: got %s, want %s", got, want)
6129+
}
6130+
if got, want := o1.Bytes, o.Bytes; got != want {
6131+
t.Fatalf("quota bytes mismatch: got %d, want %d", got, want)
6132+
}
6133+
if got, want := o1.Consumed, o.Consumed; got != want {
6134+
t.Fatalf("quota consumed mismatch: got %d, want %d", got, want)
6135+
}
6136+
if got, want := o1.Over, o.Over; got != want {
6137+
t.Fatalf("quota over mismatch: got %v, want %v", got, want)
6138+
}
6139+
}
6140+
6141+
func TestAddQuotaObjRef(t *testing.T) {
6142+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
6143+
defer nftest.CleanupSystemConn(t, newNS)
6144+
conn.FlushRuleset()
6145+
defer conn.FlushRuleset()
6146+
6147+
table := &nftables.Table{
6148+
Name: "quota_demo",
6149+
Family: nftables.TableFamilyIPv4,
6150+
}
6151+
tr := conn.AddTable(table)
6152+
6153+
c := &nftables.Chain{
6154+
Name: "filter",
6155+
Table: table,
6156+
}
6157+
conn.AddChain(c)
6158+
6159+
o := &nftables.QuotaObj{
6160+
Table: tr,
6161+
Name: "q_test",
6162+
Bytes: 0x06400000,
6163+
Consumed: 0,
6164+
Over: true,
6165+
}
6166+
conn.AddObj(o)
6167+
6168+
r := &nftables.Rule{
6169+
Table: table,
6170+
Chain: c,
6171+
Exprs: []expr.Any{
6172+
&expr.Objref{
6173+
Type: 2,
6174+
Name: "q_test",
6175+
},
6176+
},
6177+
}
6178+
conn.AddRule(r)
6179+
if err := conn.Flush(); err != nil {
6180+
t.Fatalf("failed to flush: %v", err)
6181+
}
6182+
6183+
rules, err := conn.GetRules(table, c)
6184+
if err != nil {
6185+
t.Fatalf("failed to get rules: %v", err)
6186+
}
6187+
6188+
if got, want := len(rules), 1; got != want {
6189+
t.Fatalf("unexpected number of rules: got %d, want %d", got, want)
6190+
}
6191+
if got, want := len(rules[0].Exprs), 1; got != want {
6192+
t.Fatalf("unexpected number of exprs: got %d, want %d", got, want)
6193+
}
6194+
6195+
objref, ok := rules[0].Exprs[0].(*expr.Objref)
6196+
if !ok {
6197+
t.Fatalf("Exprs[0] is type %T, want *expr.Objref", rules[0].Exprs[0])
6198+
}
6199+
if want := r.Exprs[0]; !reflect.DeepEqual(objref, want) {
6200+
t.Errorf("objref expr = %+v, wanted %+v", objref, want)
6201+
}
6202+
}
6203+
6204+
func TestDeleteQuotaObj(t *testing.T) {
6205+
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
6206+
defer nftest.CleanupSystemConn(t, newNS)
6207+
conn.FlushRuleset()
6208+
defer conn.FlushRuleset()
6209+
6210+
table := &nftables.Table{
6211+
Name: "quota_demo",
6212+
Family: nftables.TableFamilyIPv4,
6213+
}
6214+
tr := conn.AddTable(table)
6215+
6216+
c := &nftables.Chain{
6217+
Name: "filter",
6218+
Table: table,
6219+
}
6220+
conn.AddChain(c)
6221+
6222+
o := &nftables.QuotaObj{
6223+
Table: tr,
6224+
Name: "q_test",
6225+
Bytes: 0x06400000,
6226+
Consumed: 0,
6227+
Over: true,
6228+
}
6229+
conn.AddObj(o)
6230+
6231+
if err := conn.Flush(); err != nil {
6232+
t.Fatalf("conn.Flush() failed: %v", err)
6233+
}
6234+
6235+
obj, err := conn.GetObj(&nftables.QuotaObj{
6236+
Table: table,
6237+
Name: "q_test",
6238+
})
6239+
if err != nil {
6240+
t.Fatalf("conn.GetObj() failed: %v", err)
6241+
}
6242+
6243+
if got, want := len(obj), 1; got != want {
6244+
t.Fatalf("unexpected number of objects: got %d, want %d", got, want)
6245+
}
6246+
6247+
if got, want := obj[0], o; !reflect.DeepEqual(got, want) {
6248+
t.Errorf("got = %+v, want = %+v", got, want)
6249+
}
6250+
6251+
conn.DeleteObject(&nftables.QuotaObj{
6252+
Table: tr,
6253+
Name: "q_test",
6254+
})
6255+
6256+
if err := conn.Flush(); err != nil {
6257+
t.Fatalf("conn.Flush() failed: %v", err)
6258+
}
6259+
6260+
obj, err = conn.GetObj(&nftables.QuotaObj{
6261+
Table: table,
6262+
Name: "q_test",
6263+
})
6264+
if err != nil {
6265+
t.Fatalf("conn.GetObj() failed: %v", err)
6266+
}
6267+
if got, want := len(obj), 0; got != want {
6268+
t.Fatalf("unexpected object list length: got %d, want %d", got, want)
6269+
}
6270+
}
6271+
60806272
func TestGetRulesQueue(t *testing.T) {
60816273
// Create a new network namespace to test these operations,
60826274
// and tear down the namespace at test completion.

obj.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ func objFromMsg(msg netlink.Message) (Obj, error) {
155155
Name: name,
156156
}
157157

158+
ad.Do(func(b []byte) error {
159+
ad, err := netlink.NewAttributeDecoder(b)
160+
if err != nil {
161+
return err
162+
}
163+
ad.ByteOrder = binary.BigEndian
164+
return o.unmarshal(ad)
165+
})
166+
return &o, ad.Err()
167+
case NFT_OBJECT_QUOTA:
168+
o := QuotaObj{
169+
Table: table,
170+
Name: name,
171+
}
172+
158173
ad.Do(func(b []byte) error {
159174
ad, err := netlink.NewAttributeDecoder(b)
160175
if err != nil {

quota.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright 2023 Google LLC. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package nftables
16+
17+
import (
18+
"github.com/google/nftables/binaryutil"
19+
"github.com/mdlayher/netlink"
20+
"golang.org/x/sys/unix"
21+
)
22+
23+
const (
24+
NFTA_OBJ_USERDATA = 8
25+
NFT_OBJECT_QUOTA = 2
26+
)
27+
28+
type QuotaObj struct {
29+
Table *Table
30+
Name string
31+
Bytes uint64
32+
Consumed uint64
33+
Over bool
34+
}
35+
36+
func (q *QuotaObj) unmarshal(ad *netlink.AttributeDecoder) error {
37+
for ad.Next() {
38+
switch ad.Type() {
39+
case unix.NFTA_QUOTA_BYTES:
40+
q.Bytes = ad.Uint64()
41+
case unix.NFTA_QUOTA_CONSUMED:
42+
q.Consumed = ad.Uint64()
43+
case unix.NFTA_QUOTA_FLAGS:
44+
q.Over = (ad.Uint32() & unix.NFT_QUOTA_F_INV) == 1
45+
}
46+
}
47+
return nil
48+
}
49+
50+
func (q *QuotaObj) marshal(data bool) ([]byte, error) {
51+
flags := uint32(0)
52+
if q.Over {
53+
flags = unix.NFT_QUOTA_F_INV
54+
}
55+
obj, err := netlink.MarshalAttributes([]netlink.Attribute{
56+
{Type: unix.NFTA_QUOTA_BYTES, Data: binaryutil.BigEndian.PutUint64(q.Bytes)},
57+
{Type: unix.NFTA_QUOTA_CONSUMED, Data: binaryutil.BigEndian.PutUint64(q.Consumed)},
58+
{Type: unix.NFTA_QUOTA_FLAGS, Data: binaryutil.BigEndian.PutUint32(flags)},
59+
})
60+
if err != nil {
61+
return nil, err
62+
}
63+
attrs := []netlink.Attribute{
64+
{Type: unix.NFTA_OBJ_TABLE, Data: []byte(q.Table.Name + "\x00")},
65+
{Type: unix.NFTA_OBJ_NAME, Data: []byte(q.Name + "\x00")},
66+
{Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_QUOTA)},
67+
}
68+
if data {
69+
attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj})
70+
}
71+
return netlink.MarshalAttributes(attrs)
72+
}
73+
74+
func (q *QuotaObj) table() *Table {
75+
return q.Table
76+
}
77+
78+
func (q *QuotaObj) family() TableFamily {
79+
return q.Table.Family
80+
}

0 commit comments

Comments
 (0)