103
103
#include <linux/seq_file.h>
104
104
#include <net/net_namespace.h>
105
105
#include <net/icmp.h>
106
+ #include <net/inet_hashtables.h>
106
107
#include <net/route.h>
107
108
#include <net/checksum.h>
108
109
#include <net/xfrm.h>
@@ -565,6 +566,26 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
565
566
}
566
567
EXPORT_SYMBOL_GPL (udp4_lib_lookup );
567
568
569
+ static inline bool __udp_is_mcast_sock (struct net * net , struct sock * sk ,
570
+ __be16 loc_port , __be32 loc_addr ,
571
+ __be16 rmt_port , __be32 rmt_addr ,
572
+ int dif , unsigned short hnum )
573
+ {
574
+ struct inet_sock * inet = inet_sk (sk );
575
+
576
+ if (!net_eq (sock_net (sk ), net ) ||
577
+ udp_sk (sk )-> udp_port_hash != hnum ||
578
+ (inet -> inet_daddr && inet -> inet_daddr != rmt_addr ) ||
579
+ (inet -> inet_dport != rmt_port && inet -> inet_dport ) ||
580
+ (inet -> inet_rcv_saddr && inet -> inet_rcv_saddr != loc_addr ) ||
581
+ ipv6_only_sock (sk ) ||
582
+ (sk -> sk_bound_dev_if && sk -> sk_bound_dev_if != dif ))
583
+ return false;
584
+ if (!ip_mc_sf_allow (sk , loc_addr , rmt_addr , dif ))
585
+ return false;
586
+ return true;
587
+ }
588
+
568
589
static inline struct sock * udp_v4_mcast_next (struct net * net , struct sock * sk ,
569
590
__be16 loc_port , __be32 loc_addr ,
570
591
__be16 rmt_port , __be32 rmt_addr ,
@@ -575,20 +596,11 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
575
596
unsigned short hnum = ntohs (loc_port );
576
597
577
598
sk_nulls_for_each_from (s , node ) {
578
- struct inet_sock * inet = inet_sk (s );
579
-
580
- if (!net_eq (sock_net (s ), net ) ||
581
- udp_sk (s )-> udp_port_hash != hnum ||
582
- (inet -> inet_daddr && inet -> inet_daddr != rmt_addr ) ||
583
- (inet -> inet_dport != rmt_port && inet -> inet_dport ) ||
584
- (inet -> inet_rcv_saddr &&
585
- inet -> inet_rcv_saddr != loc_addr ) ||
586
- ipv6_only_sock (s ) ||
587
- (s -> sk_bound_dev_if && s -> sk_bound_dev_if != dif ))
588
- continue ;
589
- if (!ip_mc_sf_allow (s , loc_addr , rmt_addr , dif ))
590
- continue ;
591
- goto found ;
599
+ if (__udp_is_mcast_sock (net , s ,
600
+ loc_port , loc_addr ,
601
+ rmt_port , rmt_addr ,
602
+ dif , hnum ))
603
+ goto found ;
592
604
}
593
605
s = NULL ;
594
606
found :
@@ -1581,6 +1593,14 @@ static void flush_stack(struct sock **stack, unsigned int count,
1581
1593
kfree_skb (skb1 );
1582
1594
}
1583
1595
1596
+ static void udp_sk_rx_dst_set (struct sock * sk , const struct sk_buff * skb )
1597
+ {
1598
+ struct dst_entry * dst = skb_dst (skb );
1599
+
1600
+ dst_hold (dst );
1601
+ sk -> sk_rx_dst = dst ;
1602
+ }
1603
+
1584
1604
/*
1585
1605
* Multicasts and broadcasts go to each listener.
1586
1606
*
@@ -1709,11 +1729,28 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1709
1729
if (udp4_csum_init (skb , uh , proto ))
1710
1730
goto csum_error ;
1711
1731
1712
- if (rt -> rt_flags & ( RTCF_BROADCAST | RTCF_MULTICAST ))
1713
- return __udp4_lib_mcast_deliver ( net , skb , uh ,
1714
- saddr , daddr , udptable ) ;
1732
+ if (skb -> sk ) {
1733
+ int ret ;
1734
+ sk = skb -> sk ;
1715
1735
1716
- sk = __udp4_lib_lookup_skb (skb , uh -> source , uh -> dest , udptable );
1736
+ if (unlikely (sk -> sk_rx_dst == NULL ))
1737
+ udp_sk_rx_dst_set (sk , skb );
1738
+
1739
+ ret = udp_queue_rcv_skb (sk , skb );
1740
+
1741
+ /* a return value > 0 means to resubmit the input, but
1742
+ * it wants the return to be -protocol, or 0
1743
+ */
1744
+ if (ret > 0 )
1745
+ return - ret ;
1746
+ return 0 ;
1747
+ } else {
1748
+ if (rt -> rt_flags & (RTCF_BROADCAST |RTCF_MULTICAST ))
1749
+ return __udp4_lib_mcast_deliver (net , skb , uh ,
1750
+ saddr , daddr , udptable );
1751
+
1752
+ sk = __udp4_lib_lookup_skb (skb , uh -> source , uh -> dest , udptable );
1753
+ }
1717
1754
1718
1755
if (sk != NULL ) {
1719
1756
int ret ;
@@ -1771,6 +1808,135 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
1771
1808
return 0 ;
1772
1809
}
1773
1810
1811
+ /* We can only early demux multicast if there is a single matching socket.
1812
+ * If more than one socket found returns NULL
1813
+ */
1814
+ static struct sock * __udp4_lib_mcast_demux_lookup (struct net * net ,
1815
+ __be16 loc_port , __be32 loc_addr ,
1816
+ __be16 rmt_port , __be32 rmt_addr ,
1817
+ int dif )
1818
+ {
1819
+ struct sock * sk , * result ;
1820
+ struct hlist_nulls_node * node ;
1821
+ unsigned short hnum = ntohs (loc_port );
1822
+ unsigned int count , slot = udp_hashfn (net , hnum , udp_table .mask );
1823
+ struct udp_hslot * hslot = & udp_table .hash [slot ];
1824
+
1825
+ rcu_read_lock ();
1826
+ begin :
1827
+ count = 0 ;
1828
+ result = NULL ;
1829
+ sk_nulls_for_each_rcu (sk , node , & hslot -> head ) {
1830
+ if (__udp_is_mcast_sock (net , sk ,
1831
+ loc_port , loc_addr ,
1832
+ rmt_port , rmt_addr ,
1833
+ dif , hnum )) {
1834
+ result = sk ;
1835
+ ++ count ;
1836
+ }
1837
+ }
1838
+ /*
1839
+ * if the nulls value we got at the end of this lookup is
1840
+ * not the expected one, we must restart lookup.
1841
+ * We probably met an item that was moved to another chain.
1842
+ */
1843
+ if (get_nulls_value (node ) != slot )
1844
+ goto begin ;
1845
+
1846
+ if (result ) {
1847
+ if (count != 1 ||
1848
+ unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
1849
+ result = NULL ;
1850
+ else if (unlikely (!__udp_is_mcast_sock (net , sk ,
1851
+ loc_port , loc_addr ,
1852
+ rmt_port , rmt_addr ,
1853
+ dif , hnum ))) {
1854
+ sock_put (result );
1855
+ result = NULL ;
1856
+ }
1857
+ }
1858
+ rcu_read_unlock ();
1859
+ return result ;
1860
+ }
1861
+
1862
+ /* For unicast we should only early demux connected sockets or we can
1863
+ * break forwarding setups. The chains here can be long so only check
1864
+ * if the first socket is an exact match and if not move on.
1865
+ */
1866
+ static struct sock * __udp4_lib_demux_lookup (struct net * net ,
1867
+ __be16 loc_port , __be32 loc_addr ,
1868
+ __be16 rmt_port , __be32 rmt_addr ,
1869
+ int dif )
1870
+ {
1871
+ struct sock * sk , * result ;
1872
+ struct hlist_nulls_node * node ;
1873
+ unsigned short hnum = ntohs (loc_port );
1874
+ unsigned int hash2 = udp4_portaddr_hash (net , loc_addr , hnum );
1875
+ unsigned int slot2 = hash2 & udp_table .mask ;
1876
+ struct udp_hslot * hslot2 = & udp_table .hash2 [slot2 ];
1877
+ INET_ADDR_COOKIE (acookie , rmt_addr , loc_addr )
1878
+ const __portpair ports = INET_COMBINED_PORTS (rmt_port , hnum );
1879
+
1880
+ rcu_read_lock ();
1881
+ result = NULL ;
1882
+ udp_portaddr_for_each_entry_rcu (sk , node , & hslot2 -> head ) {
1883
+ if (INET_MATCH (sk , net , acookie ,
1884
+ rmt_addr , loc_addr , ports , dif ))
1885
+ result = sk ;
1886
+ /* Only check first socket in chain */
1887
+ break ;
1888
+ }
1889
+
1890
+ if (result ) {
1891
+ if (unlikely (!atomic_inc_not_zero_hint (& result -> sk_refcnt , 2 )))
1892
+ result = NULL ;
1893
+ else if (unlikely (!INET_MATCH (sk , net , acookie ,
1894
+ rmt_addr , loc_addr ,
1895
+ ports , dif ))) {
1896
+ sock_put (result );
1897
+ result = NULL ;
1898
+ }
1899
+ }
1900
+ rcu_read_unlock ();
1901
+ return result ;
1902
+ }
1903
+
1904
+ void udp_v4_early_demux (struct sk_buff * skb )
1905
+ {
1906
+ const struct iphdr * iph = ip_hdr (skb );
1907
+ const struct udphdr * uh = udp_hdr (skb );
1908
+ struct sock * sk ;
1909
+ struct dst_entry * dst ;
1910
+ struct net * net = dev_net (skb -> dev );
1911
+ int dif = skb -> dev -> ifindex ;
1912
+
1913
+ /* validate the packet */
1914
+ if (!pskb_may_pull (skb , skb_transport_offset (skb ) + sizeof (struct udphdr )))
1915
+ return ;
1916
+
1917
+ if (skb -> pkt_type == PACKET_BROADCAST ||
1918
+ skb -> pkt_type == PACKET_MULTICAST )
1919
+ sk = __udp4_lib_mcast_demux_lookup (net , uh -> dest , iph -> daddr ,
1920
+ uh -> source , iph -> saddr , dif );
1921
+ else if (skb -> pkt_type == PACKET_HOST )
1922
+ sk = __udp4_lib_demux_lookup (net , uh -> dest , iph -> daddr ,
1923
+ uh -> source , iph -> saddr , dif );
1924
+ else
1925
+ return ;
1926
+
1927
+ if (!sk )
1928
+ return ;
1929
+
1930
+ skb -> sk = sk ;
1931
+ skb -> destructor = sock_edemux ;
1932
+ dst = sk -> sk_rx_dst ;
1933
+
1934
+ if (dst )
1935
+ dst = dst_check (dst , 0 );
1936
+ if (dst )
1937
+ skb_dst_set_noref (skb , dst );
1938
+ }
1939
+
1774
1940
int udp_rcv (struct sk_buff * skb )
1775
1941
{
1776
1942
return __udp4_lib_rcv (skb , & udp_table , IPPROTO_UDP );
0 commit comments