Whamcloud - gitweb
LU-17015 gss: support large kerberos token for rpc sec init
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
index c2ef07f..a21b81c 100644 (file)
 #include "sk_utils.h"
 #include <sys/time.h>
 #include <gssapi/gssapi_krb5.h>
+#include <libcfs/util/param.h>
 
 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
 
-#define TOKEN_BUF_SIZE         8192
-
 struct svc_cred {
        uint32_t cr_remote;
        uint32_t cr_usr_root;
@@ -142,46 +141,88 @@ struct gss_verifier {
 
 #define RPCSEC_GSS_SEQ_WIN     5
 
-static int
-send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
-             u_int32_t maj_stat, u_int32_t min_stat,
-             gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
+static int send_response(int auth_res, uint64_t hash,
+                       gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
+                       u_int32_t maj_stat, u_int32_t min_stat,
+                       gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
 {
-       char buf[2 * TOKEN_BUF_SIZE];
-       char *bp = buf;
-       int blen = sizeof(buf);
-       /* XXXARG: */
-       int g;
+       struct rsi_downcall_data *rsi_dd;
+       int blen, fd, size, rc = 0;
+       glob_t path;
+       char *bp;
 
        printerr(LL_INFO, "sending reply\n");
-       qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
-       qword_addhex(&bp, &blen, in_token->value, in_token->length);
-       qword_addint(&bp, &blen, time(NULL) + 3600);   /* 1 hour should be ok */
-       qword_adduint(&bp, &blen, maj_stat);
-       qword_adduint(&bp, &blen, min_stat);
-       qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
-       qword_addhex(&bp, &blen, out_token->value, out_token->length);
-       qword_addeol(&bp, &blen);
-       if (blen <= 0) {
-               printerr(LL_ERR, "ERROR: %s: message too long\n", __func__);
-               return -1;
+
+       size = in_handle->length + sizeof(__u32) +
+               in_token->length + sizeof(__u32) +
+               sizeof(__u32) + sizeof(__u32);
+       if (!auth_res)
+               size += out_handle->length + out_token->length;
+       blen = size;
+
+       size += offsetof(struct rsi_downcall_data, sid_val[0]);
+       rsi_dd = calloc(1, size);
+       if (!rsi_dd) {
+               printerr(LL_ERR, "malloc downcall data (%d) failed\n", size);
+               return -ENOMEM;
+       }
+       rsi_dd->sid_magic = RSI_DOWNCALL_MAGIC;
+       rsi_dd->sid_hash = hash;
+       rsi_dd->sid_maj_stat = maj_stat;
+       rsi_dd->sid_min_stat = min_stat;
+
+       bp = rsi_dd->sid_val;
+       gss_buffer_write(&bp, &blen, in_handle->value, in_handle->length);
+       gss_buffer_write(&bp, &blen, in_token->value, in_token->length);
+       if (!auth_res) {
+               gss_buffer_write(&bp, &blen, out_handle->value,
+                                out_handle->length);
+               gss_buffer_write(&bp, &blen, out_token->value,
+                                out_token->length);
+       } else {
+               rsi_dd->sid_err = -EACCES;
+               gss_buffer_write(&bp, &blen, NULL, 0);
+               gss_buffer_write(&bp, &blen, NULL, 0);
        }
-       g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
-       if (g == -1) {
+       if (blen < 0) {
+               printerr(LL_ERR, "ERROR: %s: message too long > %d\n",
+                        __func__, size);
+               rc = -EMSGSIZE;
+               goto out;
+       }
+       rsi_dd->sid_len = bp - rsi_dd->sid_val;
+
+       rc = cfs_get_param_paths(&path, RSI_DOWNCALL_PATH);
+       if (rc != 0) {
+               rc = -errno;
+               printerr(LL_ERR, "ERROR: %s: cannot get param path %s: %s\n",
+                        __func__, RSI_DOWNCALL_PATH, strerror(-rc));
+               goto out;
+       }
+
+       fd = open(path.gl_pathv[0], O_WRONLY);
+       if (fd == -1) {
+               rc = -errno;
                printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
-                        __func__, SVCGSSD_INIT_CHANNEL, strerror(errno));
-               return -1;
+                        __func__, RSI_DOWNCALL_PATH, strerror(-rc));
+               goto out_path;
        }
-       *bp = '\0';
-       printerr(LL_DEBUG, "writing message: %s", buf);
-       if (write(g, buf, bp - buf) == -1) {
-               printerr(LL_ERR, "ERROR: %s: failed to write message\n",
-                        __func__);
-               close(g);
-               return -1;
+       size = offsetof(struct rsi_downcall_data,
+                       sid_val[bp - rsi_dd->sid_val]);
+       printerr(LL_DEBUG, "writing response, size %d\n", size);
+       if (write(fd, rsi_dd, size) == -1) {
+               rc = -errno;
+               printerr(LL_ERR, "ERROR: %s: failed to write message: %s\n",
+                        __func__, strerror(-rc));
        }
-       close(g);
-       return 0;
+       printerr(LL_DEBUG, "response written ok\n");
+
+       close(fd);
+out_path:
+       cfs_free_param_data(&path);
+out:
+       free(rsi_dd);
+       return rc;
 }
 
 #define rpc_auth_ok                    0
@@ -623,10 +664,17 @@ redo:
 
        do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
 
-       /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
-       free(remote_pub_key.value);
-       free(snd->ctx_token.value);
-       snd->ctx_token.length = 0;
+       /* cleanup ctx_token, out_tok is cleaned up in handle_channel_request */
+       if (remote_pub_key.length != 0) {
+               free(remote_pub_key.value);
+               remote_pub_key.value = NULL;
+               remote_pub_key.length = 0;
+       }
+       if (snd->ctx_token.value) {
+               free(snd->ctx_token.value);
+               snd->ctx_token.value = NULL;
+               snd->ctx_token.length = 0;
+       }
 
        printerr(LL_DEBUG, "sk returning success\n");
        return 0;
@@ -643,7 +691,11 @@ cleanup_partial:
        free(bufs[SK_INIT_RANDOM].value);
        free(bufs[SK_INIT_TARGET].value);
        free(bufs[SK_INIT_FLAGS].value);
-       free(remote_pub_key.value);
+       if (remote_pub_key.length != 0) {
+               free(remote_pub_key.value);
+               remote_pub_key.value = NULL;
+               remote_pub_key.length = 0;
+       }
        sk_free_cred(skc);
        snd->maj_stat = rc;
        return -1;
@@ -652,10 +704,14 @@ out_err:
        snd->maj_stat = rc;
        if (snd->ctx_token.value) {
                free(snd->ctx_token.value);
-               snd->ctx_token.value = 0;
+               snd->ctx_token.value = NULL;
                snd->ctx_token.length = 0;
        }
-       free(remote_pub_key.value);
+       if (remote_pub_key.length != 0) {
+               free(remote_pub_key.value);
+               remote_pub_key.value = NULL;
+               remote_pub_key.length = 0;
+       }
        sk_free_cred(skc);
        printerr(LL_DEBUG, "sk returning failure\n");
 #else /* !HAVE_OPENSSL_SSK */
@@ -786,7 +842,12 @@ static int handle_krb(struct svc_nego_data *snd)
        /* We no longer need the gss context */
        gss_delete_sec_context(&ignore_min_stat, &snd->ctx, &ignore_out_tok);
        do_svc_downcall(&snd->out_handle, &cred, mech, &snd->ctx_token);
-
+       /* We no longer need the context token */
+       if (snd->ctx_token.value) {
+               free(snd->ctx_token.value);
+               snd->ctx_token.value = NULL;
+               snd->ctx_token.length = 0;
+       }
        return 0;
 
 out_err:
@@ -797,46 +858,48 @@ out_err:
        return 1;
 }
 
-/*
- * return -1 only if we detect error during reading from upcall channel,
- * all other cases return 0.
- */
-int handle_channel_request(FILE *f)
+int handle_channel_request(int fd)
 {
-       char                    in_tok_buf[TOKEN_BUF_SIZE];
-       char                    in_handle_buf[15];
-       char                    out_handle_buf[15];
-       gss_buffer_desc         ctx_token      = {.value = NULL},
-                               null_token     = {.value = NULL};
-       uint32_t                lustre_mech;
-       static char             *lbuf;
-       static int              lbuflen;
-       static char             *cp;
-       int                     get_len;
-       int                     rc = 1;
-       u_int32_t               ignore_min_stat;
-       struct svc_nego_data    snd = {
-               .in_tok.value           = in_tok_buf,
+       char in_handle_buf[15];
+       char out_handle_buf[15];
+       uint32_t lustre_mech;
+       static char *lbuf;
+       static int lbuflen;
+       static char *cp;
+       int get_len;
+       int rc;
+       u_int32_t ignore_min_stat;
+       struct svc_nego_data snd = {
+               .in_tok.value           = NULL,
                .in_handle.value        = in_handle_buf,
                .out_handle.value       = out_handle_buf,
                .maj_stat               = GSS_S_FAILURE,
                .ctx                    = GSS_C_NO_CONTEXT,
        };
+       uint64_t hash = 0;
 
        printerr(LL_INFO, "handling request\n");
-       if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
+       if (readline(fd, &lbuf, &lbuflen) != 1) {
                printerr(LL_ERR, "ERROR: failed reading request\n");
                return -1;
        }
 
        cp = lbuf;
 
-       /* see rsi_request() for the format of data being input here */
-       qword_get(&cp, (char *)&snd.lustre_svc, sizeof(snd.lustre_svc));
-
+       /* see rsi_do_upcall() for the format of data being input here */
+       rc = gss_u64_read_string(&cp, (__u64 *)&hash);
+       if (rc < 0) {
+               printerr(LL_ERR, "ERROR: failed parsing request: hash\n");
+               goto out_err;
+       }
+       rc = gss_u64_read_string(&cp, (__u64 *)&snd.lustre_svc);
+       if (rc < 0) {
+               printerr(LL_ERR, "ERROR: failed parsing request: lustre svc\n");
+               goto out_err;
+       }
        /* lustre_svc is the svc and gss subflavor */
        lustre_mech = (snd.lustre_svc & LUSTRE_GSS_MECH_MASK) >>
-                     LUSTRE_GSS_MECH_SHIFT;
+               LUSTRE_GSS_MECH_SHIFT;
        snd.lustre_svc = snd.lustre_svc & LUSTRE_GSS_SVC_MASK;
        switch (lustre_mech) {
        case LGSS_MECH_KRB5:
@@ -892,16 +955,31 @@ int handle_channel_request(FILE *f)
                break;
        }
 
-       qword_get(&cp, (char *)&snd.nid, sizeof(snd.nid));
-       qword_get(&cp, (char *)&snd.handle_seq, sizeof(snd.handle_seq));
-       qword_get(&cp, snd.nm_name, sizeof(snd.nm_name));
+       rc = gss_u64_read_string(&cp, (__u64 *)&snd.nid);
+       if (rc < 0) {
+               printerr(LL_ERR, "ERROR: failed parsing request: source nid\n");
+               goto out_err;
+       }
+       rc = gss_u64_read_string(&cp, (__u64 *)&snd.handle_seq);
+       if (rc < 0) {
+               printerr(LL_ERR, "ERROR: failed parsing request: handle seq\n");
+               goto out_err;
+       }
+       get_len = gss_string_read(&cp, snd.nm_name, sizeof(snd.nm_name), 0);
+       if (get_len <= 0) {
+               printerr(LL_ERR,
+                        "ERROR: failed parsing request: nodemap name\n");
+               goto out_err;
+       }
+       snd.nm_name[get_len] = '\0';
        printerr(LL_INFO,
                 "handling req: svc %u, nid %016llx, idx %"PRIx64" nodemap %s\n",
                 snd.lustre_svc, snd.nid, snd.handle_seq, snd.nm_name);
 
-       get_len = qword_get(&cp, snd.in_handle.value, sizeof(in_handle_buf));
+       get_len = gss_base64url_decode(&cp, snd.in_handle.value,
+                                      sizeof(in_handle_buf));
        if (get_len < 0) {
-               printerr(LL_ERR, "ERROR: failed parsing request\n");
+               printerr(LL_ERR, "ERROR: failed parsing request: in handle\n");
                goto out_err;
        }
        snd.in_handle.length = (size_t)get_len;
@@ -909,9 +987,14 @@ int handle_channel_request(FILE *f)
        printerr(LL_DEBUG, "in_handle:\n");
        print_hexl(3, snd.in_handle.value, snd.in_handle.length);
 
-       get_len = qword_get(&cp, snd.in_tok.value, sizeof(in_tok_buf));
+       snd.in_tok.value = malloc(strlen(cp));
+       if (!snd.in_tok.value) {
+               printerr(LL_ERR, "ERROR: failed alloc for in token\n");
+               goto out_err;
+       }
+       get_len = gss_base64url_decode(&cp, snd.in_tok.value, strlen(cp));
        if (get_len < 0) {
-               printerr(LL_ERR, "ERROR: failed parsing request\n");
+               printerr(LL_ERR, "ERROR: failed parsing request: in token\n");
                goto out_err;
        }
        snd.in_tok.length = (size_t)get_len;
@@ -931,6 +1014,7 @@ int handle_channel_request(FILE *f)
                memcpy(&snd.ctx, snd.in_handle.value, snd.in_handle.length);
        }
 
+       rc = -1;
        if (lustre_mech == LGSS_MECH_KRB5)
                rc = handle_krb(&snd);
        else if (lustre_mech == LGSS_MECH_SK)
@@ -944,20 +1028,17 @@ int handle_channel_request(FILE *f)
 
 out_err:
        /* Failures send a null token */
-       if (rc == 0)
-               send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
-                             snd.min_stat, &snd.out_handle, &snd.out_tok);
-       else
-               send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
-                             snd.min_stat, &null_token, &null_token);
+       rc = send_response(rc, hash, &snd.in_handle, &snd.in_tok,
+                          snd.maj_stat, snd.min_stat,
+                          &snd.out_handle, &snd.out_tok);
 
        /* cleanup buffers */
-       if (snd.ctx_token.value != NULL)
-               free(ctx_token.value);
+       if (snd.in_tok.value)
+               free(snd.in_tok.value);
        if (snd.out_tok.value != NULL)
                gss_release_buffer(&ignore_min_stat, &snd.out_tok);
 
        /* For junk wire data just ignore */
 ignore:
-       return 0;
+       return rc;
 }