Whamcloud - gitweb
LU-10391 socklnd: use interface index to track local addr 02/37702/11
authorMr NeilBrown <neilb@suse.de>
Fri, 7 Feb 2020 02:05:57 +0000 (13:05 +1100)
committerOleg Drokin <green@whamcloud.com>
Tue, 14 Apr 2020 08:09:20 +0000 (08:09 +0000)
socklnd currently uses local IP addresses to track the local end of a
route or connection.  This will not work so well for IPv6 where
addresses are generally more dynamic - an interface can have several
addresses and they can come and go.

Even with IPv4, an interface can have multple addresses, though this
is partially hidden as there we normally use aliases which make it
looks like different interfaces: eth0:1 eth0:2 etc.  These are really
all the same interface, just with different addresses.

It is really the local interface, rather than the local address,
which is important.  Choosing the right interface ensures the traffic
goes over the desired network hardware.  Using an address, as the code
currently does, is just a convenient indirection.

With IPv6, using the interface directly will be easier, and it is
quite possible with IPv4.

So: this patch changes ksock_route to store an interface index for the
source end of the route, and adds the index nubmer to ksock_interface.

lnet_sock_listen() loses the local IP address which is never used, and
lnet_sock_connect() now is passed a source interface rather than a
source IP address.
lock_sock_create() also gets a local source interface.

Signed-off-by: Mr NeilBrown <neilb@suse.de>
Change-Id: Ib3e0f92d10ad6ba4d66782bae22638e9935a1e4e
Reviewed-on: https://review.whamcloud.com/37702
Tested-by: jenkins <devops@whamcloud.com>
Tested-by: Maloo <maloo@whamcloud.com>
Reviewed-by: James Simmons <jsimmons@infradead.org>
Reviewed-by: Aurelien Degremont <degremoa@amazon.com>
Reviewed-by: Oleg Drokin <green@whamcloud.com>
lnet/include/lnet/lib-lnet.h
lnet/klnds/socklnd/socklnd.c
lnet/klnds/socklnd/socklnd.h
lnet/klnds/socklnd/socklnd_cb.c
lnet/lnet/acceptor.c
lnet/lnet/config.c
lnet/lnet/lib-socket.c

index d01439c..7a5b47f 100644 (file)
@@ -90,6 +90,24 @@ extern struct lnet the_lnet;                 /* THE network */
                kernel_getsockname(sock, addr, addrlen)
 #endif
 
+/*
+ * kernel 5.3: commit ef11db3310e272d3d8dbe8739e0770820dd20e52
+ * added in_dev_for_each_ifa_rtnl and in_dev_for_each_ifa_rcu
+ * and removed for_ifa and endfor_ifa.
+ * Use the _rntl variant as the current locking is rtnl.
+ */
+#ifdef in_dev_for_each_ifa_rtnl
+#define DECLARE_CONST_IN_IFADDR(ifa)           const struct in_ifaddr *ifa
+#define endfor_ifa(in_dev)
+#else
+#define DECLARE_CONST_IN_IFADDR(ifa)
+#define in_dev_for_each_ifa_rtnl(ifa, in_dev)  for_ifa((in_dev))
+#define in_dev_for_each_ifa_rcu(ifa, in_dev)   for_ifa((in_dev))
+#endif
+
+int choose_ipv4_src(__u32 *ret,
+                   int interface, __u32 dst_ipaddr, struct net *ns);
+
 bool lnet_is_route_alive(struct lnet_route *route);
 bool lnet_is_gateway_alive(struct lnet_peer *gw);
 
@@ -775,7 +793,7 @@ unsigned int lnet_get_lnd_timeout(void);
 void lnet_register_lnd(const struct lnet_lnd *lnd);
 void lnet_unregister_lnd(const struct lnet_lnd *lnd);
 
