Whamcloud - gitweb
LU-10937 mgc: restore mgc binding for sptlrpc
[fs/lustre-release.git] / lustre / ptlrpc / gss / gss_sk_mech.c
1 /*
2  * GPL HEADER START
3  *
4  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 only,
8  * as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope that it will be useful, but
11  * WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * General Public License version 2 for more details (a copy is included
14  * in the LICENSE file that accompanied this code).
15  *
16  * You should have received a copy of the GNU General Public License
17  * version 2 along with this program; If not, see
18  * http://www.gnu.org/licenses/gpl-2.0.html
19  *
20  * GPL HEADER END
21  */
22 /*
23  * Copyright (C) 2013, 2015, Trustees of Indiana University
24  *
25  * Copyright (c) 2014, 2016, Intel Corporation.
26  *
27  * Author: Jeremy Filizetti <jfilizet@iu.edu>
28  * Author: Andrew Korty <ajk@iu.edu>
29  */
30
31 #define DEBUG_SUBSYSTEM S_SEC
32 #include <linux/init.h>
33 #include <linux/module.h>
34 #include <linux/slab.h>
35 #include <linux/crypto.h>
36 #include <linux/mutex.h>
37 #include <crypto/ctr.h>
38
39 #include <obd.h>
40 #include <obd_class.h>
41 #include <obd_support.h>
42
43 #include "gss_err.h"
44 #include "gss_crypto.h"
45 #include "gss_internal.h"
46 #include "gss_api.h"
47 #include "gss_asn1.h"
48
49 #define SK_INTERFACE_VERSION 1
50 #define SK_MSG_VERSION 1
51 #define SK_MIN_SIZE 8
52 #define SK_IV_SIZE 16
53
54 /* Starting number for reverse contexts.  It is critical to security
55  * that reverse contexts use a different range of numbers than regular
56  * contexts because they are using the same key.  Therefore the IV/nonce
57  * combination must be unique for them.  To accomplish this reverse contexts
58  * use the the negative range of a 64-bit number and regular contexts use the
59  * postive range.  If the same IV/nonce combination were reused it would leak
60  * information about the plaintext. */
61 #define SK_IV_REV_START (1ULL << 63)
62
63 struct sk_ctx {
64         __u16                   sc_hmac;
65         __u16                   sc_crypt;
66         __u32                   sc_expire;
67         __u32                   sc_host_random;
68         __u32                   sc_peer_random;
69         atomic64_t              sc_iv;
70         rawobj_t                sc_hmac_key;
71         struct gss_keyblock     sc_session_kb;
72 };
73
74 struct sk_hdr {
75         __u64                   skh_version;
76         __u64                   skh_iv;
77 } __attribute__((packed));
78
79 /* The format of SK wire data is similar to that of RFC3686 ESP Payload
80  * (section 3) except instead of just an IV there is a struct sk_hdr.
81  * ---------------------------------------------------------------------
82  * | struct sk_hdr | ciphertext (variable size) | HMAC (variable size) |
83  * --------------------------------------------------------------------- */
84 struct sk_wire {
85         rawobj_t                skw_header;
86         rawobj_t                skw_cipher;
87         rawobj_t                skw_hmac;
88 };
89
90 static struct sk_crypt_type sk_crypt_types[] = {
91         [SK_CRYPT_AES256_CTR] = {
92                 .sct_name = "ctr(aes)",
93                 .sct_bytes = 32,
94         },
95 };
96
97 static struct sk_hmac_type sk_hmac_types[] = {
98         [SK_HMAC_SHA256] = {
99                 .sht_name = "hmac(sha256)",
100                 .sht_bytes = 32,
101         },
102         [SK_HMAC_SHA512] = {
103                 .sht_name = "hmac(sha512)",
104                 .sht_bytes = 64,
105         },
106 };
107
108 static inline unsigned long sk_block_mask(unsigned long len, int blocksize)
109 {
110         return (len + blocksize - 1) & (~(blocksize - 1));
111 }
112
113 static int sk_fill_header(struct sk_ctx *skc, struct sk_hdr *skh)
114 {
115         __u64 tmp_iv;
116         skh->skh_version = be64_to_cpu(SK_MSG_VERSION);
117
118         /* Always using inc_return so we don't use our initial numbers which
119          * could be the reuse detecting numbers */
120         tmp_iv = atomic64_inc_return(&skc->sc_iv);
121         skh->skh_iv = be64_to_cpu(tmp_iv);
122         if (tmp_iv == 0 || tmp_iv == SK_IV_REV_START) {
123                 CERROR("Counter looped, connection must be reset to avoid "
124                        "plaintext information\n");
125                 return GSS_S_FAILURE;
126         }
127
128         return GSS_S_COMPLETE;
129 }
130
131 static int sk_verify_header(struct sk_hdr *skh)
132 {
133         if (cpu_to_be64(skh->skh_version) != SK_MSG_VERSION)
134                 return GSS_S_DEFECTIVE_TOKEN;
135
136         return GSS_S_COMPLETE;
137 }
138
139 void sk_construct_rfc3686_iv(__u8 *iv, __u32 nonce, __u64 partial_iv)
140 {
141         __u32 ctr = cpu_to_be32(1);
142
143         memcpy(iv, &nonce, CTR_RFC3686_NONCE_SIZE);
144         iv += CTR_RFC3686_NONCE_SIZE;
145         memcpy(iv, &partial_iv, CTR_RFC3686_IV_SIZE);
146         iv += CTR_RFC3686_IV_SIZE;
147         memcpy(iv, &ctr, sizeof(ctr));
148 }
149
150 static int sk_init_keys(struct sk_ctx *skc)
151 {
152         return gss_keyblock_init(&skc->sc_session_kb,
153                                  sk_crypt_types[skc->sc_crypt].sct_name, 0);
154 }
155
156 static int sk_fill_context(rawobj_t *inbuf, struct sk_ctx *skc)
157 {
158         char *ptr = inbuf->data;
159         char *end = inbuf->data + inbuf->len;
160         __u32 tmp;
161
162         /* see sk_serialize_kctx() for format from userspace side */
163         /*  1. Version */
164         if (gss_get_bytes(&ptr, end, &tmp, sizeof(tmp))) {
165                 CERROR("Failed to read shared key interface version");
166                 return -1;
167         }
168         if (tmp != SK_INTERFACE_VERSION) {
169                 CERROR("Invalid shared key interface version: %d\n", tmp);
170                 return -1;
171         }
172
173         /* 2. HMAC type */
174         if (gss_get_bytes(&ptr, end, &skc->sc_hmac, sizeof(skc->sc_hmac))) {
175                 CERROR("Failed to read HMAC algorithm type");
176                 return -1;
177         }
178         if (skc->sc_hmac <= SK_HMAC_EMPTY || skc->sc_hmac >= SK_HMAC_MAX) {
179                 CERROR("Invalid hmac type: %d\n", skc->sc_hmac);
180                 return -1;
181         }
182
183         /* 3. crypt type */
184         if (gss_get_bytes(&ptr, end, &skc->sc_crypt, sizeof(skc->sc_crypt))) {
185                 CERROR("Failed to read crypt algorithm type");
186                 return -1;
187         }
188         if (skc->sc_crypt <= SK_CRYPT_EMPTY || skc->sc_crypt >= SK_CRYPT_MAX) {
189                 CERROR("Invalid crypt type: %d\n", skc->sc_crypt);
190                 return -1;
191         }
192
193         /* 4. expiration time */
194         if (gss_get_bytes(&ptr, end, &tmp, sizeof(tmp))) {
195                 CERROR("Failed to read context expiration time");
196                 return -1;
197         }
198         skc->sc_expire = tmp + ktime_get_real_seconds();
199
200         /* 5. host random is used as nonce for encryption */
201         if (gss_get_bytes(&ptr, end, &skc->sc_host_random,
202                           sizeof(skc->sc_host_random))) {
203                 CERROR("Failed to read host random ");
204                 return -1;
205         }
206
207         /* 6. peer random is used as nonce for decryption */
208         if (gss_get_bytes(&ptr, end, &skc->sc_peer_random,
209                           sizeof(skc->sc_peer_random))) {
210                 CERROR("Failed to read peer random ");
211                 return -1;
212         }
213
214         /* 7. HMAC key */
215         if (gss_get_rawobj(&ptr, end, &skc->sc_hmac_key)) {
216                 CERROR("Failed to read HMAC key");
217                 return -1;
218         }
219         if (skc->sc_hmac_key.len <= SK_MIN_SIZE) {
220                 CERROR("HMAC key must key must be larger than %d bytes\n",
221                        SK_MIN_SIZE);
222                 return -1;
223         }
224
225         /* 8. Session key, can be empty if not using privacy mode */
226         if (gss_get_rawobj(&ptr, end, &skc->sc_session_kb.kb_key)) {
227                 CERROR("Failed to read session key");
228                 return -1;
229         }
230
231         return 0;
232 }
233
234 static void sk_delete_context(struct sk_ctx *skc)
235 {
236         if (!skc)
237                 return;
238
239         rawobj_free(&skc->sc_hmac_key);
240         gss_keyblock_free(&skc->sc_session_kb);
241         OBD_FREE_PTR(skc);
242 }
243
244 static
245 __u32 gss_import_sec_context_sk(rawobj_t *inbuf, struct gss_ctx *gss_context)
246 {
247         struct sk_ctx *skc;
248         bool privacy = false;
249
250         if (inbuf == NULL || inbuf->data == NULL)
251                 return GSS_S_FAILURE;
252
253         OBD_ALLOC_PTR(skc);
254         if (!skc)
255                 return GSS_S_FAILURE;
256
257         atomic64_set(&skc->sc_iv, 0);
258
259         if (sk_fill_context(inbuf, skc))
260                 goto out_err;
261
262         /* Only privacy mode needs to initialize keys */
263         if (skc->sc_session_kb.kb_key.len > 0) {
264                 privacy = true;
265                 if (sk_init_keys(skc))
266                         goto out_err;
267         }
268
269         gss_context->internal_ctx_id = skc;
270         CDEBUG(D_SEC, "successfully imported sk%s context\n",
271                privacy ? "pi" : "i");
272
273         return GSS_S_COMPLETE;
274
275 out_err:
276         sk_delete_context(skc);
277         return GSS_S_FAILURE;
278 }
279
280 static
281 __u32 gss_copy_reverse_context_sk(struct gss_ctx *gss_context_old,
282                                   struct gss_ctx *gss_context_new)
283 {
284         struct sk_ctx *skc_old = gss_context_old->internal_ctx_id;
285         struct sk_ctx *skc_new;
286
287         OBD_ALLOC_PTR(skc_new);
288         if (!skc_new)
289                 return GSS_S_FAILURE;
290
291         skc_new->sc_hmac = skc_old->sc_hmac;
292         skc_new->sc_crypt = skc_old->sc_crypt;
293         skc_new->sc_expire = skc_old->sc_expire;
294         skc_new->sc_host_random = skc_old->sc_host_random;
295         skc_new->sc_peer_random = skc_old->sc_peer_random;
296
297         atomic64_set(&skc_new->sc_iv, SK_IV_REV_START);
298
299         if (rawobj_dup(&skc_new->sc_hmac_key, &skc_old->sc_hmac_key))
300                 goto out_err;
301         if (gss_keyblock_dup(&skc_new->sc_session_kb, &skc_old->sc_session_kb))
302                 goto out_err;
303
304         /* Only privacy mode needs to initialize keys */
305         if (skc_new->sc_session_kb.kb_key.len > 0)
306                 if (sk_init_keys(skc_new))
307                         goto out_err;
308
309         gss_context_new->internal_ctx_id = skc_new;
310         CDEBUG(D_SEC, "successfully copied reverse sk context\n");
311
312         return GSS_S_COMPLETE;
313
314 out_err:
315         sk_delete_context(skc_new);
316         return GSS_S_FAILURE;
317 }
318
319 static
320 __u32 gss_inquire_context_sk(struct gss_ctx *gss_context,
321                              time64_t *endtime)
322 {
323         struct sk_ctx *skc = gss_context->internal_ctx_id;
324
325         *endtime = skc->sc_expire;
326         return GSS_S_COMPLETE;
327 }
328
329 static
330 __u32 sk_make_hmac(char *alg_name, rawobj_t *key, int msg_count, rawobj_t *msgs,
331                    int iov_count, lnet_kiov_t *iovs, rawobj_t *token)
332 {
333         struct crypto_hash *tfm;
334         int rc;
335
336         tfm = crypto_alloc_hash(alg_name, 0, 0);
337         if (IS_ERR(tfm))
338                 return GSS_S_FAILURE;
339
340         rc = GSS_S_FAILURE;
341         LASSERT(token->len >= crypto_hash_digestsize(tfm));
342         if (!gss_digest_hmac(tfm, key, NULL, msg_count, msgs, iov_count, iovs,
343                             token))
344                 rc = GSS_S_COMPLETE;
345
346         crypto_free_hash(tfm);
347         return rc;
348 }
349
350 static
351 __u32 gss_get_mic_sk(struct gss_ctx *gss_context,
352                      int message_count,
353                      rawobj_t *messages,
354                      int iov_count,
355                      lnet_kiov_t *iovs,
356                      rawobj_t *token)
357 {
358         struct sk_ctx *skc = gss_context->internal_ctx_id;
359         return sk_make_hmac(sk_hmac_types[skc->sc_hmac].sht_name,
360                             &skc->sc_hmac_key, message_count, messages,
361                             iov_count, iovs, token);
362 }
363
364 static
365 __u32 sk_verify_hmac(struct sk_hmac_type *sht, rawobj_t *key, int message_count,
366                          rawobj_t *messages, int iov_count, lnet_kiov_t *iovs,
367                          rawobj_t *token)
368 {
369         rawobj_t checksum = RAWOBJ_EMPTY;
370         __u32 rc = GSS_S_FAILURE;
371
372         checksum.len = sht->sht_bytes;
373         if (token->len < checksum.len) {
374                 CDEBUG(D_SEC, "Token received too short, expected %d "
375                        "received %d\n", token->len, checksum.len);
376                 return GSS_S_DEFECTIVE_TOKEN;
377         }
378
379         OBD_ALLOC_LARGE(checksum.data, checksum.len);
380         if (!checksum.data)
381                 return rc;
382
383         if (sk_make_hmac(sht->sht_name, key, message_count, messages,
384                          iov_count, iovs, &checksum)) {
385                 CDEBUG(D_SEC, "Failed to create checksum to validate\n");
386                 goto cleanup;
387         }
388
389         if (memcmp(token->data, checksum.data, checksum.len)) {
390                 CERROR("checksum mismatch\n");
391                 rc = GSS_S_BAD_SIG;
392                 goto cleanup;
393         }
394
395         rc = GSS_S_COMPLETE;
396
397 cleanup:
398         OBD_FREE(checksum.data, checksum.len);
399         return rc;
400 }
401
402 /* sk_verify_bulk_hmac() differs slightly from sk_verify_hmac() because all
403  * encrypted pages in the bulk descriptor are populated although we only need
404  * to decrypt up to the number of bytes actually specified from the sender
405  * (bd_nob) otherwise the calulated HMAC will be incorrect. */
406 static
407 __u32 sk_verify_bulk_hmac(struct sk_hmac_type *sht, rawobj_t *key,
408                           int msgcnt, rawobj_t *msgs, int iovcnt,
409                           lnet_kiov_t *iovs, int iov_bytes, rawobj_t *token)
410 {
411         rawobj_t checksum = RAWOBJ_EMPTY;
412         struct crypto_hash *tfm;
413         struct hash_desc desc = {
414                 .tfm = NULL,
415                 .flags = 0,
416         };
417         struct scatterlist sg[1];
418         struct sg_table sgt;
419         int bytes;
420         int i;
421         int rc = GSS_S_FAILURE;
422
423         checksum.len = sht->sht_bytes;
424         if (token->len < checksum.len) {
425                 CDEBUG(D_SEC, "Token received too short, expected %d "
426                        "received %d\n", token->len, checksum.len);
427                 return GSS_S_DEFECTIVE_TOKEN;
428         }
429
430         OBD_ALLOC_LARGE(checksum.data, checksum.len);
431         if (!checksum.data)
432                 return rc;
433
434         tfm = crypto_alloc_hash(sht->sht_name, 0, 0);
435         if (IS_ERR(tfm))
436                 goto cleanup;
437
438         desc.tfm = tfm;
439
440         LASSERT(token->len >= crypto_hash_digestsize(tfm));
441
442         rc = crypto_hash_setkey(tfm, key->data, key->len);
443         if (rc)
444                 goto hash_cleanup;
445
446         rc = crypto_hash_init(&desc);
447         if (rc)
448                 goto hash_cleanup;
449
450         for (i = 0; i < msgcnt; i++) {
451                 if (msgs[i].len == 0)
452                         continue;
453
454                 rc = gss_setup_sgtable(&sgt, sg, msgs[i].data, msgs[i].len);
455                 if (rc != 0)
456                         goto hash_cleanup;
457
458                 rc = crypto_hash_update(&desc, sg, msgs[i].len);
459                 if (rc) {
460                         gss_teardown_sgtable(&sgt);
461                         goto hash_cleanup;
462                 }
463
464                 gss_teardown_sgtable(&sgt);
465         }
466
467         for (i = 0; i < iovcnt && iov_bytes > 0; i++) {
468                 if (iovs[i].kiov_len == 0)
469                         continue;
470
471                 bytes = min_t(int, iov_bytes, iovs[i].kiov_len);
472                 iov_bytes -= bytes;
473
474                 sg_init_table(sg, 1);
475                 sg_set_page(&sg[0], iovs[i].kiov_page, bytes,
476                             iovs[i].kiov_offset);
477                 rc = crypto_hash_update(&desc, sg, bytes);
478                 if (rc)
479                         goto hash_cleanup;
480         }
481
482         crypto_hash_final(&desc, checksum.data);
483
484         if (memcmp(token->data, checksum.data, checksum.len)) {
485                 rc = GSS_S_BAD_SIG;
486                 goto hash_cleanup;
487         }
488
489         rc = GSS_S_COMPLETE;
490
491 hash_cleanup:
492         crypto_free_hash(tfm);
493
494 cleanup:
495         OBD_FREE_LARGE(checksum.data, checksum.len);
496
497         return rc;
498 }
499
500 static
501 __u32 gss_verify_mic_sk(struct gss_ctx *gss_context,
502                         int message_count,
503                         rawobj_t *messages,
504                         int iov_count,
505                         lnet_kiov_t *iovs,
506                         rawobj_t *token)
507 {
508         struct sk_ctx *skc = gss_context->internal_ctx_id;
509         return sk_verify_hmac(&sk_hmac_types[skc->sc_hmac], &skc->sc_hmac_key,
510                               message_count, messages, iov_count, iovs, token);
511 }
512
513 static
514 __u32 gss_wrap_sk(struct gss_ctx *gss_context, rawobj_t *gss_header,
515                     rawobj_t *message, int message_buffer_length,
516                     rawobj_t *token)
517 {
518         struct sk_ctx *skc = gss_context->internal_ctx_id;
519         struct sk_hmac_type *sht = &sk_hmac_types[skc->sc_hmac];
520         struct sk_wire skw;
521         struct sk_hdr skh;
522         rawobj_t msgbufs[3];
523         __u8 local_iv[SK_IV_SIZE];
524         unsigned int blocksize;
525
526         LASSERT(skc->sc_session_kb.kb_tfm);
527
528         blocksize = crypto_blkcipher_blocksize(skc->sc_session_kb.kb_tfm);
529         if (gss_add_padding(message, message_buffer_length, blocksize))
530                 return GSS_S_FAILURE;
531
532         memset(token->data, 0, token->len);
533
534         if (sk_fill_header(skc, &skh) != GSS_S_COMPLETE)
535                 return GSS_S_FAILURE;
536
537         skw.skw_header.data = token->data;
538         skw.skw_header.len = sizeof(skh);
539         memcpy(skw.skw_header.data, &skh, sizeof(skh));
540
541         sk_construct_rfc3686_iv(local_iv, skc->sc_host_random, skh.skh_iv);
542         skw.skw_cipher.data = skw.skw_header.data + skw.skw_header.len;
543         skw.skw_cipher.len = token->len - skw.skw_header.len - sht->sht_bytes;
544         if (gss_crypt_rawobjs(skc->sc_session_kb.kb_tfm, local_iv, 1, message,
545                               &skw.skw_cipher, 1))
546                 return GSS_S_FAILURE;
547
548         /* HMAC covers the SK header, GSS header, and ciphertext */
549         msgbufs[0] = skw.skw_header;
550         msgbufs[1] = *gss_header;
551         msgbufs[2] = skw.skw_cipher;
552
553         skw.skw_hmac.data = skw.skw_cipher.data + skw.skw_cipher.len;
554         skw.skw_hmac.len = sht->sht_bytes;
555         if (sk_make_hmac(sht->sht_name, &skc->sc_hmac_key, 3, msgbufs, 0,
556                          NULL, &skw.skw_hmac))
557                 return GSS_S_FAILURE;
558
559         token->len = skw.skw_header.len + skw.skw_cipher.len + skw.skw_hmac.len;
560
561         return GSS_S_COMPLETE;
562 }
563
564 static
565 __u32 gss_unwrap_sk(struct gss_ctx *gss_context, rawobj_t *gss_header,
566                       rawobj_t *token, rawobj_t *message)
567 {
568         struct sk_ctx *skc = gss_context->internal_ctx_id;
569         struct sk_hmac_type *sht = &sk_hmac_types[skc->sc_hmac];
570         struct sk_wire skw;
571         struct sk_hdr *skh;
572         rawobj_t msgbufs[3];
573         __u8 local_iv[SK_IV_SIZE];
574         unsigned int blocksize;
575         int rc;
576
577         LASSERT(skc->sc_session_kb.kb_tfm);
578
579         if (token->len < sizeof(skh) + sht->sht_bytes)
580                 return GSS_S_DEFECTIVE_TOKEN;
581
582         skw.skw_header.data = token->data;
583         skw.skw_header.len = sizeof(struct sk_hdr);
584         skw.skw_cipher.data = skw.skw_header.data + skw.skw_header.len;
585         skw.skw_cipher.len = token->len - skw.skw_header.len - sht->sht_bytes;
586         skw.skw_hmac.data = skw.skw_cipher.data + skw.skw_cipher.len;
587         skw.skw_hmac.len = sht->sht_bytes;
588
589         blocksize = crypto_blkcipher_blocksize(skc->sc_session_kb.kb_tfm);
590         if (skw.skw_cipher.len % blocksize != 0)
591                 return GSS_S_DEFECTIVE_TOKEN;
592
593         skh = (struct sk_hdr *)skw.skw_header.data;
594         rc = sk_verify_header(skh);
595         if (rc != GSS_S_COMPLETE)
596                 return rc;
597
598         /* HMAC covers the SK header, GSS header, and ciphertext */
599         msgbufs[0] = skw.skw_header;
600         msgbufs[1] = *gss_header;
601         msgbufs[2] = skw.skw_cipher;
602         rc = sk_verify_hmac(sht, &skc->sc_hmac_key, 3, msgbufs, 0, NULL,
603                             &skw.skw_hmac);
604         if (rc)
605                 return rc;
606
607         sk_construct_rfc3686_iv(local_iv, skc->sc_peer_random, skh->skh_iv);
608         message->len = skw.skw_cipher.len;
609         if (gss_crypt_rawobjs(skc->sc_session_kb.kb_tfm, local_iv,
610                               1, &skw.skw_cipher, message, 0))
611                 return GSS_S_FAILURE;
612
613         return GSS_S_COMPLETE;
614 }
615
616 static
617 __u32 gss_prep_bulk_sk(struct gss_ctx *gss_context,
618                        struct ptlrpc_bulk_desc *desc)
619 {
620         struct sk_ctx *skc = gss_context->internal_ctx_id;
621         int blocksize;
622         int i;
623
624         LASSERT(skc->sc_session_kb.kb_tfm);
625         blocksize = crypto_blkcipher_blocksize(skc->sc_session_kb.kb_tfm);
626
627         for (i = 0; i < desc->bd_iov_count; i++) {
628                 if (BD_GET_KIOV(desc, i).kiov_offset & blocksize) {
629                         CERROR("offset %d not blocksize aligned\n",
630                                BD_GET_KIOV(desc, i).kiov_offset);
631                         return GSS_S_FAILURE;
632                 }
633
634                 BD_GET_ENC_KIOV(desc, i).kiov_offset =
635                         BD_GET_KIOV(desc, i).kiov_offset;
636                 BD_GET_ENC_KIOV(desc, i).kiov_len =
637                         sk_block_mask(BD_GET_KIOV(desc, i).kiov_len, blocksize);
638         }
639
640         return GSS_S_COMPLETE;
641 }
642
643 static __u32 sk_encrypt_bulk(struct crypto_blkcipher *tfm, __u8 *iv,
644                              struct ptlrpc_bulk_desc *desc, rawobj_t *cipher,
645                              int adj_nob)
646 {
647         struct blkcipher_desc cdesc = {
648                 .tfm = tfm,
649                 .info = iv,
650                 .flags = 0,
651         };
652         struct scatterlist ptxt;
653         struct scatterlist ctxt;
654         int blocksize;
655         int i;
656         int rc;
657         int nob = 0;
658
659         blocksize = crypto_blkcipher_blocksize(tfm);
660
661         sg_init_table(&ptxt, 1);
662         sg_init_table(&ctxt, 1);
663
664         for (i = 0; i < desc->bd_iov_count; i++) {
665                 sg_set_page(&ptxt, BD_GET_KIOV(desc, i).kiov_page,
666                             sk_block_mask(BD_GET_KIOV(desc, i).kiov_len,
667                                           blocksize),
668                             BD_GET_KIOV(desc, i).kiov_offset);
669                 nob += ptxt.length;
670
671                 sg_set_page(&ctxt, BD_GET_ENC_KIOV(desc, i).kiov_page,
672                             ptxt.length, ptxt.offset);
673
674                 BD_GET_ENC_KIOV(desc, i).kiov_offset = ctxt.offset;
675                 BD_GET_ENC_KIOV(desc, i).kiov_len = ctxt.length;
676
677                 rc = crypto_blkcipher_encrypt_iv(&cdesc, &ctxt, &ptxt,
678                                                  ptxt.length);
679                 if (rc) {
680                         CERROR("failed to encrypt page: %d\n", rc);
681                         return rc;
682                 }
683         }
684
685         if (adj_nob)
686                 desc->bd_nob = nob;
687
688         return 0;
689 }
690
691 static __u32 sk_decrypt_bulk(struct crypto_blkcipher *tfm, __u8 *iv,
692                              struct ptlrpc_bulk_desc *desc, rawobj_t *cipher,
693                              int adj_nob)
694 {
695         struct blkcipher_desc cdesc = {
696                 .tfm = tfm,
697                 .info = iv,
698                 .flags = 0,
699         };
700         struct scatterlist ptxt;
701         struct scatterlist ctxt;
702         int blocksize;
703         int i;
704         int rc;
705         int pnob = 0;
706         int cnob = 0;
707
708         sg_init_table(&ptxt, 1);
709         sg_init_table(&ctxt, 1);
710
711         blocksize = crypto_blkcipher_blocksize(tfm);
712         if (desc->bd_nob_transferred % blocksize != 0) {
713                 CERROR("Transfer not a multiple of block size: %d\n",
714                        desc->bd_nob_transferred);
715                 return GSS_S_DEFECTIVE_TOKEN;
716         }
717
718         for (i = 0; i < desc->bd_iov_count && cnob < desc->bd_nob_transferred;
719              i++) {
720                 lnet_kiov_t *piov = &BD_GET_KIOV(desc, i);
721                 lnet_kiov_t *ciov = &BD_GET_ENC_KIOV(desc, i);
722
723                 if (ciov->kiov_offset % blocksize != 0 ||
724                     ciov->kiov_len % blocksize != 0) {
725                         CERROR("Invalid bulk descriptor vector\n");
726                         return GSS_S_DEFECTIVE_TOKEN;
727                 }
728
729                 /* Must adjust bytes here because we know the actual sizes after
730                  * decryption.  Similar to what gss_cli_ctx_unwrap_bulk does for
731                  * integrity only mode */
732                 if (adj_nob) {
733                         /* cipher text must not exceed transferred size */
734                         if (ciov->kiov_len + cnob > desc->bd_nob_transferred)
735                                 ciov->kiov_len =
736                                         desc->bd_nob_transferred - cnob;
737
738                         piov->kiov_len = ciov->kiov_len;
739
740                         /* plain text must not exceed bulk's size */
741                         if (ciov->kiov_len + pnob > desc->bd_nob)
742                                 piov->kiov_len = desc->bd_nob - pnob;
743                 } else {
744                         /* Taken from krb5_decrypt since it was not verified
745                          * whether or not LNET guarantees these */
746                         if (ciov->kiov_len + cnob > desc->bd_nob_transferred ||
747                             piov->kiov_len > ciov->kiov_len) {
748                                 CERROR("Invalid decrypted length\n");
749                                 return GSS_S_FAILURE;
750                         }
751                 }
752
753                 if (ciov->kiov_len == 0)
754                         continue;
755
756                 sg_init_table(&ctxt, 1);
757                 sg_set_page(&ctxt, ciov->kiov_page, ciov->kiov_len,
758                             ciov->kiov_offset);
759                 ptxt = ctxt;
760
761                 /* In the event the plain text size is not a multiple
762                  * of blocksize we decrypt in place and copy the result
763                  * after the decryption */
764                 if (piov->kiov_len % blocksize == 0)
765                         sg_assign_page(&ptxt, piov->kiov_page);
766
767                 rc = crypto_blkcipher_decrypt_iv(&cdesc, &ptxt, &ctxt,
768                                                  ctxt.length);
769                 if (rc) {
770                         CERROR("Decryption failed for page: %d\n", rc);
771                         return GSS_S_FAILURE;
772                 }
773
774                 if (piov->kiov_len % blocksize != 0) {
775                         memcpy(page_address(piov->kiov_page) +
776                                piov->kiov_offset,
777                                page_address(ciov->kiov_page) +
778                                ciov->kiov_offset,
779                                piov->kiov_len);
780                 }
781
782                 cnob += ciov->kiov_len;
783                 pnob += piov->kiov_len;
784         }
785
786         /* if needed, clear up the rest unused iovs */
787         if (adj_nob)
788                 while (i < desc->bd_iov_count)
789                         BD_GET_KIOV(desc, i++).kiov_len = 0;
790
791         if (unlikely(cnob != desc->bd_nob_transferred)) {
792                 CERROR("%d cipher text transferred but only %d decrypted\n",
793                        desc->bd_nob_transferred, cnob);
794                 return GSS_S_FAILURE;
795         }
796
797         if (unlikely(!adj_nob && pnob != desc->bd_nob)) {
798                 CERROR("%d plain text expected but only %d received\n",
799                        desc->bd_nob, pnob);
800                 return GSS_S_FAILURE;
801         }
802
803         return 0;
804 }
805
806 static
807 __u32 gss_wrap_bulk_sk(struct gss_ctx *gss_context,
808                        struct ptlrpc_bulk_desc *desc, rawobj_t *token,
809                        int adj_nob)
810 {
811         struct sk_ctx *skc = gss_context->internal_ctx_id;
812         struct sk_hmac_type *sht = &sk_hmac_types[skc->sc_hmac];
813         struct sk_wire skw;
814         struct sk_hdr skh;
815         __u8 local_iv[SK_IV_SIZE];
816
817         LASSERT(skc->sc_session_kb.kb_tfm);
818
819         memset(token->data, 0, token->len);
820         if (sk_fill_header(skc, &skh) != GSS_S_COMPLETE)
821                 return GSS_S_FAILURE;
822
823         skw.skw_header.data = token->data;
824         skw.skw_header.len = sizeof(skh);
825         memcpy(skw.skw_header.data, &skh, sizeof(skh));
826
827         sk_construct_rfc3686_iv(local_iv, skc->sc_host_random, skh.skh_iv);
828         skw.skw_cipher.data = skw.skw_header.data + skw.skw_header.len;
829         skw.skw_cipher.len = token->len - skw.skw_header.len - sht->sht_bytes;
830         if (sk_encrypt_bulk(skc->sc_session_kb.kb_tfm, local_iv,
831                             desc, &skw.skw_cipher, adj_nob))
832                 return GSS_S_FAILURE;
833
834         skw.skw_hmac.data = skw.skw_cipher.data + skw.skw_cipher.len;
835         skw.skw_hmac.len = sht->sht_bytes;
836         if (sk_make_hmac(sht->sht_name, &skc->sc_hmac_key, 1, &skw.skw_cipher,
837                          desc->bd_iov_count, GET_ENC_KIOV(desc), &skw.skw_hmac))
838                 return GSS_S_FAILURE;
839
840         return GSS_S_COMPLETE;
841 }
842
843 static
844 __u32 gss_unwrap_bulk_sk(struct gss_ctx *gss_context,
845                            struct ptlrpc_bulk_desc *desc,
846                            rawobj_t *token, int adj_nob)
847 {
848         struct sk_ctx *skc = gss_context->internal_ctx_id;
849         struct sk_hmac_type *sht = &sk_hmac_types[skc->sc_hmac];
850         struct sk_wire skw;
851         struct sk_hdr *skh;
852         __u8 local_iv[SK_IV_SIZE];
853         int rc;
854
855         LASSERT(skc->sc_session_kb.kb_tfm);
856
857         if (token->len < sizeof(skh) + sht->sht_bytes)
858                 return GSS_S_DEFECTIVE_TOKEN;
859
860         skw.skw_header.data = token->data;
861         skw.skw_header.len = sizeof(struct sk_hdr);
862         skw.skw_cipher.data = skw.skw_header.data + skw.skw_header.len;
863         skw.skw_cipher.len = token->len - skw.skw_header.len - sht->sht_bytes;
864         skw.skw_hmac.data = skw.skw_cipher.data + skw.skw_cipher.len;
865         skw.skw_hmac.len = sht->sht_bytes;
866
867         skh = (struct sk_hdr *)skw.skw_header.data;
868         rc = sk_verify_header(skh);
869         if (rc != GSS_S_COMPLETE)
870                 return rc;
871
872         rc = sk_verify_bulk_hmac(&sk_hmac_types[skc->sc_hmac],
873                                  &skc->sc_hmac_key, 1, &skw.skw_cipher,
874                                  desc->bd_iov_count, GET_ENC_KIOV(desc),
875                                  desc->bd_nob, &skw.skw_hmac);
876         if (rc)
877                 return rc;
878
879         sk_construct_rfc3686_iv(local_iv, skc->sc_peer_random, skh->skh_iv);
880         rc = sk_decrypt_bulk(skc->sc_session_kb.kb_tfm, local_iv,
881                              desc, &skw.skw_cipher, adj_nob);
882         if (rc)
883                 return rc;
884
885         return GSS_S_COMPLETE;
886 }
887
888 static
889 void gss_delete_sec_context_sk(void *internal_context)
890 {
891         struct sk_ctx *sk_context = internal_context;
892         sk_delete_context(sk_context);
893 }
894
895 int gss_display_sk(struct gss_ctx *gss_context, char *buf, int bufsize)
896 {
897         return snprintf(buf, bufsize, "sk");
898 }
899
900 static struct gss_api_ops gss_sk_ops = {
901         .gss_import_sec_context     = gss_import_sec_context_sk,
902         .gss_copy_reverse_context   = gss_copy_reverse_context_sk,
903         .gss_inquire_context        = gss_inquire_context_sk,
904         .gss_get_mic                = gss_get_mic_sk,
905         .gss_verify_mic             = gss_verify_mic_sk,
906         .gss_wrap                   = gss_wrap_sk,
907         .gss_unwrap                 = gss_unwrap_sk,
908         .gss_prep_bulk              = gss_prep_bulk_sk,
909         .gss_wrap_bulk              = gss_wrap_bulk_sk,
910         .gss_unwrap_bulk            = gss_unwrap_bulk_sk,
911         .gss_delete_sec_context     = gss_delete_sec_context_sk,
912         .gss_display                = gss_display_sk,
913 };
914
915 static struct subflavor_desc gss_sk_sfs[] = {
916         {
917                 .sf_subflavor   = SPTLRPC_SUBFLVR_SKN,
918                 .sf_qop         = 0,
919                 .sf_service     = SPTLRPC_SVC_NULL,
920                 .sf_name        = "skn"
921         },
922         {
923                 .sf_subflavor   = SPTLRPC_SUBFLVR_SKA,
924                 .sf_qop         = 0,
925                 .sf_service     = SPTLRPC_SVC_AUTH,
926                 .sf_name        = "ska"
927         },
928         {
929                 .sf_subflavor   = SPTLRPC_SUBFLVR_SKI,
930                 .sf_qop         = 0,
931                 .sf_service     = SPTLRPC_SVC_INTG,
932                 .sf_name        = "ski"
933         },
934         {
935                 .sf_subflavor   = SPTLRPC_SUBFLVR_SKPI,
936                 .sf_qop         = 0,
937                 .sf_service     = SPTLRPC_SVC_PRIV,
938                 .sf_name        = "skpi"
939         },
940 };
941
942 static struct gss_api_mech gss_sk_mech = {
943         /* .gm_owner uses default NULL value for THIS_MODULE */
944         .gm_name        = "sk",
945         .gm_oid         = (rawobj_t) {
946                 .len = 12,
947                 .data = "\053\006\001\004\001\311\146\215\126\001\000\001",
948         },
949         .gm_ops         = &gss_sk_ops,
950         .gm_sf_num      = 4,
951         .gm_sfs         = gss_sk_sfs,
952 };
953
954 int __init init_sk_module(void)
955 {
956         int status;
957
958         status = lgss_mech_register(&gss_sk_mech);
959         if (status)
960                 CERROR("Failed to register sk gss mechanism!\n");
961
962         return status;
963 }
964
965 void cleanup_sk_module(void)
966 {
967         lgss_mech_unregister(&gss_sk_mech);
968 }