Whamcloud - gitweb
* add a bit in ptlrpc_request to indicate ptlrpcs error, gss recovery
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         __u32                   naltype;
109         __u32                   netid;
110         __u64                   nid;
111         rawobj_t                in_handle, in_token;
112         rawobj_t                out_handle, out_token;
113         int                     major_status, minor_status;
114 };
115
116 static struct cache_head *rsi_table[RSI_HASHMAX];
117 static struct cache_detail rsi_cache;
118
119 static void rsi_free(struct rsi *rsii)
120 {
121         rawobj_free(&rsii->in_handle);
122         rawobj_free(&rsii->in_token);
123         rawobj_free(&rsii->out_handle);
124         rawobj_free(&rsii->out_token);
125 }
126
127 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
128 {
129         struct rsi *rsii = container_of(item, struct rsi, h);
130         if (cache_put(item, cd)) {
131                 rsi_free(rsii);
132                 OBD_FREE(rsii, sizeof(*rsii));
133         }
134 }
135
136 static inline int rsi_hash(struct rsi *item)
137 {
138         return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
139               ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS);
140 }
141
142 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
143 {
144         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
145                 rawobj_equal(&item->in_token, &tmp->in_token));
146 }
147
148 static void rsi_request(struct cache_detail *cd,
149                         struct cache_head *h,
150                         char **bpp, int *blen)
151 {
152         struct rsi *rsii = container_of(h, struct rsi, h);
153
154         qword_addhex(bpp, blen, (char *) &rsii->naltype, sizeof(rsii->naltype));
155         qword_addhex(bpp, blen, (char *) &rsii->netid, sizeof(rsii->netid));
156         qword_addhex(bpp, blen, (char *) &rsii->nid, sizeof(rsii->nid));
157         qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len);
158         qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len);
159         (*bpp)[-1] = '\n';
160 }
161
162 static int
163 gssd_reply(struct rsi *item)
164 {
165         struct rsi *tmp;
166         struct cache_head **hp, **head;
167         ENTRY;
168
169         head = &rsi_cache.hash_table[rsi_hash(item)];
170         write_lock(&rsi_cache.hash_lock);
171         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
172                 tmp = container_of(*hp, struct rsi, h);
173                 if (rsi_match(tmp, item)) {
174                         cache_get(&tmp->h);
175                         clear_bit(CACHE_HASHED, &tmp->h.flags);
176                         *hp = tmp->h.next;
177                         tmp->h.next = NULL;
178                         rsi_cache.entries--;
179                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
180                                 write_unlock(&rsi_cache.hash_lock);
181                                 rsi_put(&tmp->h, &rsi_cache);
182                                 RETURN(-EINVAL);
183                         }
184                         set_bit(CACHE_HASHED, &item->h.flags);
185                         item->h.next = *hp;
186                         *hp = &item->h;
187                         rsi_cache.entries++;
188                         set_bit(CACHE_VALID, &item->h.flags);
189                         item->h.last_refresh = get_seconds();
190                         write_unlock(&rsi_cache.hash_lock);
191                         cache_fresh(&rsi_cache, &tmp->h, 0);
192                         rsi_put(&tmp->h, &rsi_cache);
193                         RETURN(0);
194                 }
195         }
196         write_unlock(&rsi_cache.hash_lock);
197         RETURN(-EINVAL);
198 }
199
200 /* XXX
201  * here we just wait here for its completion or timedout. it's a
202  * hacking but works, and we'll comeup with real fix if we decided
203  * to still stick with NFS4 cache code
204  */
205 static struct rsi *
206 gssd_upcall(struct rsi *item, struct cache_req *chandle)
207 {
208         struct rsi *tmp;
209         struct cache_head **hp, **head;
210         unsigned long starttime;
211         ENTRY;
212
213         head = &rsi_cache.hash_table[rsi_hash(item)];
214         read_lock(&rsi_cache.hash_lock);
215         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
216                 tmp = container_of(*hp, struct rsi, h);
217                 if (rsi_match(tmp, item)) {
218                         LBUG();
219                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
220                                 CERROR("found rsi without VALID\n");
221                                 read_unlock(&rsi_cache.hash_lock);
222                                 return NULL;
223                         }
224                         *hp = tmp->h.next;
225                         tmp->h.next = NULL;
226                         rsi_cache.entries--;
227                         cache_get(&tmp->h);
228                         read_unlock(&rsi_cache.hash_lock);
229                         return tmp;
230                 }
231         }
232         // cache_get(&item->h);
233         set_bit(CACHE_HASHED, &item->h.flags);
234         item->h.next = *head;
235         *head = &item->h;
236         rsi_cache.entries++;
237         read_unlock(&rsi_cache.hash_lock);
238         cache_get(&item->h);
239
240         cache_check(&rsi_cache, &item->h, chandle);
241         starttime = get_seconds();
242         do {
243                 yield();
244                 read_lock(&rsi_cache.hash_lock);
245                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
246                         tmp = container_of(*hp, struct rsi, h);
247                         if (tmp == item)
248                                 continue;
249                         if (rsi_match(tmp, item)) {
250                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
251                                         read_unlock(&rsi_cache.hash_lock);
252                                         return NULL;
253                                 }
254                                 cache_get(&tmp->h);
255                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
256                                 *hp = tmp->h.next;
257                                 tmp->h.next = NULL;
258                                 rsi_cache.entries--;
259                                 read_unlock(&rsi_cache.hash_lock);
260                                 return tmp;
261                         }
262                 }
263                 read_unlock(&rsi_cache.hash_lock);
264         } while ((get_seconds() - starttime) <= 5);
265         CERROR("5s timeout while waiting cache refill\n");
266         return NULL;
267 }
268
269 static int rsi_parse(struct cache_detail *cd,
270                      char *mesg, int mlen)
271 {
272         /* context token expiry major minor context token */
273         char *buf = mesg;
274         char *ep;
275         int len;
276         struct rsi *rsii;
277         time_t expiry;
278         int status = -EINVAL;
279         ENTRY;
280
281         OBD_ALLOC(rsii, sizeof(*rsii));
282         if (!rsii) {
283                 CERROR("failed to alloc rsii\n");
284                 RETURN(-ENOMEM);
285         }
286         cache_init(&rsii->h);
287
288         /* handle */
289         len = qword_get(&mesg, buf, mlen);
290         if (len < 0)
291                 goto out;
292         status = -ENOMEM;
293         if (rawobj_alloc(&rsii->in_handle, buf, len))
294                 goto out;
295
296         /* token */
297         len = qword_get(&mesg, buf, mlen);
298         status = -EINVAL;
299         if (len < 0)
300                 goto out;;
301         status = -ENOMEM;
302         if (rawobj_alloc(&rsii->in_token, buf, len))
303                 goto out;
304
305         /* expiry */
306         expiry = get_expiry(&mesg);
307         status = -EINVAL;
308         if (expiry == 0)
309                 goto out;
310
311         /* major/minor */
312         len = qword_get(&mesg, buf, mlen);
313         if (len < 0)
314                 goto out;
315         if (len == 0) {
316                 goto out;
317         } else {
318                 rsii->major_status = simple_strtoul(buf, &ep, 10);
319                 if (*ep)
320                         goto out;
321                 len = qword_get(&mesg, buf, mlen);
322                 if (len <= 0)
323                         goto out;
324                 rsii->minor_status = simple_strtoul(buf, &ep, 10);
325                 if (*ep)
326                         goto out;
327
328                 /* out_handle */
329                 len = qword_get(&mesg, buf, mlen);
330                 if (len < 0)
331                         goto out;
332                 status = -ENOMEM;
333                 if (rawobj_alloc(&rsii->out_handle, buf, len))
334                         goto out;
335
336                 /* out_token */
337                 len = qword_get(&mesg, buf, mlen);
338                 status = -EINVAL;
339                 if (len < 0)
340                         goto out;
341                 status = -ENOMEM;
342                 if (rawobj_alloc(&rsii->out_token, buf, len))
343                         goto out;
344         }
345         rsii->h.expiry_time = expiry;
346         status = gssd_reply(rsii);
347 out:
348         if (rsii)
349                 rsi_put(&rsii->h, &rsi_cache);
350         RETURN(status);
351 }
352
353 static struct cache_detail rsi_cache = {
354         .hash_size      = RSI_HASHMAX,
355         .hash_table     = rsi_table,
356         .name           = "auth.ptlrpcs.init",
357         .cache_put      = rsi_put,
358         .cache_request  = rsi_request,
359         .cache_parse    = rsi_parse,
360 };
361
362 /*
363  * The rpcsec_context cache is used to store a context that is
364  * used in data exchange.
365  * The key is a context handle. The content is:
366  *  uid, gidlist, mechanism, service-set, mech-specific-data
367  */
368
369 #define RSC_HASHBITS    10
370 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
371 #define RSC_HASHMASK    (RSC_HASHMAX-1)
372
373 #define GSS_SEQ_WIN     512
374
375 struct gss_svc_seq_data {
376         /* highest seq number seen so far: */
377         __u32                   sd_max;
378         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
379          * sd_win is nonzero iff sequence number i has been seen already: */
380         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
381         spinlock_t              sd_lock;
382 };
383
384 struct rsc {
385         struct cache_head       h;
386         rawobj_t                handle;
387         __u32                   remote_realm;
388         struct vfs_cred         cred;
389         uid_t                   mapped_uid;
390         struct gss_svc_seq_data seqdata;
391         struct gss_ctx         *mechctx;
392 };
393
394 static struct cache_head *rsc_table[RSC_HASHMAX];
395 static struct cache_detail rsc_cache;
396
397 static void rsc_free(struct rsc *rsci)
398 {
399         rawobj_free(&rsci->handle);
400         if (rsci->mechctx)
401                 kgss_delete_sec_context(&rsci->mechctx);
402 #if 0
403         if (rsci->cred.vc_ginfo)
404                 put_group_info(rsci->cred.vc_ginfo);
405 #endif
406 }
407
408 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
409 {
410         struct rsc *rsci = container_of(item, struct rsc, h);
411
412         if (cache_put(item, cd)) {
413                 rsc_free(rsci);
414                 OBD_FREE(rsci, sizeof(*rsci));
415         }
416 }
417
418 static inline int
419 rsc_hash(struct rsc *rsci)
420 {
421         return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS);
422 }
423
424 static inline int
425 rsc_match(struct rsc *new, struct rsc *tmp)
426 {
427         return rawobj_equal(&new->handle, &tmp->handle);
428 }
429
430 static struct rsc *rsc_lookup(struct rsc *item, int set)
431 {
432         struct rsc *tmp = NULL;
433         struct cache_head **hp, **head;
434         head = &rsc_cache.hash_table[rsc_hash(item)];
435         ENTRY;
436
437         if (set)
438                 write_lock(&rsc_cache.hash_lock);
439         else
440                 read_lock(&rsc_cache.hash_lock);
441         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
442                 tmp = container_of(*hp, struct rsc, h);
443                 if (!rsc_match(tmp, item))
444                         continue;
445                 cache_get(&tmp->h);
446                 if (!set) {
447                         goto out_noset;
448                 }
449                 *hp = tmp->h.next;
450                 tmp->h.next = NULL;
451                 clear_bit(CACHE_HASHED, &tmp->h.flags);
452                 rsc_put(&tmp->h, &rsc_cache);
453                 goto out_set;
454         }
455         /* Didn't find anything */
456         if (!set)
457                 goto out_noset;
458         rsc_cache.entries++;
459 out_set:
460         set_bit(CACHE_HASHED, &item->h.flags);
461         item->h.next = *head;
462         *head = &item->h;
463         write_unlock(&rsc_cache.hash_lock);
464         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
465         cache_get(&item->h);
466         RETURN(item);
467 out_noset:
468         read_unlock(&rsc_cache.hash_lock);
469         RETURN(tmp);
470 }
471                                                                                                                         
472 static int rsc_parse(struct cache_detail *cd,
473                      char *mesg, int mlen)
474 {
475         /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */
476         char *buf = mesg;
477         int len, rv;
478         struct rsc *rsci, *res = NULL;
479         time_t expiry;
480         int status = -EINVAL;
481
482         OBD_ALLOC(rsci, sizeof(*rsci));
483         if (!rsci) {
484                 CERROR("fail to alloc rsci\n");
485                 return -ENOMEM;
486         }
487         cache_init(&rsci->h);
488
489         /* context handle */
490         len = qword_get(&mesg, buf, mlen);
491         if (len < 0) goto out;
492         status = -ENOMEM;
493         if (rawobj_alloc(&rsci->handle, buf, len))
494                 goto out;
495
496         /* expiry */
497         expiry = get_expiry(&mesg);
498         status = -EINVAL;
499         if (expiry == 0)
500                 goto out;
501
502         /* remote flag */
503         rv = get_int(&mesg, &rsci->remote_realm);
504         if (rv) {
505                 CERROR("fail to get remote flag\n");
506                 goto out;
507         }
508
509         /* mapped uid */
510         rv = get_int(&mesg, &rsci->mapped_uid);
511         if (rv) {
512                 CERROR("fail to get mapped uid\n");
513                 goto out;
514         }
515
516         /* uid, or NEGATIVE */
517         rv = get_int(&mesg, &rsci->cred.vc_uid);
518         if (rv == -EINVAL)
519                 goto out;
520         if (rv == -ENOENT) {
521                 CERROR("NOENT? set rsc entry negative\n");
522                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
523         } else {
524                 struct gss_api_mech *gm;
525                 rawobj_t tmp_buf;
526                 __u64 ctx_expiry;
527
528                 /* gid */
529                 if (get_int(&mesg, &rsci->cred.vc_gid))
530                         goto out;
531
532                 /* mech name */
533                 len = qword_get(&mesg, buf, mlen);
534                 if (len < 0)
535                         goto out;
536                 gm = kgss_name_to_mech(buf);
537                 status = -EOPNOTSUPP;
538                 if (!gm)
539                         goto out;
540
541                 status = -EINVAL;
542                 /* mech-specific data: */
543                 len = qword_get(&mesg, buf, mlen);
544                 if (len < 0) {
545                         kgss_mech_put(gm);
546                         goto out;
547                 }
548                 tmp_buf.len = len;
549                 tmp_buf.data = buf;
550                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
551                         kgss_mech_put(gm);
552                         goto out;
553                 }
554
555                 /* currently the expiry time passed down from user-space
556                  * is invalid, here we retrive it from mech.
557                  */
558                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
559                         CERROR("unable to get expire time, drop it\n");
560                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
561                         kgss_mech_put(gm);
562                         goto out;
563                 }
564                 expiry = (time_t) ctx_expiry;
565
566                 kgss_mech_put(gm);
567         }
568         rsci->h.expiry_time = expiry;
569         spin_lock_init(&rsci->seqdata.sd_lock);
570         res = rsc_lookup(rsci, 1);
571         rsc_put(&res->h, &rsc_cache);
572         status = 0;
573 out:
574         if (rsci)
575                 rsc_put(&rsci->h, &rsc_cache);
576         return status;
577 }
578
579 /*
580  * flush all entries with @uid. @uid == -1 will match all.
581  * we only know the uid, maybe netid/nid in the future, in all cases
582  * we must search the whole cache
583  */
584 static void rsc_flush(uid_t uid)
585 {
586         struct cache_head **ch;
587         struct rsc *rscp;
588         int n;
589         ENTRY;
590
591         write_lock(&rsc_cache.hash_lock);
592         for (n = 0; n < RSC_HASHMAX; n++) {
593                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
594                         rscp = container_of(*ch, struct rsc, h);
595                         if (uid == -1 || rscp->cred.vc_uid == uid) {
596                                 /* it seems simply set NEGATIVE doesn't work */
597                                 *ch = (*ch)->next;
598                                 rscp->h.next = NULL;
599                                 cache_get(&rscp->h);
600                                 set_bit(CACHE_NEGATIVE, &rscp->h.flags);
601                                 clear_bit(CACHE_HASHED, &rscp->h.flags);
602                                 CWARN("flush rsc %p for uid %u\n",
603                                        rscp, rscp->cred.vc_uid);
604                                 rsc_put(&rscp->h, &rsc_cache);
605                                 rsc_cache.entries--;
606                                 continue;
607                         }
608                         ch = &((*ch)->next);
609                 }
610         }
611         write_unlock(&rsc_cache.hash_lock);
612         EXIT;
613 }
614
615 static struct cache_detail rsc_cache = {
616         .hash_size      = RSC_HASHMAX,
617         .hash_table     = rsc_table,
618         .name           = "auth.ptlrpcs.context",
619         .cache_put      = rsc_put,
620         .cache_parse    = rsc_parse,
621 };
622
623 static struct rsc *
624 gss_svc_searchbyctx(rawobj_t *handle)
625 {
626         struct rsc rsci;
627         struct rsc *found;
628
629         rsci.handle = *handle;
630         found = rsc_lookup(&rsci, 0);
631         if (!found)
632                 return NULL;
633
634         if (cache_check(&rsc_cache, &found->h, NULL))
635                 return NULL;
636
637         return found;
638 }
639
640 struct gss_svc_data {
641         /* decoded gss client cred: */
642         struct rpc_gss_wire_cred        clcred;
643         /* internal used status */
644         unsigned int                    is_init:1,
645                                         is_init_continue:1,
646                                         is_err_notify:1,
647                                         is_fini:1;
648         int                             reserve_len;
649 };
650
651 /* FIXME
652  * again hacking: only try to give the svcgssd a chance to handle
653  * upcalls.
654  */
655 struct cache_deferred_req* my_defer(struct cache_req *req)
656 {
657         yield();
658         return NULL;
659 }
660 static struct cache_req my_chandle = {my_defer};
661
662 /* Implements sequence number algorithm as specified in RFC 2203. */
663 static int
664 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
665 {
666         int rc = 0;
667
668         spin_lock(&sd->sd_lock);
669         if (seq_num > sd->sd_max) {
670                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
671                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
672                         sd->sd_max = seq_num;
673                 } else {
674                         while(sd->sd_max < seq_num) {
675                                 sd->sd_max++;
676                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
677                                             sd->sd_win);
678                         }
679                 }
680                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
681                 goto exit;
682         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
683                 CERROR("seq %u too low: max %u, win %d\n",
684                         seq_num, sd->sd_max, GSS_SEQ_WIN);
685                 rc = 1;
686                 goto exit;
687         }
688
689         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) {
690                 CERROR("seq %u is replay: max %u, win %d\n",
691                         seq_num, sd->sd_max, GSS_SEQ_WIN);
692                 rc = 1;
693         }
694 exit:
695         spin_unlock(&sd->sd_lock);
696         return rc;
697 }
698
699 static int
700 gss_svc_verify_request(struct ptlrpc_request *req,
701                        struct rsc *rsci,
702                        struct rpc_gss_wire_cred *gc,
703                        __u32 *vp, __u32 vlen)
704 {
705         struct ptlrpcs_wire_hdr *sec_hdr;
706         struct gss_ctx *ctx = rsci->mechctx;
707         __u32 maj_stat;
708         rawobj_t msg;
709         rawobj_t mic;
710         ENTRY;
711
712         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
713
714         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
715         req->rq_reqlen = sec_hdr->msg_len;
716
717         msg.len = sec_hdr->msg_len;
718         msg.data = (__u8 *)req->rq_reqmsg;
719
720         mic.len = le32_to_cpu(*vp++);
721         mic.data = (char *) vp;
722         vlen -= 4;
723
724         if (mic.len > vlen) {
725                 CERROR("checksum len %d, while buffer len %d\n",
726                         mic.len, vlen);
727                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
728         }
729
730         if (mic.len > 256) {
731                 CERROR("invalid mic len %d\n", mic.len);
732                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
733         }
734
735         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
736         if (maj_stat != GSS_S_COMPLETE) {
737                 CERROR("MIC verification error: major %x\n", maj_stat);
738                 RETURN(maj_stat);
739         }
740
741         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
742                 CERROR("discard request %p with old seq_num %u\n",
743                         req, gc->gc_seq);
744                 RETURN(GSS_S_DUPLICATE_TOKEN);
745         }
746
747         RETURN(GSS_S_COMPLETE);
748 }
749
750 static int
751 gss_svc_unseal_request(struct ptlrpc_request *req,
752                        struct rsc *rsci,
753                        struct rpc_gss_wire_cred *gc,
754                        __u32 *vp, __u32 vlen)
755 {
756         struct ptlrpcs_wire_hdr *sec_hdr;
757         struct gss_ctx *ctx = rsci->mechctx;
758         rawobj_t cipher_text, plain_text;
759         __u32 major;
760         ENTRY;
761
762         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
763
764         if (vlen < 4) {
765                 CERROR("vlen only %u\n", vlen);
766                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
767         }
768
769         cipher_text.len = le32_to_cpu(*vp++);
770         cipher_text.data = (__u8 *) vp;
771         vlen -= 4;
772         
773         if (cipher_text.len > vlen) {
774                 CERROR("cipher claimed %u while buf only %u\n",
775                         cipher_text.len, vlen);
776                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
777         }
778
779         plain_text = cipher_text;
780
781         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
782         if (major) {
783                 CERROR("unwrap error 0x%x\n", major);
784                 RETURN(major);
785         }
786
787         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
788                 CERROR("discard request %p with old seq_num %u\n",
789                         req, gc->gc_seq);
790                 RETURN(GSS_S_DUPLICATE_TOKEN);
791         }
792
793         req->rq_reqmsg = (struct lustre_msg *) (vp);
794         req->rq_reqlen = plain_text.len;
795
796         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
797
798         RETURN(GSS_S_COMPLETE);
799 }
800
801 static int
802 gss_pack_err_notify(struct ptlrpc_request *req,
803                     __u32 major, __u32 minor)
804 {
805         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
806         __u32 reslen, *resp, *reslenp;
807         char  nidstr[PTL_NALFMT_SIZE];
808         const __u32 secdata_len = 7 * 4;
809         int rc;
810         ENTRY;
811
812         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
813
814         LASSERT(svcdata);
815         svcdata->is_err_notify = 1;
816         svcdata->reserve_len = 7 * 4;
817
818         rc = lustre_pack_reply(req, 0, NULL, NULL);
819         if (rc) {
820                 CERROR("could not pack reply, err %d\n", rc);
821                 RETURN(rc);
822         }
823
824         LASSERT(req->rq_reply_state);
825         LASSERT(req->rq_reply_state->rs_repbuf);
826         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
827         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
828
829         /* header */
830         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
831         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
832         *resp++ = cpu_to_le32(req->rq_replen);
833         reslenp = resp++;
834
835         /* skip lustre msg */
836         resp += req->rq_replen / 4;
837         reslen = svcdata->reserve_len;
838
839         /* gss replay:
840          * version, subflavor, notify, major, minor,
841          * obj1(fake), obj2(fake)
842          */
843         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
844         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
845         *resp++ = cpu_to_le32(PTLRPC_GSS_PROC_ERR);
846         *resp++ = cpu_to_le32(major);
847         *resp++ = cpu_to_le32(minor);
848         *resp++ = 0;
849         *resp++ = 0;
850         reslen -= (4 * 4);
851         /* the actual sec data length */
852         *reslenp = cpu_to_le32(secdata_len);
853
854         req->rq_reply_state->rs_repdata_len += (secdata_len);
855         CWARN("prepare gss error notify(0x%x/0x%x) to %s\n", major, minor,
856                portals_nid2str(req->rq_peer.peer_ni->pni_number,
857                                req->rq_peer.peer_id.nid, nidstr));
858         RETURN(0);
859 }
860
861 static int
862 gss_svcsec_handle_init(struct ptlrpc_request *req,
863                        struct rpc_gss_wire_cred *gc,
864                        __u32 *secdata, __u32 seclen,
865                        enum ptlrpcs_error *res)
866 {
867         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
868         struct rsc          *rsci;
869         struct rsi          *rsikey, *rsip;
870         rawobj_t             tmpobj;
871         __u32 reslen,       *resp, *reslenp;
872         char                 nidstr[PTL_NALFMT_SIZE];
873         int                  rc;
874         ENTRY;
875
876         LASSERT(svcdata);
877
878         CWARN("processing gss init(%d) request from %s\n", gc->gc_proc,
879                portals_nid2str(req->rq_peer.peer_ni->pni_number,
880                                req->rq_peer.peer_id.nid, nidstr));
881
882         *res = PTLRPCS_BADCRED;
883         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
884
885         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
886             gc->gc_ctx.len != 0) {
887                 CERROR("proc %d, ctx_len %d: not really init?\n",
888                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
889                 RETURN(SVC_DROP);
890         }
891
892         OBD_ALLOC(rsikey, sizeof(*rsikey));
893         if (!rsikey) {
894                 CERROR("out of memory\n");
895                 RETURN(SVC_DROP);
896         }
897         cache_init(&rsikey->h);
898
899         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
900                 CERROR("fail to dup context handle\n");
901                 GOTO(out_rsikey, rc = SVC_DROP);
902         }
903         *res = PTLRPCS_BADVERF;
904         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
905                 CERROR("can't extract token\n");
906                 GOTO(out_rsikey, rc = SVC_DROP);
907         }
908         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
909                 CERROR("can't duplicate token\n");
910                 GOTO(out_rsikey, rc = SVC_DROP);
911         }
912
913         rsikey->naltype = (__u32) req->rq_peer.peer_ni->pni_number;
914         rsikey->netid = 0;
915         rsikey->nid = (__u64) req->rq_peer.peer_id.nid;
916
917         rsip = gssd_upcall(rsikey, &my_chandle);
918         if (!rsip) {
919                 CERROR("error in gssd_upcall.\n");
920                 GOTO(out_rsikey, rc = SVC_DROP);
921         }
922
923         rsci = gss_svc_searchbyctx(&rsip->out_handle);
924         if (!rsci) {
925                 CERROR("rsci still not mature yet?\n");
926
927                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
928                         rc = SVC_DROP;
929                 else
930                         rc = SVC_COMPLETE;
931
932                 GOTO(out_rsip, rc);
933         }
934         CWARN("svcsec create gss context %p(%u@%s)\n",
935                rsci, rsci->cred.vc_uid,
936                portals_nid2str(req->rq_peer.peer_ni->pni_number,
937                                req->rq_peer.peer_id.nid, nidstr));
938
939         svcdata->is_init = 1;
940         svcdata->reserve_len = 6 * 4 +
941                 size_round4(rsip->out_handle.len) +
942                 size_round4(rsip->out_token.len);
943
944         rc = lustre_pack_reply(req, 0, NULL, NULL);
945         if (rc) {
946                 CERROR("failed to pack reply, rc = %d\n", rc);
947                 GOTO(out, rc = SVC_DROP);
948         }
949
950         /* header */
951         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
952         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
953         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
954         *resp++ = cpu_to_le32(req->rq_replen);
955         reslenp = resp++;
956
957         resp += req->rq_replen / 4;
958         reslen = svcdata->reserve_len;
959
960         /* gss reply:
961          * status, major, minor, seq, out_handle, out_token
962          */
963         *resp++ = cpu_to_le32(PTLRPCS_OK);
964         *resp++ = cpu_to_le32(rsip->major_status);
965         *resp++ = cpu_to_le32(rsip->minor_status);
966         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
967         reslen -= (4 * 4);
968         if (rawobj_serialize(&rsip->out_handle,
969                              &resp, &reslen))
970                 LBUG();
971         if (rawobj_serialize(&rsip->out_token,
972                              &resp, &reslen))
973                 LBUG();
974         /* the actual sec data length */
975         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
976
977         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
978         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
979                "total size %d\n", req, req->rq_replen,
980                le32_to_cpu(*reslenp),
981                req->rq_reply_state->rs_repdata_len);
982
983         *res = PTLRPCS_OK;
984
985         req->rq_auth_uid = rsci->cred.vc_uid;
986         req->rq_remote_realm = rsci->remote_realm;
987         req->rq_mapped_uid = rsci->mapped_uid;
988
989         /* This is simplified since right now we doesn't support
990          * INIT_CONTINUE yet.
991          */
992         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
993                 struct ptlrpcs_wire_hdr *hdr;
994
995                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
996                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
997                 req->rq_reqlen = hdr->msg_len;
998
999                 rc = SVC_LOGIN;
1000         } else
1001                 rc = SVC_COMPLETE;
1002
1003 out:
1004         rsc_put(&rsci->h, &rsc_cache);
1005 out_rsip:
1006         rsi_put(&rsip->h, &rsi_cache);
1007 out_rsikey:
1008         rsi_put(&rsikey->h, &rsi_cache);
1009
1010         RETURN(rc);
1011 }
1012
1013 static int
1014 gss_svcsec_handle_data(struct ptlrpc_request *req,
1015                        struct rpc_gss_wire_cred *gc,
1016                        __u32 *secdata, __u32 seclen,
1017                        enum ptlrpcs_error *res)
1018 {
1019         struct rsc          *rsci;
1020         char                 nidstr[PTL_NALFMT_SIZE];
1021         __u32                major;
1022         int                  rc;
1023         ENTRY;
1024
1025         *res = PTLRPCS_GSS_CREDPROBLEM;
1026
1027         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1028         if (!rsci) {
1029                 CWARN("Invalid gss context handle from %s\n",
1030                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1031                                        req->rq_peer.peer_id.nid, nidstr));
1032                 major = GSS_S_NO_CONTEXT;
1033                 goto notify_err;
1034         }
1035
1036         switch (gc->gc_svc) {
1037         case PTLRPC_GSS_SVC_INTEGRITY:
1038                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1039                 if (major == GSS_S_COMPLETE)
1040                         break;
1041
1042                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1043                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1044                                        req->rq_peer.peer_id.nid, nidstr));
1045                 goto notify_err;
1046         case PTLRPC_GSS_SVC_PRIVACY:
1047                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1048                 if (major == GSS_S_COMPLETE)
1049                         break;
1050
1051                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1052                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1053                                        req->rq_peer.peer_id.nid, nidstr));
1054                 goto notify_err;
1055         default:
1056                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1057                 GOTO(out, rc = SVC_DROP);
1058         }
1059
1060         req->rq_auth_uid = rsci->cred.vc_uid;
1061         req->rq_remote_realm = rsci->remote_realm;
1062         req->rq_mapped_uid = rsci->mapped_uid;
1063
1064         *res = PTLRPCS_OK;
1065         GOTO(out, rc = SVC_OK);
1066
1067 notify_err:
1068         if (gss_pack_err_notify(req, major, 0))
1069                 rc = SVC_DROP;
1070         else
1071                 rc = SVC_COMPLETE;
1072 out:
1073         if (rsci)
1074                 rsc_put(&rsci->h, &rsc_cache);
1075         RETURN(rc);
1076 }
1077
1078 static int
1079 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1080                           struct rpc_gss_wire_cred *gc,
1081                           __u32 *secdata, __u32 seclen,
1082                           enum ptlrpcs_error *res)
1083 {
1084         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1085         struct rsc          *rsci;
1086         char                 nidstr[PTL_NALFMT_SIZE];
1087         int                  rc;
1088         ENTRY;
1089
1090         LASSERT(svcdata);
1091         *res = PTLRPCS_GSS_CREDPROBLEM;
1092
1093         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1094         if (!rsci) {
1095                 CWARN("invalid gss context handle for destroy.\n");
1096                 RETURN(SVC_DROP);
1097         }
1098
1099         if (gc->gc_svc != PTLRPC_GSS_SVC_INTEGRITY) {
1100                 CERROR("service %d is not supported in destroy.\n",
1101                         gc->gc_svc);
1102                 GOTO(out, rc = SVC_DROP);
1103         }
1104
1105         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1106         if (*res)
1107                 GOTO(out, rc = SVC_DROP);
1108
1109         /* compose reply, which is actually nothing */
1110         svcdata->is_fini = 1;
1111         if (lustre_pack_reply(req, 0, NULL, NULL))
1112                 GOTO(out, rc = SVC_DROP);
1113
1114         CWARN("svcsec destroy gss context %p(%u@%s)\n",
1115                rsci, rsci->cred.vc_uid,
1116                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1117                                req->rq_peer.peer_id.nid, nidstr));
1118
1119         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1120         *res = PTLRPCS_OK;
1121         rc = SVC_LOGOUT;
1122 out:
1123         rsc_put(&rsci->h, &rsc_cache);
1124         RETURN(rc);
1125 }
1126
1127 /*
1128  * let incomming request go through security check:
1129  *  o context establishment: invoke user space helper
1130  *  o data exchange: verify/decrypt
1131  *  o context destruction: mark context invalid
1132  *
1133  * in most cases, error will result to drop the packet silently.
1134  */
1135 static int
1136 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1137 {
1138         struct gss_svc_data *svcdata;
1139         struct rpc_gss_wire_cred *gc;
1140         struct ptlrpcs_wire_hdr *sec_hdr;
1141         __u32 seclen, *secdata, version, subflavor;
1142         int rc;
1143         ENTRY;
1144
1145         CDEBUG(D_SEC, "request %p\n", req);
1146         LASSERT(req->rq_reqbuf);
1147         LASSERT(req->rq_reqbuf_len);
1148
1149         *res = PTLRPCS_BADCRED;
1150
1151         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1152         LASSERT(sec_hdr->flavor == PTLRPC_SEC_GSS);
1153
1154         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1155         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1156
1157         if (sec_hdr->sec_len > seclen) {
1158                 CERROR("seclen %d, while max buf %d\n",
1159                         sec_hdr->sec_len, seclen);
1160                 RETURN(SVC_DROP);
1161         }
1162
1163         if (seclen < 6 * 4) {
1164                 CERROR("sec size %d too small\n", seclen);
1165                 RETURN(SVC_DROP);
1166         }
1167
1168         LASSERT(!req->rq_sec_svcdata);
1169         OBD_ALLOC(svcdata, sizeof(*svcdata));
1170         if (!svcdata) {
1171                 CERROR("fail to alloc svcdata\n");
1172                 RETURN(SVC_DROP);
1173         }
1174         req->rq_sec_svcdata = svcdata;
1175         gc = &svcdata->clcred;
1176
1177         /* Now secdata/seclen is what we want to parse
1178          */
1179         version = le32_to_cpu(*secdata++);      /* version */
1180         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1181         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1182         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1183         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1184         seclen -= 5 * 4;
1185
1186         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1187                version, subflavor, gc->gc_proc, gc->gc_seq, gc->gc_svc);
1188
1189         if (version != PTLRPC_SEC_GSS_VERSION) {
1190                 CERROR("gss version mismatch: %d - %d\n",
1191                         version, PTLRPC_SEC_GSS_VERSION);
1192                 GOTO(err_free, rc = SVC_DROP);
1193         }
1194
1195         if (rawobj_extract(&gc->gc_ctx, &secdata, &seclen)) {
1196                 CERROR("fail to obtain gss context handle\n");
1197                 GOTO(err_free, rc = SVC_DROP);
1198         }
1199
1200         *res = PTLRPCS_BADVERF;
1201         switch(gc->gc_proc) {
1202         case RPC_GSS_PROC_INIT:
1203         case RPC_GSS_PROC_CONTINUE_INIT:
1204                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1205                 break;
1206         case RPC_GSS_PROC_DATA:
1207                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1208                 break;
1209         case RPC_GSS_PROC_DESTROY:
1210                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1211                 break;
1212         default:
1213                 rc = SVC_DROP;
1214                 LBUG();
1215         }
1216
1217 err_free:
1218         if (rc == SVC_DROP && req->rq_sec_svcdata) {
1219                 OBD_FREE(req->rq_sec_svcdata, sizeof(struct gss_svc_data));
1220                 req->rq_sec_svcdata = NULL;
1221         }
1222
1223         RETURN(rc);
1224 }
1225
1226 static int
1227 gss_svcsec_authorize(struct ptlrpc_request *req)
1228 {
1229         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1230         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_sec_svcdata;
1231         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1232         struct rsc                *rscp;
1233         struct ptlrpcs_wire_hdr   *sec_hdr;
1234         rawobj_buf_t               msg_buf;
1235         rawobj_t                   cipher_buf;
1236         __u32                     *vp, *vpsave, major, vlen, seclen;
1237         rawobj_t                   lmsg, mic;
1238         int                        ret;
1239         ENTRY;
1240
1241         LASSERT(rs);
1242         LASSERT(rs->rs_repbuf);
1243         LASSERT(gsd);
1244
1245         if (gsd->is_init || gsd->is_init_continue ||
1246             gsd->is_err_notify || gsd->is_fini) {
1247                 /* nothing to do in these cases */
1248                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1249                 RETURN(0);
1250         }
1251
1252         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1253                 CERROR("proc %d not support\n", gc->gc_proc);
1254                 RETURN(-EINVAL);
1255         }
1256
1257         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1258         if (!rscp) {
1259                 CERROR("ctx disapeared under us?\n");
1260                 RETURN(-EINVAL);
1261         }
1262
1263         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1264         switch (gc->gc_svc) {
1265         case  PTLRPC_GSS_SVC_INTEGRITY:
1266                 /* prepare various pointers */
1267                 lmsg.len = req->rq_replen;
1268                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1269                 vp = (__u32 *) (lmsg.data + lmsg.len);
1270                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1271                 seclen = vlen;
1272
1273                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1274                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_AUTH);
1275                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1276
1277                 /* standard gss hdr */
1278                 LASSERT(vlen >= 7 * 4);
1279                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1280                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1281                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1282                 *vp++ = cpu_to_le32(gc->gc_seq);
1283                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_INTEGRITY);
1284                 *vp++ = 0;      /* fake ctx handle */
1285                 vpsave = vp++;  /* reserve size */
1286                 vlen -= 7 * 4;
1287
1288                 mic.len = vlen;
1289                 mic.data = (char *) vp;
1290
1291                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1292                 if (major) {
1293                         CERROR("fail to get MIC: 0x%x\n", major);
1294                         GOTO(out, ret = -EINVAL);
1295                 }
1296                 *vpsave = cpu_to_le32(mic.len);
1297                 seclen = seclen - vlen + mic.len;
1298                 sec_hdr->sec_len = cpu_to_le32(seclen);
1299                 rs->rs_repdata_len += size_round(seclen);
1300                 break;
1301         case  PTLRPC_GSS_SVC_PRIVACY:
1302                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1303                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1304                 seclen = vlen;
1305
1306                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1307                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_PRIV);
1308                 sec_hdr->msg_len = cpu_to_le32(0);
1309
1310                 /* standard gss hdr */
1311                 LASSERT(vlen >= 7 * 4);
1312                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1313                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1314                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1315                 *vp++ = cpu_to_le32(gc->gc_seq);
1316                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_PRIVACY);
1317                 *vp++ = 0;      /* fake ctx handle */
1318                 vpsave = vp++;  /* reserve size */
1319                 vlen -= 7 * 4;
1320
1321                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1322                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1323                                  GSS_PRIVBUF_SUFFIX_LEN;
1324                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1325                 msg_buf.datalen = req->rq_replen;
1326
1327                 cipher_buf.data = (__u8 *) vp;
1328                 cipher_buf.len = vlen;
1329
1330                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1331                                 &msg_buf, &cipher_buf);
1332                 if (major) {
1333                         CERROR("failed to wrap: 0x%x\n", major);
1334                         GOTO(out, ret = -EINVAL);
1335                 }
1336
1337                 *vpsave = cpu_to_le32(cipher_buf.len);
1338                 seclen = seclen - vlen + cipher_buf.len;
1339                 sec_hdr->sec_len = cpu_to_le32(seclen);
1340                 rs->rs_repdata_len += size_round(seclen);
1341                 break;
1342         default:
1343                 CERROR("Unknown service %d\n", gc->gc_svc);
1344                 GOTO(out, ret = -EINVAL);
1345         }
1346         ret = 0;
1347 out:
1348         rsc_put(&rscp->h, &rsc_cache);
1349
1350         RETURN(ret);
1351 }
1352
1353 static
1354 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1355                             struct ptlrpc_request *req)
1356 {
1357         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1358
1359         if (!gsd) {
1360                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1361                 return;
1362         }
1363
1364         /* gsd->clclred.gc_ctx is NOT allocated, just set pointer
1365          * to the incoming packet buffer, so don't need free it
1366          */
1367         OBD_FREE(gsd, sizeof(*gsd));
1368         req->rq_sec_svcdata = NULL;
1369         return;
1370 }
1371
1372 static
1373 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1374                            struct ptlrpc_request *req,
1375                            int msgsize)
1376 {
1377         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1378         ENTRY;
1379
1380         /* just return the pre-set reserve_len for init/fini/err cases.
1381          */
1382         LASSERT(svcdata);
1383         if (svcdata->is_init) {
1384                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1385                        size_round(svcdata->reserve_len),
1386                        svcdata->reserve_len);
1387                 LASSERT(svcdata->reserve_len);
1388                 LASSERT(svcdata->reserve_len % 4 == 0);
1389                 RETURN(size_round(svcdata->reserve_len));
1390         } else if (svcdata->is_err_notify) {
1391                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1392                        size_round(svcdata->reserve_len),
1393                        svcdata->reserve_len);
1394                 RETURN(size_round(svcdata->reserve_len));
1395         } else if (svcdata->is_fini) {
1396                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1397                 RETURN(0);
1398         } else {
1399                 if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_NONE ||
1400                     svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_INTEGRITY)
1401                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1402                 else if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1403                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1404                                             GSS_PRIVBUF_PREFIX_LEN +
1405                                             GSS_PRIVBUF_SUFFIX_LEN));
1406                 else {
1407                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1408                         *((int *)0) = 0;
1409                         LBUG();
1410                 }
1411         }
1412         RETURN(0);
1413 }
1414
1415 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1416                             struct ptlrpc_request *req,
1417                             int msgsize)
1418 {
1419         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1420         struct ptlrpc_reply_state *rs;
1421         int msg_payload, sec_payload;
1422         int privacy, rc;
1423         ENTRY;
1424
1425         /* determine the security type: none/auth or priv, we have
1426          * different pack scheme for them.
1427          * init/fini/err will always be treated as none/auth.
1428          */
1429         LASSERT(gsd);
1430         if (!gsd->is_init && !gsd->is_init_continue &&
1431             !gsd->is_fini && !gsd->is_err_notify &&
1432             gsd->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1433                 privacy = 1;
1434         else
1435                 privacy = 0;
1436
1437         msg_payload = privacy ? 0 : msgsize;
1438         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1439
1440         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1441         if (rc)
1442                 RETURN(rc);
1443
1444         rs = req->rq_reply_state;
1445         LASSERT(rs);
1446         rs->rs_msg_len = msgsize;
1447
1448         if (privacy) {
1449                 /* we can choose to let msg simply point to the rear of the
1450                  * buffer, which lead to buffer overlap when doing encryption.
1451                  * usually it's ok and it indeed passed all existing tests.
1452                  * but not sure if there will be subtle problems in the future.
1453                  * so right now we choose to alloc another new buffer. we'll
1454                  * see how it works.
1455                  */
1456 #if 0
1457                 rs->rs_msg = (struct lustre_msg *)
1458                              (rs->rs_repbuf + rs->rs_repbuf_len -
1459                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1460 #endif
1461                 char *msgbuf;
1462
1463                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1464                 OBD_ALLOC(msgbuf, msgsize);
1465                 if (!msgbuf) {
1466                         CERROR("can't alloc %d\n", msgsize);
1467                         svcsec_free_reply_state(rs);
1468                         req->rq_reply_state = NULL;
1469                         RETURN(-ENOMEM);
1470                 }
1471                 rs->rs_msg = (struct lustre_msg *)
1472                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1473         }
1474
1475         req->rq_repmsg = rs->rs_msg;
1476
1477         RETURN(0);
1478 }
1479
1480 static
1481 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1482                             struct ptlrpc_reply_state *rs)
1483 {
1484         unsigned long p1 = (unsigned long) rs->rs_msg;
1485         unsigned long p2 = (unsigned long) rs->rs_buf;
1486
1487         LASSERT(rs->rs_buf);
1488         LASSERT(rs->rs_msg);
1489         LASSERT(rs->rs_msg_len);
1490
1491         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1492                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1493                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1494                            GSS_PRIVBUF_SUFFIX_LEN;
1495                 OBD_FREE(start, size);
1496         }
1497
1498         svcsec_free_reply_state(rs);
1499 }
1500
1501 struct ptlrpc_svcsec svcsec_gss = {
1502         .pss_owner              = THIS_MODULE,
1503         .pss_name               = "GSS_SVCSEC",
1504         .pss_flavor             = {PTLRPC_SEC_GSS, 0},
1505         .accept                 = gss_svcsec_accept,
1506         .authorize              = gss_svcsec_authorize,
1507         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1508         .free_repbuf            = gss_svcsec_free_repbuf,
1509         .cleanup_req            = gss_svcsec_cleanup_req,
1510 };
1511
1512 /* XXX hacking */
1513 void lgss_svc_cache_purge_all(void)
1514 {
1515         cache_purge(&rsi_cache);
1516         cache_purge(&rsc_cache);
1517 }
1518 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1519
1520 void lgss_svc_cache_flush(__u32 uid)
1521 {
1522         rsc_flush(uid);
1523 }
1524 EXPORT_SYMBOL(lgss_svc_cache_flush);
1525
1526 int gss_svc_init(void)
1527 {
1528         int rc;
1529
1530         rc = svcsec_register(&svcsec_gss);
1531         if (!rc) {
1532                 cache_register(&rsc_cache);
1533                 cache_register(&rsi_cache);
1534         }
1535         return rc;
1536 }
1537
1538 void gss_svc_exit(void)
1539 {
1540         int rc;
1541         if ((rc = cache_unregister(&rsi_cache)))
1542                 CERROR("unregister rsi cache: %d\n", rc);
1543         if ((rc = cache_unregister(&rsc_cache)))
1544                 CERROR("unregister rsc cache: %d\n", rc);
1545         if ((rc = svcsec_unregister(&svcsec_gss)))
1546                 CERROR("unregister svcsec_gss: %d\n", rc);
1547 }