-struct socket *lnet_connect(lnet_nid_t peer_nid, __u32 local_ip, __u32 peer_ip,
+struct socket *lnet_connect(lnet_nid_t peer_nid, int interface, __u32 peer_ip,
                            int peer_port, struct net *ns);
 void lnet_connect_console_error(int rc, lnet_nid_t peer_nid,
                                 __u32 peer_ip, int port);
@@ -800,9 +818,9 @@ int lnet_sock_getaddr(struct socket *socket, bool remote, __u32 *ip, int *port);
 int lnet_sock_write(struct socket *sock, void *buffer, int nob, int timeout);
 int lnet_sock_read(struct socket *sock, void *buffer, int nob, int timeout);
 
-struct socket *lnet_sock_listen(__u32 ip, int port, int backlog,
+struct socket *lnet_sock_listen(int port, int backlog,
                                struct net *ns);
-struct socket *lnet_sock_connect(__u32 local_ip, int local_port,
+struct socket *lnet_sock_connect(int interface, int local_port,
                                 __u32 peer_ip, int peer_port, struct net *ns);
 
 int lnet_peers_start_down(void);
index 094e840..5a907fc 100644 (file)
@@ -61,6 +61,58 @@ ksocknal_ip2iface(struct lnet_ni *ni, __u32 ip)
        return NULL;
 }
 
+static struct ksock_interface *
+ksocknal_index2iface(struct lnet_ni *ni, int index)
+{
+       struct ksock_net *net = ni->ni_data;
+       int i;
+       struct ksock_interface *iface;
+
+       for (i = 0; i < net->ksnn_ninterfaces; i++) {
+               LASSERT(i < LNET_INTERFACES_NUM);
+               iface = &net->ksnn_interfaces[i];
+
+               if (iface->ksni_index == index)
+                       return iface;
+       }
+
+       return NULL;
+}
+
+static int ksocknal_ip2index(__u32 ipaddress, struct lnet_ni *ni)
+{
+       struct net_device *dev;
+       int ret = -1;
+       DECLARE_CONST_IN_IFADDR(ifa);
+
+       rcu_read_lock();
+       for_each_netdev(ni->ni_net_ns, dev) {
+               int flags = dev_get_flags(dev);
+               struct in_device *in_dev;
+
+               if (flags & IFF_LOOPBACK) /* skip the loopback IF */
+                       continue;
+
+               if (!(flags & IFF_UP))
+                       continue;
+
+               in_dev = __in_dev_get_rcu(dev);
+               if (!in_dev)
+                       continue;
+
+               in_dev_for_each_ifa_rcu(ifa, in_dev) {
+                       if (ntohl(ifa->ifa_local) == ipaddress)
+                               ret = dev->ifindex;
+               }
+               endfor_ifa(in_dev);
+               if (ret >= 0)
+                       break;
+       }
+       rcu_read_unlock();
+
+       return ret;
+}
+
 static struct ksock_route *
 ksocknal_create_route(__u32 ipaddr, int port)
 {
@@ -74,15 +126,16 @@ ksocknal_create_route(__u32 ipaddr, int port)
        route->ksnr_peer = NULL;
        route->ksnr_retry_interval = 0;         /* OK to connect at any time */
        route->ksnr_ipaddr = ipaddr;
-        route->ksnr_port = port;
-        route->ksnr_scheduled = 0;
-        route->ksnr_connecting = 0;
-        route->ksnr_connected = 0;
-        route->ksnr_deleted = 0;
-        route->ksnr_conn_count = 0;
-        route->ksnr_share_count = 0;
-
-        return (route);
+       route->ksnr_myiface = -1;
+       route->ksnr_port = port;
+       route->ksnr_scheduled = 0;
+       route->ksnr_connecting = 0;
+       route->ksnr_connected = 0;
+       route->ksnr_deleted = 0;
+       route->ksnr_conn_count = 0;
+       route->ksnr_share_count = 0;
+
+       return route;
 }
 
 void
@@ -288,12 +341,13 @@ ksocknal_get_peer_info(struct lnet_ni *ni, int index,
                                           ksnr_list);
 
                        *id = peer_ni->ksnp_id;
-                       *myip = route->ksnr_myipaddr;
+                       rc = choose_ipv4_src(myip, route->ksnr_myiface,
+                                            route->ksnr_ipaddr,
+                                            ni->ni_net_ns);
                        *peer_ip = route->ksnr_ipaddr;
                        *port = route->ksnr_port;
                        *conn_count = route->ksnr_conn_count;
                        *share_count = route->ksnr_share_count;
-                       rc = 0;
                        goto out;
                }
        }
@@ -303,47 +357,52 @@ out:
 }
 
 static void
