Whamcloud - gitweb
LU-3289 gss: Fix for SK bulk HMACs
[fs/lustre-release.git] / lustre / ptlrpc / gss / gss_sk_mech.c
index 1cb3645..39d3e76 100644 (file)
@@ -267,16 +267,14 @@ __u32 gss_inquire_context_sk(struct gss_ctx *gss_context,
 }
 
 static
-__u32 sk_make_checksum(char *alg_name, rawobj_t *key,
-                      int msg_count, rawobj_t *msgs,
-                      int iov_count, lnet_kiov_t *iovs,
-                      rawobj_t *token)
+__u32 sk_make_hmac(char *alg_name, rawobj_t *key, int msg_count, rawobj_t *msgs,
+                  int iov_count, lnet_kiov_t *iovs, rawobj_t *token)
 {
        struct crypto_hash *tfm;
        int rc;
 
        tfm = crypto_alloc_hash(alg_name, 0, 0);
-       if (!tfm)
+       if (IS_ERR(tfm))
                return GSS_S_FAILURE;
 
        rc = GSS_S_FAILURE;
@@ -298,18 +296,14 @@ __u32 gss_get_mic_sk(struct gss_ctx *gss_context,
                     rawobj_t *token)
 {
        struct sk_ctx *skc = gss_context->internal_ctx_id;
-       return sk_make_checksum(sk_hmac_types[skc->sc_hmac].sht_name,
-                               &skc->sc_shared_key, message_count, messages,
-                               iov_count, iovs, token);
+       return sk_make_hmac(sk_hmac_types[skc->sc_hmac].sht_name,
+                           &skc->sc_shared_key, message_count, messages,
+                           iov_count, iovs, token);
 }
 
 static
-__u32 sk_verify_checksum(struct sk_hmac_type *sht,
-                        rawobj_t *key,
-                        int message_count,
-                        rawobj_t *messages,
-                        int iov_count,
-                        lnet_kiov_t *iovs,
+__u32 sk_verify_hmac(struct sk_hmac_type *sht, rawobj_t *key, int message_count,
+                        rawobj_t *messages, int iov_count, lnet_kiov_t *iovs,
                         rawobj_t *token)
 {
        rawobj_t checksum = RAWOBJ_EMPTY;
@@ -326,8 +320,8 @@ __u32 sk_verify_checksum(struct sk_hmac_type *sht,
        if (!checksum.data)
                return rc;
 
-       if (sk_make_checksum(sht->sht_name, key, message_count,
-                            messages, iov_count, iovs, &checksum)) {
+       if (sk_make_hmac(sht->sht_name, key, message_count, messages,
+                        iov_count, iovs, &checksum)) {
                CDEBUG(D_SEC, "Failed to create checksum to validate\n");
                goto cleanup;
        }
@@ -345,6 +339,104 @@ cleanup:
        return rc;
 }
 
+/* sk_verify_bulk_hmac() differs slightly from sk_verify_hmac() because all
+ * encrypted pages in the bulk descriptor are populated although we only need
+ * to decrypt up to the number of bytes actually specified from the sender
+ * (bd_nob) otherwise the calulated HMAC will be incorrect. */
+static
+__u32 sk_verify_bulk_hmac(struct sk_hmac_type *sht, rawobj_t *key,
+                         int msgcnt, rawobj_t *msgs, int iovcnt,
+                         lnet_kiov_t *iovs, int iov_bytes, rawobj_t *token)
+{
+       rawobj_t checksum = RAWOBJ_EMPTY;
+       struct crypto_hash *tfm;
+       struct hash_desc desc = {
+               .tfm = NULL,
+               .flags = 0,
+       };
+       struct scatterlist sg[1];
+       struct sg_table sgt;
+       int bytes;
+       int i;
+       int rc = GSS_S_FAILURE;
+
+       checksum.len = sht->sht_bytes;
+       if (token->len < checksum.len) {
+               CDEBUG(D_SEC, "Token received too short, expected %d "
+                      "received %d\n", token->len, checksum.len);
+               return GSS_S_DEFECTIVE_TOKEN;
+       }
+
+       OBD_ALLOC_LARGE(checksum.data, checksum.len);
+       if (!checksum.data)
+               return rc;
+
+       tfm = crypto_alloc_hash(sht->sht_name, 0, 0);
+       if (IS_ERR(tfm))
+               goto cleanup;
+
+       desc.tfm = tfm;
+
+       LASSERT(token->len >= crypto_hash_digestsize(tfm));
+
+       rc = crypto_hash_setkey(tfm, key->data, key->len);
+       if (rc)
+               goto hash_cleanup;
+
+       rc = crypto_hash_init(&desc);
+       if (rc)
+               goto hash_cleanup;
+
+       for (i = 0; i < msgcnt; i++) {
+               if (msgs[i].len == 0)
+                       continue;
+
+               rc = gss_setup_sgtable(&sgt, sg, msgs[i].data, msgs[i].len);
+               if (rc != 0)
+                       goto hash_cleanup;
+
+               rc = crypto_hash_update(&desc, sg, msgs[i].len);
+               if (rc) {
+                       gss_teardown_sgtable(&sgt);
+                       goto hash_cleanup;
+               }
+
+               gss_teardown_sgtable(&sgt);
+       }
+
+       for (i = 0; i < iovcnt && iov_bytes > 0; i++) {
+               if (iovs[i].kiov_len == 0)
+                       continue;
+
+               bytes = min_t(int, iov_bytes, iovs[i].kiov_len);
+               iov_bytes -= bytes;
+
+               sg_init_table(sg, 1);
+               sg_set_page(&sg[0], iovs[i].kiov_page, bytes,
+                           iovs[i].kiov_offset);
+               rc = crypto_hash_update(&desc, sg, bytes);
+               if (rc)
+                       goto hash_cleanup;
+       }
+
+       crypto_hash_final(&desc, checksum.data);
+
+       if (memcmp(token->data, checksum.data, checksum.len)) {
+               rc = GSS_S_BAD_SIG;
+               goto hash_cleanup;
+       }
+
+       rc = GSS_S_COMPLETE;
+
+hash_cleanup:
+       crypto_free_hash(tfm);
+
+cleanup:
+       OBD_FREE_LARGE(checksum.data, checksum.len);
+
+       return rc;
+}
+
 static
 __u32 gss_verify_mic_sk(struct gss_ctx *gss_context,
                        int message_count,
@@ -354,9 +446,8 @@ __u32 gss_verify_mic_sk(struct gss_ctx *gss_context,
                        rawobj_t *token)
 {
        struct sk_ctx *skc = gss_context->internal_ctx_id;
-       return sk_verify_checksum(&sk_hmac_types[skc->sc_hmac],
-                                 &skc->sc_shared_key, message_count, messages,
-                                 iov_count, iovs, token);
+       return sk_verify_hmac(&sk_hmac_types[skc->sc_hmac], &skc->sc_shared_key,
+                             message_count, messages, iov_count, iovs, token);
 }
 
 static
@@ -393,8 +484,8 @@ __u32 gss_wrap_sk(struct gss_ctx *gss_context, rawobj_t *gss_header,
        LASSERT(cipher.len + sht->sht_bytes <= token->len);
        checksum.data = token->data + cipher.len;
        checksum.len = sht->sht_bytes;
-       if (sk_make_checksum(sht->sht_name, &skc->sc_shared_key, 2, msgbufs, 0,
-                            NULL, &checksum))
+       if (sk_make_hmac(sht->sht_name, &skc->sc_shared_key, 2, msgbufs, 0,
+                        NULL, &checksum))
                return GSS_S_FAILURE;
 
        token->len = cipher.len + checksum.len;
@@ -433,8 +524,8 @@ __u32 gss_unwrap_sk(struct gss_ctx *gss_context, rawobj_t *gss_header,
        msgbufs[0].data = gss_header->data;
        msgbufs[1].len = cipher.len;
        msgbufs[1].data = cipher.data;
-       rc = sk_verify_checksum(sht, &skc->sc_shared_key, 2, msgbufs, 0, NULL,
-                              &checksum);
+       rc = sk_verify_hmac(sht, &skc->sc_shared_key, 2, msgbufs, 0, NULL,
+                           &checksum);
        if (rc)
                return rc;
 
@@ -551,12 +642,13 @@ static __u32 sk_decrypt_bulk(struct crypto_blkcipher *tfm,
                return GSS_S_DEFECTIVE_TOKEN;
        }
 
-       for (i = 0; i < desc->bd_iov_count; i++) {
+       for (i = 0; i < desc->bd_iov_count && cnob < desc->bd_nob_transferred;
+            i++) {
                lnet_kiov_t *piov = &BD_GET_KIOV(desc, i);
                lnet_kiov_t *ciov = &BD_GET_ENC_KIOV(desc, i);
 
-               if (piov->kiov_offset % blocksize != 0 ||
-                   piov->kiov_len % blocksize != 0) {
+               if (ciov->kiov_offset % blocksize != 0 ||
+                   ciov->kiov_len % blocksize != 0) {
                        CERROR("Invalid bulk descriptor vector\n");
                        return GSS_S_DEFECTIVE_TOKEN;
                }
@@ -650,6 +742,7 @@ __u32 gss_wrap_bulk_sk(struct gss_ctx *gss_context,
 
        cipher.data = token->data;
        cipher.len = token->len - sht->sht_bytes;
+       memset(token->data, 0, token->len);
 
        if (sk_encrypt_bulk(skc->sc_session_kb.kb_tfm, desc, &cipher, adj_nob))
                return GSS_S_FAILURE;
@@ -657,8 +750,8 @@ __u32 gss_wrap_bulk_sk(struct gss_ctx *gss_context,
        checksum.data = token->data + cipher.len;
        checksum.len = sht->sht_bytes;
 
-       if (sk_make_checksum(sht->sht_name, &skc->sc_shared_key, 1, &cipher, 0,
-                            NULL, &checksum))
+       if (sk_make_hmac(sht->sht_name, &skc->sc_shared_key, 1, &cipher,
+                        desc->bd_iov_count, GET_ENC_KIOV(desc), &checksum))
                return GSS_S_FAILURE;
 
        return GSS_S_COMPLETE;
@@ -680,9 +773,10 @@ __u32 gss_unwrap_bulk_sk(struct gss_ctx *gss_context,
        checksum.data = token->data + cipher.len;
        checksum.len = sht->sht_bytes;
 
-       rc = sk_verify_checksum(&sk_hmac_types[skc->sc_hmac],
-                               &skc->sc_shared_key, 1, &cipher, 0, NULL,
-                               &checksum);
+       rc = sk_verify_bulk_hmac(&sk_hmac_types[skc->sc_hmac],
+                                &skc->sc_shared_key, 1, &cipher,
+                                desc->bd_iov_count, GET_ENC_KIOV(desc),
+                                desc->bd_nob, &checksum);
        if (rc)
                return rc;