-ksocknal_associate_route_conn_locked(struct ksock_route *route, struct ksock_conn *conn)
+ksocknal_associate_route_conn_locked(struct ksock_route *route,
+                                    struct ksock_conn *conn)
 {
        struct ksock_peer_ni *peer_ni = route->ksnr_peer;
        int type = conn->ksnc_type;
        struct ksock_interface *iface;
+       int conn_iface = ksocknal_ip2index(conn->ksnc_myipaddr,
+                                          route->ksnr_peer->ksnp_ni);
 
        conn->ksnc_route = route;
        ksocknal_route_addref(route);
 
-       if (route->ksnr_myipaddr != conn->ksnc_myipaddr) {
-               if (route->ksnr_myipaddr == 0) {
+       if (route->ksnr_myiface != conn_iface) {
+               if (route->ksnr_myiface < 0) {
                        /* route wasn't bound locally yet (the initial route) */
-                       CDEBUG(D_NET, "Binding %s %pI4h to %pI4h\n",
+                       CDEBUG(D_NET, "Binding %s %pI4h to interface %d\n",
                               libcfs_id2str(peer_ni->ksnp_id),
                               &route->ksnr_ipaddr,
-                              &conn->ksnc_myipaddr);
+                              conn_iface);
                } else {
-                       CDEBUG(D_NET, "Rebinding %s %pI4h from %pI4h "
-                              "to %pI4h\n", libcfs_id2str(peer_ni->ksnp_id),
+                       CDEBUG(D_NET,
+                              "Rebinding %s %pI4h from interface %d to %d\n",
+                              libcfs_id2str(peer_ni->ksnp_id),
                               &route->ksnr_ipaddr,
-                              &route->ksnr_myipaddr,
-                              &conn->ksnc_myipaddr);
+                              route->ksnr_myiface,
+                              conn_iface);
 
-                        iface = ksocknal_ip2iface(route->ksnr_peer->ksnp_ni,
-                                                  route->ksnr_myipaddr);
-                        if (iface != NULL)
-                                iface->ksni_nroutes--;
-                }
-                route->ksnr_myipaddr = conn->ksnc_myipaddr;
-                iface = ksocknal_ip2iface(route->ksnr_peer->ksnp_ni,
-                                          route->ksnr_myipaddr);
-                if (iface != NULL)
-                        iface->ksni_nroutes++;
-        }
+                       iface = ksocknal_index2iface(route->ksnr_peer->ksnp_ni,
+                                                    route->ksnr_myiface);
+                       if (iface)
+                               iface->ksni_nroutes--;
+               }
+               route->ksnr_myiface = conn_iface;
+               iface = ksocknal_index2iface(route->ksnr_peer->ksnp_ni,
+                                            route->ksnr_myiface);
+               if (iface)
+                       iface->ksni_nroutes++;
+       }
 
-        route->ksnr_connected |= (1<<type);
-        route->ksnr_conn_count++;
+       route->ksnr_connected |= (1<<type);
+       route->ksnr_conn_count++;
 
-        /* Successful connection => further attempts can
-         * proceed immediately */
-        route->ksnr_retry_interval = 0;
+       /* Successful connection => further attempts can
+        * proceed immediately
+        */
+       route->ksnr_retry_interval = 0;
 }
 
 static void
@@ -408,10 +467,10 @@ ksocknal_del_route_locked(struct ksock_route *route)
                ksocknal_close_conn_locked(conn, 0);
        }
 
-       if (route->ksnr_myipaddr != 0) {
-               iface = ksocknal_ip2iface(route->ksnr_peer->ksnp_ni,
-                                         route->ksnr_myipaddr);
-               if (iface != NULL)
+       if (route->ksnr_myiface >= 0) {
+               iface = ksocknal_index2iface(route->ksnr_peer->ksnp_ni,
+                                            route->ksnr_myiface);
+               if (iface)
                        iface->ksni_nroutes--;
        }
 
@@ -905,7 +964,7 @@ ksocknal_create_routes(struct ksock_peer_ni *peer_ni, int port,
                                route = list_entry(rtmp, struct ksock_route,
                                                   ksnr_list);
 
-                               if (route->ksnr_myipaddr == iface->ksni_ipaddr)
+                               if (route->ksnr_myiface == iface->ksni_index)
                                        break;
 
                                route = NULL;
@@ -913,34 +972,34 @@ ksocknal_create_routes(struct ksock_peer_ni *peer_ni, int port,
                        if (route != NULL)
                                continue;
 
-                        this_netmatch = (((iface->ksni_ipaddr ^
-                                           newroute->ksnr_ipaddr) &
-                                           iface->ksni_netmask) == 0) ? 1 : 0;
+                       this_netmatch = (((iface->ksni_ipaddr ^
+                                          newroute->ksnr_ipaddr) &
+                                         iface->ksni_netmask) == 0) ? 1 : 0;
 
-                        if (!(best_iface == NULL ||
-                              best_netmatch < this_netmatch ||
-                              (best_netmatch == this_netmatch &&
-                               best_nroutes > iface->ksni_nroutes)))
-                                continue;
+                       if (!(best_iface == NULL ||
+                             best_netmatch < this_netmatch ||
+                             (best_netmatch == this_netmatch &&
+                              best_nroutes > iface->ksni_nroutes)))
+                               continue;
 
-                        best_iface = iface;
-                        best_netmatch = this_netmatch;
-                        best_nroutes = iface->ksni_nroutes;
-                }
+                       best_iface = iface;
+                       best_netmatch = this_netmatch;
+                       best_nroutes = iface->ksni_nroutes;
+               }
 
-                if (best_iface == NULL)
-                        continue;
+               if (best_iface == NULL)
+                       continue;
 
-                newroute->ksnr_myipaddr = best_iface->ksni_ipaddr;
-                best_iface->ksni_nroutes++;
+               newroute->ksnr_myiface = best_iface->ksni_index;
+               best_iface->ksni_nroutes++;
 
-                ksocknal_add_route_locked(peer_ni, newroute);
-                newroute = NULL;
-        }
+               ksocknal_add_route_locked(peer_ni, newroute);
+               newroute = NULL;
+       }
 
        write_unlock_bh(global_lock);
-        if (newroute != NULL)
-                ksocknal_route_decref(newroute);
+       if (newroute != NULL)
+               ksocknal_route_decref(newroute);
 }
 
 int
@@ -1889,6 +1948,7 @@ ksocknal_add_interface(struct lnet_ni *ni, __u32 ipaddress, __u32 netmask)
        } else {
                iface = &net->ksnn_interfaces[net->ksnn_ninterfaces++];
 
+               iface->ksni_index = ksocknal_ip2index(ipaddress, ni);
                iface->ksni_ipaddr = ipaddress;
                iface->ksni_netmask = netmask;
                iface->ksni_nroutes = 0;
@@ -1904,7 +1964,8 @@ ksocknal_add_interface(struct lnet_ni *ni, __u32 ipaddress, __u32 netmask)
                                                   struct ksock_route,
                                                   ksnr_list);
 
-                               if (route->ksnr_myipaddr == ipaddress)
+                               if (route->ksnr_myiface ==
+                                           iface->ksni_index)
                                        iface->ksni_nroutes++;
                        }
                }
@@ -1921,7 +1982,8 @@ ksocknal_add_interface(struct lnet_ni *ni, __u32 ipaddress, __u32 netmask)
 }
 
 static void
-ksocknal_peer_del_interface_locked(struct ksock_peer_ni *peer_ni, __u32 ipaddr)
+ksocknal_peer_del_interface_locked(struct ksock_peer_ni *peer_ni,
+                                  __u32 ipaddr, int index)
 {
        struct list_head *tmp;
        struct list_head *nxt;
@@ -1942,16 +2004,16 @@ ksocknal_peer_del_interface_locked(struct ksock_peer_ni *peer_ni, __u32 ipaddr)
        list_for_each_safe(tmp, nxt, &peer_ni->ksnp_routes) {
                route = list_entry(tmp, struct ksock_route, ksnr_list);
 
-                if (route->ksnr_myipaddr != ipaddr)
-                        continue;
+               if (route->ksnr_myiface != index)
+                       continue;
 
-                if (route->ksnr_share_count != 0) {
-                        /* Manually created; keep, but unbind */
-                        route->ksnr_myipaddr = 0;
-                } else {
-                        ksocknal_del_route_locked(route);
-                }
-        }
+               if (route->ksnr_share_count != 0) {
+                       /* Manually created; keep, but unbind */
+                       route->ksnr_myiface = -1;
+               } else {
+                       ksocknal_del_route_locked(route);
+               }
+       }
 
        list_for_each_safe(tmp, nxt, &peer_ni->ksnp_conns) {
                conn = list_entry(tmp, struct ksock_conn, ksnc_list);
@@ -1969,9 +2031,12 @@ ksocknal_del_interface(struct lnet_ni *ni, __u32 ipaddress)
        struct hlist_node *nxt;
        struct ksock_peer_ni *peer_ni;
        u32 this_ip;
+       int index;
        int i;
        int j;
 
+       index = ksocknal_ip2index(ipaddress, ni);
+
        write_lock_bh(&ksocknal_data.ksnd_global_lock);
 
        for (i = 0; i < net->ksnn_ninterfaces; i++) {
@@ -1994,7 +2059,8 @@ ksocknal_del_interface(struct lnet_ni *ni, __u32 ipaddress)
                        if (peer_ni->ksnp_ni != ni)
                                continue;
 
-                       ksocknal_peer_del_interface_locked(peer_ni, this_ip);
+                       ksocknal_peer_del_interface_locked(peer_ni,
+                                                          this_ip, index);
                }
        }
 
index 2b37222..a6b760c 100644 (file)
@@ -108,6 +108,7 @@ struct ksock_sched {
 #define KSOCK_THREAD_SID(id)           ((id) & ((1UL << KSOCK_CPT_SHIFT) - 1))
 
 struct ksock_interface {                       /* in-use interface */
+       int             ksni_index;             /* Linux interface index */
        __u32           ksni_ipaddr;            /* interface's IP address */
        __u32           ksni_netmask;           /* interface's network mask */
        int             ksni_nroutes;           /* # routes using (active) */
@@ -364,21 +365,23 @@ struct ksock_conn {
 };
 
 struct ksock_route {
-       struct list_head   ksnr_list;           /* chain on peer_ni route list */
-       struct list_head   ksnr_connd_list;     /* chain on ksnr_connd_routes */
-       struct ksock_peer_ni *ksnr_peer;        /* owning peer_ni */
-       atomic_t           ksnr_refcount;       /* # users */
-       time64_t           ksnr_timeout;        /* when (in secs) reconnection can happen next */
-       time64_t           ksnr_retry_interval; /* how long between retries */
-        __u32                 ksnr_myipaddr;    /* my IP */
-        __u32                 ksnr_ipaddr;      /* IP address to connect to */
-        int                   ksnr_port;        /* port to connect to */
-        unsigned int          ksnr_scheduled:1; /* scheduled for attention */
-        unsigned int          ksnr_connecting:1;/* connection establishment in progress */
-        unsigned int          ksnr_connected:4; /* connections established by type */
-        unsigned int          ksnr_deleted:1;   /* been removed from peer_ni? */
-        unsigned int          ksnr_share_count; /* created explicitly? */
-        int                   ksnr_conn_count;  /* # conns established by this route */
+       struct list_head        ksnr_list;      /* chain on peer_ni route list*/
+       struct list_head        ksnr_connd_list;/* chain on ksnr_connd_routes */
+       struct ksock_peer_ni   *ksnr_peer;      /* owning peer_ni */
+       atomic_t                ksnr_refcount;  /* # users */
+       time64_t                ksnr_timeout;   /* when (in secs) reconnection
+                                                * can happen next
+                                                */
+       time64_t                ksnr_retry_interval;/* secs between retries */
+       int                     ksnr_myiface;   /* interface index */
+       __u32                   ksnr_ipaddr;    /* IP address to connect to */
+       int                     ksnr_port;      /* port to connect to */
+       unsigned int            ksnr_scheduled:1;/* scheduled for attention */
+       unsigned int            ksnr_connecting:1;/* connection in progress */
+       unsigned int            ksnr_connected:4;/* connections by type */
+       unsigned int            ksnr_deleted:1; /* been removed from peer_ni? */
+       unsigned int            ksnr_share_count;/* created explicitly? */
+       int                     ksnr_conn_count;/* # conns for this route */
 };
 
 #define SOCKNAL_KEEPALIVE_PING          1       /* cookie for keepalive ping */
index e89b911..8439652 100644 (file)
@@ -1986,7 +1986,7 @@ ksocknal_connect(struct ksock_route *route)
                 }
 
                sock = lnet_connect(peer_ni->ksnp_id.nid,
-                                   route->ksnr_myipaddr,
+                                   route->ksnr_myiface,
                                    route->ksnr_ipaddr, route->ksnr_port,
                                    peer_ni->ksnp_ni->ni_net_ns);
                if (IS_ERR(sock)) {
index 09d016f..b11e09d 100644 (file)
@@ -145,7 +145,7 @@ lnet_connect_console_error (int rc, lnet_nid_t peer_nid,
 EXPORT_SYMBOL(lnet_connect_console_error);
 
 struct socket *
-lnet_connect(lnet_nid_t peer_nid, __u32 local_ip, __u32 peer_ip,
+lnet_connect(lnet_nid_t peer_nid, int interface, __u32 peer_ip,
             int peer_port, struct net *ns)
 {
        struct lnet_acceptor_connreq cr;
@@ -160,7 +160,7 @@ lnet_connect(lnet_nid_t peer_nid, __u32 local_ip, __u32 peer_ip,
             --port) {
                /* Iterate through reserved ports. */
 
-               sock = lnet_sock_connect(local_ip, port,
+               sock = lnet_sock_connect(interface, port,
                                         peer_ip, peer_port, ns);
                if (IS_ERR(sock)) {
                        rc = PTR_ERR(sock);
@@ -356,7 +356,7 @@ lnet_acceptor(void *arg)
        LASSERT(lnet_acceptor_state.pta_sock == NULL);
 
        lnet_acceptor_state.pta_sock =
-               lnet_sock_listen(0, accept_port, accept_backlog,
+               lnet_sock_listen(accept_port, accept_backlog,
                                 lnet_acceptor_state.pta_ns);
        if (IS_ERR(lnet_acceptor_state.pta_sock)) {
                rc = PTR_ERR(lnet_acceptor_state.pta_sock);
index 7ffa2e1..7fea399 100644 (file)
@@ -1572,19 +1572,6 @@ lnet_match_networks (char **networksp, char *ip2nets, __u32 *ipaddrs, int nip)
        *networksp = networks;
        return count;
 }
-/*
- * kernel 5.3: commit ef11db3310e272d3d8dbe8739e0770820dd20e52
- * added in_dev_for_each_ifa_rtnl and in_dev_for_each_ifa_rcu
- * and removed for_ifa and endfor_ifa.
- * Use the _rntl variant as the current locking is rtnl.
- */
-#ifdef in_dev_for_each_ifa_rtnl
-#define DECLARE_CONST_IN_IFADDR(ifa)           const struct in_ifaddr *ifa
-#define endfor_ifa(in_dev)
-#else
-#define DECLARE_CONST_IN_IFADDR(ifa)
-#define in_dev_for_each_ifa_rtnl(ifa, in_dev)  for_ifa((in_dev))
-#endif
 
 int lnet_inet_enumerate(struct lnet_inetdev **dev_list, struct net *ns)
 {
index 7c7406c..b801b47 100644 (file)
@@ -39,6 +39,7 @@
 /* For sys_open & sys_close */
 #include <linux/syscalls.h>
 #include <net/sock.h>
+#include <linux/inetdevice.h>
 
 #include <libcfs/libcfs.h>
 #include <lnet/lib-lnet.h>
@@ -175,10 +176,44 @@ lnet_sock_read(struct socket *sock, void *buffer, int nob, int timeout)
 }
 EXPORT_SYMBOL(lnet_sock_read);
 
+int choose_ipv4_src(__u32 *ret, int interface, __u32 dst_ipaddr, struct net *ns)
+{
+       struct net_device *dev;
+       struct in_device *in_dev;
+       int err;
+       DECLARE_CONST_IN_IFADDR(ifa);
+
+       rcu_read_lock();
+       dev = dev_get_by_index_rcu(ns, interface);
+       err = -EINVAL;
+       if (!dev || !(dev->flags & IFF_UP))
+               goto out;
+       in_dev = __in_dev_get_rcu(dev);
+       if (!in_dev)
+               goto out;
+       err = -ENOENT;
+       in_dev_for_each_ifa_rcu(ifa, in_dev) {
+               if (*ret == 0 ||
+                   ((dst_ipaddr ^ ntohl(ifa->ifa_local))
+                    & ntohl(ifa->ifa_mask)) == 0) {
+                       /* This address at least as good as what we
+                        * already have
+                        */
+                       *ret = ntohl(ifa->ifa_local);
+                       err = 0;
+               }
+       }
+       endfor_ifa(in_dev);
+out:
+       rcu_read_unlock();
+       return err;
+}
+EXPORT_SYMBOL(choose_ipv4_src);
+
 static struct socket *
-lnet_sock_create(__u32 local_ip, int local_port, struct net *ns)
+lnet_sock_create(int interface, struct sockaddr *remaddr,
+                int local_port, struct net *ns)
 {
-       struct sockaddr_in  locaddr;
        struct socket      *sock;
        int                 rc;
        int                 option;
@@ -201,12 +236,25 @@ lnet_sock_create(__u32 local_ip, int local_port, struct net *ns)
                goto failed;
        }
 
-       if (local_ip != 0 || local_port != 0) {
-               memset(&locaddr, 0, sizeof(locaddr));
+       if (interface >= 0 || local_port != 0) {
+               struct sockaddr_in locaddr = {};
+
                locaddr.sin_family = AF_INET;
+               locaddr.sin_addr.s_addr = INADDR_ANY;
+               if (interface >= 0) {
+                       struct sockaddr_in *sin = (void *)remaddr;
+                       __u32 ip;
+
+                       rc = choose_ipv4_src(&ip,
+                                            interface,
+                                            ntohl(sin->sin_addr.s_addr),
+                                            ns);
+                       if (rc)
+                               goto failed;
+                       locaddr.sin_addr.s_addr = htonl(ip);
+               }
+
                locaddr.sin_port = htons(local_port);
-               locaddr.sin_addr.s_addr = (local_ip == 0) ?
-                                         INADDR_ANY : htonl(local_ip);
 
                rc = kernel_bind(sock, (struct sockaddr *)&locaddr,
                                 sizeof(locaddr));
@@ -303,12 +351,12 @@ lnet_sock_getbuf(struct socket *sock, int *txbufsize, int *rxbufsize)
 EXPORT_SYMBOL(lnet_sock_getbuf);
 
 struct socket *
-lnet_sock_listen(__u32 local_ip, int local_port, int backlog, struct net *ns)
+lnet_sock_listen(int local_port, int backlog, struct net *ns)
 {
        struct socket *sock;
        int rc;
 
-       sock = lnet_sock_create(local_ip, local_port, ns);
+       sock = lnet_sock_create(-1, NULL, local_port, ns);
        if (IS_ERR(sock)) {
                rc = PTR_ERR(sock);
                if (rc == -EADDRINUSE)
@@ -327,7 +375,7 @@ lnet_sock_listen(__u32 local_ip, int local_port, int backlog, struct net *ns)
 }
 
 struct socket *
-lnet_sock_connect(__u32 local_ip, int local_port,
+lnet_sock_connect(int interface, int local_port,
                  __u32 peer_ip, int peer_port,
                  struct net *ns)
 {
@@ -335,15 +383,16 @@ lnet_sock_connect(__u32 local_ip, int local_port,
        struct sockaddr_in srvaddr;
        int rc;
 
-       sock = lnet_sock_create(local_ip, local_port, ns);
-       if (IS_ERR(sock))
-               return sock;
-
        memset(&srvaddr, 0, sizeof(srvaddr));
        srvaddr.sin_family = AF_INET;
        srvaddr.sin_port = htons(peer_port);
        srvaddr.sin_addr.s_addr = htonl(peer_ip);
 
+       sock = lnet_sock_create(interface, (struct sockaddr *)&srvaddr,
+                               local_port, ns);
+       if (IS_ERR(sock))
+               return sock;
+
        rc = kernel_connect(sock, (struct sockaddr *)&srvaddr,
                            sizeof(srvaddr), 0);
        if (rc == 0)
@@ -355,8 +404,8 @@ lnet_sock_connect(__u32 local_ip, int local_port,
         * port... */
 
        CDEBUG_LIMIT(rc == -EADDRNOTAVAIL ? D_NET : D_NETERROR,
-                    "Error %d connecting %pI4h/%d -> %pI4h/%d\n", rc,
-                    &local_ip, local_port, &peer_ip, peer_port);
+                    "Error %d connecting %d -> %pI4h/%d\n", rc,
+                    local_port, &peer_ip, peer_port);
 
        sock_release(sock);
        return ERR_PTR(rc);