Whamcloud - gitweb
1ac060e759a54cf434fbce1266c4534823f99d24
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         rawobj_t                in_handle, in_token;
109         rawobj_t                out_handle, out_token;
110         int                     major_status, minor_status;
111 };
112
113 static struct cache_head *rsi_table[RSI_HASHMAX];
114 static struct cache_detail rsi_cache;
115
116 static void rsi_free(struct rsi *rsii)
117 {
118         rawobj_free(&rsii->in_handle);
119         rawobj_free(&rsii->in_token);
120         rawobj_free(&rsii->out_handle);
121         rawobj_free(&rsii->out_token);
122 }
123
124 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
125 {
126         struct rsi *rsii = container_of(item, struct rsi, h);
127         if (cache_put(item, cd)) {
128                 rsi_free(rsii);
129                 OBD_FREE(rsii, sizeof(*rsii));
130         }
131 }
132
133 static inline int rsi_hash(struct rsi *item)
134 {
135         return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
136               ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS);
137 }
138
139 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
140 {
141         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
142                 rawobj_equal(&item->in_token, &tmp->in_token));
143 }
144
145 static void rsi_request(struct cache_detail *cd,
146                         struct cache_head *h,
147                         char **bpp, int *blen)
148 {
149         struct rsi *rsii = container_of(h, struct rsi, h);
150
151         qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len);
152         qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len);
153         (*bpp)[-1] = '\n';
154 }
155
156 static int
157 gssd_reply(struct rsi *item)
158 {
159         struct rsi *tmp;
160         struct cache_head **hp, **head;
161         ENTRY;
162
163         head = &rsi_cache.hash_table[rsi_hash(item)];
164         write_lock(&rsi_cache.hash_lock);
165         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
166                 tmp = container_of(*hp, struct rsi, h);
167                 if (rsi_match(tmp, item)) {
168                         cache_get(&tmp->h);
169                         clear_bit(CACHE_HASHED, &tmp->h.flags);
170                         *hp = tmp->h.next;
171                         tmp->h.next = NULL;
172                         rsi_cache.entries--;
173                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
174                                 write_unlock(&rsi_cache.hash_lock);
175                                 rsi_put(&tmp->h, &rsi_cache);
176                                 RETURN(-EINVAL);
177                         }
178                         set_bit(CACHE_HASHED, &item->h.flags);
179                         item->h.next = *hp;
180                         *hp = &item->h;
181                         rsi_cache.entries++;
182                         set_bit(CACHE_VALID, &item->h.flags);
183                         item->h.last_refresh = get_seconds();
184                         write_unlock(&rsi_cache.hash_lock);
185                         cache_fresh(&rsi_cache, &tmp->h, 0);
186                         rsi_put(&tmp->h, &rsi_cache);
187                         RETURN(0);
188                 }
189         }
190         write_unlock(&rsi_cache.hash_lock);
191         RETURN(-EINVAL);
192 }
193
194 /* XXX
195  * here we just wait here for its completion or timedout. it's a
196  * hacking but works, and we'll comeup with real fix if we decided
197  * to still stick with NFS4 cache code
198  */
199 static struct rsi *
200 gssd_upcall(struct rsi *item, struct cache_req *chandle)
201 {
202         struct rsi *tmp;
203         struct cache_head **hp, **head;
204         unsigned long starttime;
205         ENTRY;
206
207         head = &rsi_cache.hash_table[rsi_hash(item)];
208         read_lock(&rsi_cache.hash_lock);
209         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
210                 tmp = container_of(*hp, struct rsi, h);
211                 if (rsi_match(tmp, item)) {
212                         LBUG();
213                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
214                                 CERROR("found rsi without VALID\n");
215                                 read_unlock(&rsi_cache.hash_lock);
216                                 return NULL;
217                         }
218                         *hp = tmp->h.next;
219                         tmp->h.next = NULL;
220                         rsi_cache.entries--;
221                         cache_get(&tmp->h);
222                         read_unlock(&rsi_cache.hash_lock);
223                         return tmp;
224                 }
225         }
226         // cache_get(&item->h);
227         set_bit(CACHE_HASHED, &item->h.flags);
228         item->h.next = *head;
229         *head = &item->h;
230         rsi_cache.entries++;
231         read_unlock(&rsi_cache.hash_lock);
232         cache_get(&item->h);
233
234         cache_check(&rsi_cache, &item->h, chandle);
235         starttime = get_seconds();
236         do {
237                 yield();
238                 read_lock(&rsi_cache.hash_lock);
239                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
240                         tmp = container_of(*hp, struct rsi, h);
241                         if (tmp == item)
242                                 continue;
243                         if (rsi_match(tmp, item)) {
244                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
245                                         read_unlock(&rsi_cache.hash_lock);
246                                         return NULL;
247                                 }
248                                 cache_get(&tmp->h);
249                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
250                                 *hp = tmp->h.next;
251                                 tmp->h.next = NULL;
252                                 rsi_cache.entries--;
253                                 read_unlock(&rsi_cache.hash_lock);
254                                 return tmp;
255                         }
256                 }
257                 read_unlock(&rsi_cache.hash_lock);
258         } while ((get_seconds() - starttime) <= 5);
259         CERROR("5s timeout while waiting cache refill\n");
260         return NULL;
261 }
262
263 static int rsi_parse(struct cache_detail *cd,
264                      char *mesg, int mlen)
265 {
266         /* context token expiry major minor context token */
267         char *buf = mesg;
268         char *ep;
269         int len;
270         struct rsi *rsii;
271         time_t expiry;
272         int status = -EINVAL;
273         ENTRY;
274
275         OBD_ALLOC(rsii, sizeof(*rsii));
276         if (!rsii) {
277                 CERROR("failed to alloc rsii\n");
278                 RETURN(-ENOMEM);
279         }
280         cache_init(&rsii->h);
281
282         /* handle */
283         len = qword_get(&mesg, buf, mlen);
284         if (len < 0)
285                 goto out;
286         status = -ENOMEM;
287         if (rawobj_alloc(&rsii->in_handle, buf, len))
288                 goto out;
289
290         /* token */
291         len = qword_get(&mesg, buf, mlen);
292         status = -EINVAL;
293         if (len < 0)
294                 goto out;;
295         status = -ENOMEM;
296         if (rawobj_alloc(&rsii->in_token, buf, len))
297                 goto out;
298
299         /* expiry */
300         expiry = get_expiry(&mesg);
301         status = -EINVAL;
302         if (expiry == 0)
303                 goto out;
304
305         /* major/minor */
306         len = qword_get(&mesg, buf, mlen);
307         if (len < 0)
308                 goto out;
309         if (len == 0) {
310                 goto out;
311         } else {
312                 rsii->major_status = simple_strtoul(buf, &ep, 10);
313                 if (*ep)
314                         goto out;
315                 len = qword_get(&mesg, buf, mlen);
316                 if (len <= 0)
317                         goto out;
318                 rsii->minor_status = simple_strtoul(buf, &ep, 10);
319                 if (*ep)
320                         goto out;
321
322                 /* out_handle */
323                 len = qword_get(&mesg, buf, mlen);
324                 if (len < 0)
325                         goto out;
326                 status = -ENOMEM;
327                 if (rawobj_alloc(&rsii->out_handle, buf, len))
328                         goto out;
329
330                 /* out_token */
331                 len = qword_get(&mesg, buf, mlen);
332                 status = -EINVAL;
333                 if (len < 0)
334                         goto out;
335                 status = -ENOMEM;
336                 if (rawobj_alloc(&rsii->out_token, buf, len))
337                         goto out;
338         }
339         rsii->h.expiry_time = expiry;
340         status = gssd_reply(rsii);
341 out:
342         if (rsii)
343                 rsi_put(&rsii->h, &rsi_cache);
344         RETURN(status);
345 }
346
347 static struct cache_detail rsi_cache = {
348         .hash_size      = RSI_HASHMAX,
349         .hash_table     = rsi_table,
350         .name           = "auth.ptlrpcs.init",
351         .cache_put      = rsi_put,
352         .cache_request  = rsi_request,
353         .cache_parse    = rsi_parse,
354 };
355
356 /*
357  * The rpcsec_context cache is used to store a context that is
358  * used in data exchange.
359  * The key is a context handle. The content is:
360  *  uid, gidlist, mechanism, service-set, mech-specific-data
361  */
362
363 #define RSC_HASHBITS    10
364 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
365 #define RSC_HASHMASK    (RSC_HASHMAX-1)
366
367 #define GSS_SEQ_WIN     128
368
369 struct gss_svc_seq_data {
370         /* highest seq number seen so far: */
371         __u32                   sd_max;
372         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
373          * sd_win is nonzero iff sequence number i has been seen already: */
374         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
375         spinlock_t              sd_lock;
376 };
377
378 struct rsc {
379         struct cache_head       h;
380         rawobj_t                handle;
381         __u32                   remote;
382         struct vfs_cred         cred;
383         struct gss_svc_seq_data seqdata;
384         struct gss_ctx         *mechctx;
385 };
386
387 static struct cache_head *rsc_table[RSC_HASHMAX];
388 static struct cache_detail rsc_cache;
389
390 static void rsc_free(struct rsc *rsci)
391 {
392         rawobj_free(&rsci->handle);
393         if (rsci->mechctx)
394                 kgss_delete_sec_context(&rsci->mechctx);
395 #if 0
396         if (rsci->cred.vc_ginfo)
397                 put_group_info(rsci->cred.vc_ginfo);
398 #endif
399 }
400
401 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
402 {
403         struct rsc *rsci = container_of(item, struct rsc, h);
404
405         if (cache_put(item, cd)) {
406                 rsc_free(rsci);
407                 OBD_FREE(rsci, sizeof(*rsci));
408         }
409 }
410
411 static inline int
412 rsc_hash(struct rsc *rsci)
413 {
414         return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS);
415 }
416
417 static inline int
418 rsc_match(struct rsc *new, struct rsc *tmp)
419 {
420         return rawobj_equal(&new->handle, &tmp->handle);
421 }
422
423 static struct rsc *rsc_lookup(struct rsc *item, int set)
424 {
425         struct rsc *tmp = NULL;
426         struct cache_head **hp, **head;
427         head = &rsc_cache.hash_table[rsc_hash(item)];
428         ENTRY;
429
430         if (set)
431                 write_lock(&rsc_cache.hash_lock);
432         else
433                 read_lock(&rsc_cache.hash_lock);
434         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
435                 tmp = container_of(*hp, struct rsc, h);
436                 if (!rsc_match(tmp, item))
437                         continue;
438                 cache_get(&tmp->h);
439                 if (!set) {
440                         goto out_noset;
441                 }
442                 *hp = tmp->h.next;
443                 tmp->h.next = NULL;
444                 clear_bit(CACHE_HASHED, &tmp->h.flags);
445                 rsc_put(&tmp->h, &rsc_cache);
446                 goto out_set;
447         }
448         /* Didn't find anything */
449         if (!set)
450                 goto out_noset;
451         rsc_cache.entries++;
452 out_set:
453         set_bit(CACHE_HASHED, &item->h.flags);
454         item->h.next = *head;
455         *head = &item->h;
456         write_unlock(&rsc_cache.hash_lock);
457         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
458         cache_get(&item->h);
459         RETURN(item);
460 out_noset:
461         read_unlock(&rsc_cache.hash_lock);
462         RETURN(tmp);
463 }
464                                                                                                                         
465 static int rsc_parse(struct cache_detail *cd,
466                      char *mesg, int mlen)
467 {
468         /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */
469         char *buf = mesg;
470         int len, rv;
471         struct rsc *rsci, *res = NULL;
472         time_t expiry;
473         int status = -EINVAL;
474
475         OBD_ALLOC(rsci, sizeof(*rsci));
476         if (!rsci) {
477                 CERROR("fail to alloc rsci\n");
478                 return -ENOMEM;
479         }
480         cache_init(&rsci->h);
481
482         /* context handle */
483         len = qword_get(&mesg, buf, mlen);
484         if (len < 0) goto out;
485         status = -ENOMEM;
486         if (rawobj_alloc(&rsci->handle, buf, len))
487                 goto out;
488
489         /* expiry */
490         expiry = get_expiry(&mesg);
491         status = -EINVAL;
492         if (expiry == 0)
493                 goto out;
494
495         /* remote flag */
496         rv = get_int(&mesg, &rsci->remote);
497         if (rv) {
498                 CERROR("fail to get remote flag\n");
499                 goto out;
500         }
501
502         /* uid, or NEGATIVE */
503         rv = get_int(&mesg, &rsci->cred.vc_uid);
504         if (rv == -EINVAL)
505                 goto out;
506         if (rv == -ENOENT)
507                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
508         else {
509                 int N, i;
510                 struct gss_api_mech *gm;
511                 rawobj_t tmp_buf;
512                 __u64 ctx_expiry;
513
514                 /* gid */
515                 if (get_int(&mesg, &rsci->cred.vc_gid))
516                         goto out;
517
518                 /* number of additional gid's */
519                 if (get_int(&mesg, &N))
520                         goto out;
521                 status = -ENOMEM;
522 #if 0
523                 rsci->cred.vc_ginfo = groups_alloc(N);
524                 if (rsci->cred.vc_ginfo == NULL)
525                         goto out;
526 #endif
527
528                 /* gid's */
529                 status = -EINVAL;
530                 for (i=0; i<N; i++) {
531                         gid_t gid;
532                         if (get_int(&mesg, &gid))
533                                 goto out;
534 #if 0
535                         GROUP_AT(rsci->cred.vc_ginfo, i) = gid;
536 #endif
537                 }
538
539                 /* mech name */
540                 len = qword_get(&mesg, buf, mlen);
541                 if (len < 0)
542                         goto out;
543                 gm = kgss_name_to_mech(buf);
544                 status = -EOPNOTSUPP;
545                 if (!gm)
546                         goto out;
547
548                 status = -EINVAL;
549                 /* mech-specific data: */
550                 len = qword_get(&mesg, buf, mlen);
551                 if (len < 0) {
552                         kgss_mech_put(gm);
553                         goto out;
554                 }
555                 tmp_buf.len = len;
556                 tmp_buf.data = buf;
557                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
558                         kgss_mech_put(gm);
559                         goto out;
560                 }
561
562                 /* currently the expiry time passed down from user-space
563                  * is invalid, here we retrive it from mech.
564                  */
565                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
566                         CERROR("unable to get expire time, drop it\n");
567                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
568                         kgss_mech_put(gm);
569                         goto out;
570                 }
571                 expiry = (time_t) ctx_expiry;
572
573                 kgss_mech_put(gm);
574         }
575         rsci->h.expiry_time = expiry;
576         spin_lock_init(&rsci->seqdata.sd_lock);
577         res = rsc_lookup(rsci, 1);
578         rsc_put(&res->h, &rsc_cache);
579         status = 0;
580 out:
581         if (rsci)
582                 rsc_put(&rsci->h, &rsc_cache);
583         return status;
584 }
585
586 /*
587  * flush all entries with @uid. @uid == -1 will match all.
588  * we only know the uid, maybe netid/nid in the future, in all cases
589  * we must search the whole cache
590  */
591 static void rsc_flush(uid_t uid)
592 {
593         struct cache_head **ch;
594         struct rsc *rscp;
595         int n;
596         ENTRY;
597
598         write_lock(&rsc_cache.hash_lock);
599         for (n = 0; n < RSC_HASHMAX; n++) {
600                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
601                         rscp = container_of(*ch, struct rsc, h);
602                         if (uid == -1 || rscp->cred.vc_uid == uid) {
603                                 /* it seems simply set NEGATIVE doesn't work */
604                                 *ch = (*ch)->next;
605                                 rscp->h.next = NULL;
606                                 cache_get(&rscp->h);
607                                 set_bit(CACHE_NEGATIVE, &rscp->h.flags);
608                                 clear_bit(CACHE_HASHED, &rscp->h.flags);
609                                 CWARN("flush rsc %p for uid %u\n",
610                                        rscp, rscp->cred.vc_uid);
611                                 rsc_put(&rscp->h, &rsc_cache);
612                                 rsc_cache.entries--;
613                                 continue;
614                         }
615                         ch = &((*ch)->next);
616                 }
617         }
618         write_unlock(&rsc_cache.hash_lock);
619         EXIT;
620 }
621
622 static struct cache_detail rsc_cache = {
623         .hash_size      = RSC_HASHMAX,
624         .hash_table     = rsc_table,
625         .name           = "auth.ptlrpcs.context",
626         .cache_put      = rsc_put,
627         .cache_parse    = rsc_parse,
628 };
629
630 static struct rsc *
631 gss_svc_searchbyctx(rawobj_t *handle)
632 {
633         struct rsc rsci;
634         struct rsc *found;
635
636         rsci.handle = *handle;
637         found = rsc_lookup(&rsci, 0);
638         if (!found)
639                 return NULL;
640
641         if (cache_check(&rsc_cache, &found->h, NULL))
642                 return NULL;
643
644         return found;
645 }
646
647 struct gss_svc_data {
648         /* decoded gss client cred: */
649         struct rpc_gss_wire_cred        clcred;
650         /* internal used status */
651         unsigned int                    is_init:1,
652                                         is_init_continue:1,
653                                         is_err_notify:1,
654                                         is_fini:1;
655         int                             reserve_len;
656 };
657
658 /* FIXME
659  * again hacking: only try to give the svcgssd a chance to handle
660  * upcalls.
661  */
662 struct cache_deferred_req* my_defer(struct cache_req *req)
663 {
664         yield();
665         return NULL;
666 }
667 static struct cache_req my_chandle = {my_defer};
668
669 /* Implements sequence number algorithm as specified in RFC 2203. */
670 static int
671 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
672 {
673         int rc = 0;
674
675         spin_lock(&sd->sd_lock);
676         if (seq_num > sd->sd_max) {
677                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
678                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
679                         sd->sd_max = seq_num;
680                 } else {
681                         while(sd->sd_max < seq_num) {
682                                 sd->sd_max++;
683                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
684                                             sd->sd_win);
685                         }
686                 }
687                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
688                 goto exit;
689         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
690                 rc = 1;
691                 goto exit;
692         }
693
694         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win))
695                 rc = 1;
696 exit:
697         spin_unlock(&sd->sd_lock);
698         return rc;
699 }
700
701 static int
702 gss_svc_verify_request(struct ptlrpc_request *req,
703                        struct rsc *rsci,
704                        struct rpc_gss_wire_cred *gc,
705                        __u32 *vp, __u32 vlen)
706 {
707         struct ptlrpcs_wire_hdr *sec_hdr;
708         struct gss_ctx *ctx = rsci->mechctx;
709         __u32 maj_stat;
710         rawobj_t msg;
711         rawobj_t mic;
712         ENTRY;
713
714         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
715
716         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
717         req->rq_reqlen = sec_hdr->msg_len;
718
719         msg.len = sec_hdr->msg_len;
720         msg.data = (__u8 *)req->rq_reqmsg;
721
722         mic.len = le32_to_cpu(*vp++);
723         mic.data = (char *) vp;
724         vlen -= 4;
725
726         if (mic.len > vlen) {
727                 CERROR("checksum len %d, while buffer len %d\n",
728                         mic.len, vlen);
729                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
730         }
731
732         if (mic.len > 256) {
733                 CERROR("invalid mic len %d\n", mic.len);
734                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
735         }
736
737         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
738         if (maj_stat != GSS_S_COMPLETE) {
739                 CERROR("MIC verification error: major %x\n", maj_stat);
740                 RETURN(maj_stat);
741         }
742
743         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
744                 CERROR("discard request %p with old seq_num %u\n",
745                         req, gc->gc_seq);
746                 RETURN(GSS_S_DUPLICATE_TOKEN);
747         }
748
749         RETURN(GSS_S_COMPLETE);
750 }
751
752 static int
753 gss_svc_unseal_request(struct ptlrpc_request *req,
754                        struct rsc *rsci,
755                        struct rpc_gss_wire_cred *gc,
756                        __u32 *vp, __u32 vlen)
757 {
758         struct ptlrpcs_wire_hdr *sec_hdr;
759         struct gss_ctx *ctx = rsci->mechctx;
760         rawobj_t cipher_text, plain_text;
761         __u32 major;
762         ENTRY;
763
764         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
765
766         if (vlen < 4) {
767                 CERROR("vlen only %u\n", vlen);
768                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
769         }
770
771         cipher_text.len = le32_to_cpu(*vp++);
772         cipher_text.data = (__u8 *) vp;
773         vlen -= 4;
774         
775         if (cipher_text.len > vlen) {
776                 CERROR("cipher claimed %u while buf only %u\n",
777                         cipher_text.len, vlen);
778                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
779         }
780
781         plain_text = cipher_text;
782
783         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
784         if (major) {
785                 CERROR("unwrap error 0x%x\n", major);
786                 RETURN(major);
787         }
788
789         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
790                 CERROR("discard request %p with old seq_num %u\n",
791                         req, gc->gc_seq);
792                 RETURN(GSS_S_DUPLICATE_TOKEN);
793         }
794
795         req->rq_reqmsg = (struct lustre_msg *) (vp);
796         req->rq_reqlen = plain_text.len;
797
798         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
799
800         RETURN(GSS_S_COMPLETE);
801 }
802
803 static int
804 gss_pack_err_notify(struct ptlrpc_request *req,
805                     __u32 major, __u32 minor)
806 {
807         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
808         __u32 reslen, *resp, *reslenp;
809         char  nidstr[PTL_NALFMT_SIZE];
810         const __u32 secdata_len = 7 * 4;
811         int rc;
812         ENTRY;
813
814         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
815
816         LASSERT(svcdata);
817         svcdata->is_err_notify = 1;
818         svcdata->reserve_len = 7 * 4;
819
820         rc = lustre_pack_reply(req, 0, NULL, NULL);
821         if (rc) {
822                 CERROR("could not pack reply, err %d\n", rc);
823                 RETURN(rc);
824         }
825
826         LASSERT(req->rq_reply_state);
827         LASSERT(req->rq_reply_state->rs_repbuf);
828         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
829         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
830
831         /* header */
832         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
833         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
834         *resp++ = cpu_to_le32(req->rq_replen);
835         reslenp = resp++;
836
837         /* skip lustre msg */
838         resp += req->rq_replen / 4;
839         reslen = svcdata->reserve_len;
840
841         /* gss replay:
842          * version, subflavor, notify, major, minor,
843          * obj1(fake), obj2(fake)
844          */
845         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
846         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
847         *resp++ = cpu_to_le32(PTLRPC_GSS_PROC_ERR);
848         *resp++ = cpu_to_le32(major);
849         *resp++ = cpu_to_le32(minor);
850         *resp++ = 0;
851         *resp++ = 0;
852         reslen -= (4 * 4);
853         /* the actual sec data length */
854         *reslenp = cpu_to_le32(secdata_len);
855
856         req->rq_reply_state->rs_repdata_len += (secdata_len);
857         CWARN("prepare gss error notify(0x%x/0x%x) to %s\n", major, minor,
858                portals_nid2str(req->rq_peer.peer_ni->pni_number,
859                                req->rq_peer.peer_id.nid, nidstr));
860         RETURN(0);
861 }
862
863 static int
864 gss_svcsec_handle_init(struct ptlrpc_request *req,
865                        struct rpc_gss_wire_cred *gc,
866                        __u32 *secdata, __u32 seclen,
867                        enum ptlrpcs_error *res)
868 {
869         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
870         struct rsc          *rsci;
871         struct rsi          *rsikey, *rsip;
872         rawobj_t             tmpobj;
873         __u32 reslen,       *resp, *reslenp;
874         char                 nidstr[PTL_NALFMT_SIZE];
875         int                  rc;
876         ENTRY;
877
878         LASSERT(svcdata);
879
880         CWARN("processing gss init(%d) request from %s\n", gc->gc_proc,
881                portals_nid2str(req->rq_peer.peer_ni->pni_number,
882                                req->rq_peer.peer_id.nid, nidstr));
883
884         *res = PTLRPCS_BADCRED;
885         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
886
887         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
888             gc->gc_ctx.len != 0) {
889                 CERROR("proc %d, ctx_len %d: not really init?\n",
890                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
891                 RETURN(SVC_DROP);
892         }
893
894         OBD_ALLOC(rsikey, sizeof(*rsikey));
895         if (!rsikey) {
896                 CERROR("out of memory\n");
897                 RETURN(SVC_DROP);
898         }
899         cache_init(&rsikey->h);
900
901         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
902                 CERROR("fail to dup context handle\n");
903                 GOTO(out_rsikey, rc = SVC_DROP);
904         }
905         *res = PTLRPCS_BADVERF;
906         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
907                 CERROR("can't extract token\n");
908                 GOTO(out_rsikey, rc = SVC_DROP);
909         }
910         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
911                 CERROR("can't duplicate token\n");
912                 GOTO(out_rsikey, rc = SVC_DROP);
913         }
914
915         rsip = gssd_upcall(rsikey, &my_chandle);
916         if (!rsip) {
917                 CERROR("error in gssd_upcall.\n");
918                 GOTO(out_rsikey, rc = SVC_DROP);
919         }
920
921         rsci = gss_svc_searchbyctx(&rsip->out_handle);
922         if (!rsci) {
923                 CERROR("rsci still not mature yet?\n");
924                 GOTO(out_rsip, rc = SVC_DROP);
925         }
926         CWARN("svcsec create gss context %p(%u@%s)\n",
927                rsci, rsci->cred.vc_uid,
928                portals_nid2str(req->rq_peer.peer_ni->pni_number,
929                                req->rq_peer.peer_id.nid, nidstr));
930
931         svcdata->is_init = 1;
932         svcdata->reserve_len = 6 * 4 +
933                 size_round4(rsip->out_handle.len) +
934                 size_round4(rsip->out_token.len);
935
936         rc = lustre_pack_reply(req, 0, NULL, NULL);
937         if (rc) {
938                 CERROR("failed to pack reply, rc = %d\n", rc);
939                 GOTO(out, rc = SVC_DROP);
940         }
941
942         /* header */
943         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
944         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS);
945         *resp++ = cpu_to_le32(PTLRPC_SEC_TYPE_NONE);
946         *resp++ = cpu_to_le32(req->rq_replen);
947         reslenp = resp++;
948
949         resp += req->rq_replen / 4;
950         reslen = svcdata->reserve_len;
951
952         /* gss reply:
953          * status, major, minor, seq, out_handle, out_token
954          */
955         *resp++ = cpu_to_le32(PTLRPCS_OK);
956         *resp++ = cpu_to_le32(rsip->major_status);
957         *resp++ = cpu_to_le32(rsip->minor_status);
958         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
959         reslen -= (4 * 4);
960         if (rawobj_serialize(&rsip->out_handle,
961                              &resp, &reslen))
962                 LBUG();
963         if (rawobj_serialize(&rsip->out_token,
964                              &resp, &reslen))
965                 LBUG();
966         /* the actual sec data length */
967         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
968
969         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
970         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
971                "total size %d\n", req, req->rq_replen,
972                le32_to_cpu(*reslenp),
973                req->rq_reply_state->rs_repdata_len);
974
975         *res = PTLRPCS_OK;
976
977         /* This is simplified since right now we doesn't support
978          * INIT_CONTINUE yet.
979          */
980         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
981                 struct ptlrpcs_wire_hdr *hdr;
982
983                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
984                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
985                 req->rq_reqlen = hdr->msg_len;
986
987                 rc = SVC_LOGIN;
988         } else
989                 rc = SVC_COMPLETE;
990
991 out:
992         rsc_put(&rsci->h, &rsc_cache);
993 out_rsip:
994         rsi_put(&rsip->h, &rsi_cache);
995 out_rsikey:
996         rsi_put(&rsikey->h, &rsi_cache);
997
998         RETURN(rc);
999 }
1000
1001 static int
1002 gss_svcsec_handle_data(struct ptlrpc_request *req,
1003                        struct rpc_gss_wire_cred *gc,
1004                        __u32 *secdata, __u32 seclen,
1005                        enum ptlrpcs_error *res)
1006 {
1007         struct rsc          *rsci;
1008         char                 nidstr[PTL_NALFMT_SIZE];
1009         __u32                major;
1010         int                  rc;
1011         ENTRY;
1012
1013         *res = PTLRPCS_GSS_CREDPROBLEM;
1014
1015         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1016         if (!rsci) {
1017                 CWARN("Invalid gss context handle from %s\n",
1018                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1019                                        req->rq_peer.peer_id.nid, nidstr));
1020                 major = GSS_S_NO_CONTEXT;
1021                 goto notify_err;
1022         }
1023
1024         switch (gc->gc_svc) {
1025         case PTLRPC_GSS_SVC_INTEGRITY:
1026                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1027                 if (major == GSS_S_COMPLETE)
1028                         break;
1029
1030                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1031                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1032                                        req->rq_peer.peer_id.nid, nidstr));
1033                 goto notify_err;
1034         case PTLRPC_GSS_SVC_PRIVACY:
1035                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1036                 if (major == GSS_S_COMPLETE)
1037                         break;
1038
1039                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1040                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1041                                        req->rq_peer.peer_id.nid, nidstr));
1042                 goto notify_err;
1043         default:
1044                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1045                 GOTO(out, rc = SVC_DROP);
1046         }
1047
1048         req->rq_auth_uid = rsci->cred.vc_uid;
1049         req->rq_remote = rsci->remote;
1050
1051         *res = PTLRPCS_OK;
1052         GOTO(out, rc = SVC_OK);
1053
1054 notify_err:
1055         if (gss_pack_err_notify(req, major, 0))
1056                 rc = SVC_DROP;
1057         else
1058                 rc = SVC_COMPLETE;
1059 out:
1060         if (rsci)
1061                 rsc_put(&rsci->h, &rsc_cache);
1062         RETURN(rc);
1063 }
1064
1065 static int
1066 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1067                           struct rpc_gss_wire_cred *gc,
1068                           __u32 *secdata, __u32 seclen,
1069                           enum ptlrpcs_error *res)
1070 {
1071         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1072         struct rsc          *rsci;
1073         char                 nidstr[PTL_NALFMT_SIZE];
1074         int                  rc;
1075         ENTRY;
1076
1077         LASSERT(svcdata);
1078         *res = PTLRPCS_GSS_CREDPROBLEM;
1079
1080         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1081         if (!rsci) {
1082                 CWARN("invalid gss context handle for destroy.\n");
1083                 RETURN(SVC_DROP);
1084         }
1085
1086         if (gc->gc_svc != PTLRPC_GSS_SVC_INTEGRITY) {
1087                 CERROR("service %d is not supported in destroy.\n",
1088                         gc->gc_svc);
1089                 GOTO(out, rc = SVC_DROP);
1090         }
1091
1092         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1093         if (*res)
1094                 GOTO(out, rc = SVC_DROP);
1095
1096         /* compose reply, which is actually nothing */
1097         svcdata->is_fini = 1;
1098         if (lustre_pack_reply(req, 0, NULL, NULL))
1099                 GOTO(out, rc = SVC_DROP);
1100
1101         CWARN("svcsec destroy gss context %p(%u@%s)\n",
1102                rsci, rsci->cred.vc_uid,
1103                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1104                                req->rq_peer.peer_id.nid, nidstr));
1105
1106         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1107         *res = PTLRPCS_OK;
1108         rc = SVC_LOGOUT;
1109 out:
1110         rsc_put(&rsci->h, &rsc_cache);
1111         RETURN(rc);
1112 }
1113
1114 /*
1115  * let incomming request go through security check:
1116  *  o context establishment: invoke user space helper
1117  *  o data exchange: verify/decrypt
1118  *  o context destruction: mark context invalid
1119  *
1120  * in most cases, error will result to drop the packet silently.
1121  */
1122 static int
1123 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1124 {
1125         struct gss_svc_data *svcdata;
1126         struct rpc_gss_wire_cred *gc;
1127         struct ptlrpcs_wire_hdr *sec_hdr;
1128         __u32 seclen, *secdata, version, subflavor;
1129         int rc;
1130         ENTRY;
1131
1132         CDEBUG(D_SEC, "request %p\n", req);
1133         LASSERT(req->rq_reqbuf);
1134         LASSERT(req->rq_reqbuf_len);
1135
1136         *res = PTLRPCS_BADCRED;
1137
1138         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1139         LASSERT(sec_hdr->flavor == PTLRPC_SEC_GSS);
1140
1141         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1142         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1143
1144         if (sec_hdr->sec_len > seclen) {
1145                 CERROR("seclen %d, while max buf %d\n",
1146                         sec_hdr->sec_len, seclen);
1147                 RETURN(SVC_DROP);
1148         }
1149
1150         if (seclen < 6 * 4) {
1151                 CERROR("sec size %d too small\n", seclen);
1152                 RETURN(SVC_DROP);
1153         }
1154
1155         LASSERT(!req->rq_sec_svcdata);
1156         OBD_ALLOC(svcdata, sizeof(*svcdata));
1157         if (!svcdata) {
1158                 CERROR("fail to alloc svcdata\n");
1159                 RETURN(SVC_DROP);
1160         }
1161         req->rq_sec_svcdata = svcdata;
1162         gc = &svcdata->clcred;
1163
1164         /* Now secdata/seclen is what we want to parse
1165          */
1166         version = le32_to_cpu(*secdata++);      /* version */
1167         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1168         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1169         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1170         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1171         seclen -= 5 * 4;
1172
1173         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1174                version, subflavor, gc->gc_proc, gc->gc_seq, gc->gc_svc);
1175
1176         if (version != PTLRPC_SEC_GSS_VERSION) {
1177                 CERROR("gss version mismatch: %d - %d\n",
1178                         version, PTLRPC_SEC_GSS_VERSION);
1179                 GOTO(err_free, rc = SVC_DROP);
1180         }
1181
1182         if (rawobj_extract(&gc->gc_ctx, &secdata, &seclen)) {
1183                 CERROR("fail to obtain gss context handle\n");
1184                 GOTO(err_free, rc = SVC_DROP);
1185         }
1186
1187         *res = PTLRPCS_BADVERF;
1188         switch(gc->gc_proc) {
1189         case RPC_GSS_PROC_INIT:
1190         case RPC_GSS_PROC_CONTINUE_INIT:
1191                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1192                 break;
1193         case RPC_GSS_PROC_DATA:
1194                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1195                 break;
1196         case RPC_GSS_PROC_DESTROY:
1197                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1198                 break;
1199         default:
1200                 rc = SVC_DROP;
1201                 LBUG();
1202         }
1203
1204 err_free:
1205         if (rc == SVC_DROP && req->rq_sec_svcdata) {
1206                 OBD_FREE(req->rq_sec_svcdata, sizeof(struct gss_svc_data));
1207                 req->rq_sec_svcdata = NULL;
1208         }
1209
1210         RETURN(rc);
1211 }
1212
1213 static int
1214 gss_svcsec_authorize(struct ptlrpc_request *req)
1215 {
1216         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1217         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_sec_svcdata;
1218         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1219         struct rsc                *rscp;
1220         struct ptlrpcs_wire_hdr   *sec_hdr;
1221         rawobj_buf_t               msg_buf;
1222         rawobj_t                   cipher_buf;
1223         __u32                     *vp, *vpsave, major, vlen, seclen;
1224         rawobj_t                   lmsg, mic;
1225         int                        ret;
1226         ENTRY;
1227
1228         LASSERT(rs);
1229         LASSERT(rs->rs_repbuf);
1230         LASSERT(gsd);
1231
1232         if (gsd->is_init || gsd->is_init_continue ||
1233             gsd->is_err_notify || gsd->is_fini) {
1234                 /* nothing to do in these cases */
1235                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1236                 RETURN(0);
1237         }
1238
1239         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1240                 CERROR("proc %d not support\n", gc->gc_proc);
1241                 RETURN(-EINVAL);
1242         }
1243
1244         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1245         if (!rscp) {
1246                 CERROR("ctx disapeared under us?\n");
1247                 RETURN(-EINVAL);
1248         }
1249
1250         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1251         switch (gc->gc_svc) {
1252         case  PTLRPC_GSS_SVC_INTEGRITY:
1253                 /* prepare various pointers */
1254                 lmsg.len = req->rq_replen;
1255                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1256                 vp = (__u32 *) (lmsg.data + lmsg.len);
1257                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1258                 seclen = vlen;
1259
1260                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1261                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_AUTH);
1262                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1263
1264                 /* standard gss hdr */
1265                 LASSERT(vlen >= 7 * 4);
1266                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1267                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1268                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1269                 *vp++ = cpu_to_le32(gc->gc_seq);
1270                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_INTEGRITY);
1271                 *vp++ = 0;      /* fake ctx handle */
1272                 vpsave = vp++;  /* reserve size */
1273                 vlen -= 7 * 4;
1274
1275                 mic.len = vlen;
1276                 mic.data = (char *) vp;
1277
1278                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1279                 if (major) {
1280                         CERROR("fail to get MIC: 0x%x\n", major);
1281                         GOTO(out, ret = -EINVAL);
1282                 }
1283                 *vpsave = cpu_to_le32(mic.len);
1284                 seclen = seclen - vlen + mic.len;
1285                 sec_hdr->sec_len = cpu_to_le32(seclen);
1286                 rs->rs_repdata_len += size_round(seclen);
1287                 break;
1288         case  PTLRPC_GSS_SVC_PRIVACY:
1289                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1290                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1291                 seclen = vlen;
1292
1293                 sec_hdr->flavor = cpu_to_le32(PTLRPC_SEC_GSS);
1294                 sec_hdr->sectype = cpu_to_le32(PTLRPC_SEC_TYPE_PRIV);
1295                 sec_hdr->msg_len = cpu_to_le32(0);
1296
1297                 /* standard gss hdr */
1298                 LASSERT(vlen >= 7 * 4);
1299                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1300                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_KRB5I);
1301                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1302                 *vp++ = cpu_to_le32(gc->gc_seq);
1303                 *vp++ = cpu_to_le32(PTLRPC_GSS_SVC_PRIVACY);
1304                 *vp++ = 0;      /* fake ctx handle */
1305                 vpsave = vp++;  /* reserve size */
1306                 vlen -= 7 * 4;
1307
1308                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1309                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1310                                  GSS_PRIVBUF_SUFFIX_LEN;
1311                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1312                 msg_buf.datalen = req->rq_replen;
1313
1314                 cipher_buf.data = (__u8 *) vp;
1315                 cipher_buf.len = vlen;
1316
1317                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1318                                 &msg_buf, &cipher_buf);
1319                 if (major) {
1320                         CERROR("failed to wrap: 0x%x\n", major);
1321                         GOTO(out, ret = -EINVAL);
1322                 }
1323
1324                 *vpsave = cpu_to_le32(cipher_buf.len);
1325                 seclen = seclen - vlen + cipher_buf.len;
1326                 sec_hdr->sec_len = cpu_to_le32(seclen);
1327                 rs->rs_repdata_len += size_round(seclen);
1328                 break;
1329         default:
1330                 CERROR("Unknown service %d\n", gc->gc_svc);
1331                 GOTO(out, ret = -EINVAL);
1332         }
1333         ret = 0;
1334 out:
1335         rsc_put(&rscp->h, &rsc_cache);
1336
1337         RETURN(ret);
1338 }
1339
1340 static
1341 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1342                             struct ptlrpc_request *req)
1343 {
1344         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1345
1346         if (!gsd) {
1347                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1348                 return;
1349         }
1350
1351         /* gsd->clclred.gc_ctx is NOT allocated, just set pointer
1352          * to the incoming packet buffer, so don't need free it
1353          */
1354         OBD_FREE(gsd, sizeof(*gsd));
1355         req->rq_sec_svcdata = NULL;
1356         return;
1357 }
1358
1359 static
1360 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1361                            struct ptlrpc_request *req,
1362                            int msgsize)
1363 {
1364         struct gss_svc_data *svcdata = req->rq_sec_svcdata;
1365         ENTRY;
1366
1367         /* just return the pre-set reserve_len for init/fini/err cases.
1368          */
1369         LASSERT(svcdata);
1370         if (svcdata->is_init) {
1371                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1372                        size_round(svcdata->reserve_len),
1373                        svcdata->reserve_len);
1374                 LASSERT(svcdata->reserve_len);
1375                 LASSERT(svcdata->reserve_len % 4 == 0);
1376                 RETURN(size_round(svcdata->reserve_len));
1377         } else if (svcdata->is_err_notify) {
1378                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1379                        size_round(svcdata->reserve_len),
1380                        svcdata->reserve_len);
1381                 RETURN(size_round(svcdata->reserve_len));
1382         } else if (svcdata->is_fini) {
1383                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1384                 RETURN(0);
1385         } else {
1386                 if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_NONE ||
1387                     svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_INTEGRITY)
1388                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1389                 else if (svcdata->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1390                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1391                                             GSS_PRIVBUF_PREFIX_LEN +
1392                                             GSS_PRIVBUF_SUFFIX_LEN));
1393                 else {
1394                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1395                         *((int *)0) = 0;
1396                         LBUG();
1397                 }
1398         }
1399         RETURN(0);
1400 }
1401
1402 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1403                             struct ptlrpc_request *req,
1404                             int msgsize)
1405 {
1406         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_sec_svcdata;
1407         struct ptlrpc_reply_state *rs;
1408         int msg_payload, sec_payload;
1409         int privacy, rc;
1410         ENTRY;
1411
1412         /* determine the security type: none/auth or priv, we have
1413          * different pack scheme for them.
1414          * init/fini/err will always be treated as none/auth.
1415          */
1416         LASSERT(gsd);
1417         if (!gsd->is_init && !gsd->is_init_continue &&
1418             !gsd->is_fini && !gsd->is_err_notify &&
1419             gsd->clcred.gc_svc == PTLRPC_GSS_SVC_PRIVACY)
1420                 privacy = 1;
1421         else
1422                 privacy = 0;
1423
1424         msg_payload = privacy ? 0 : msgsize;
1425         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1426
1427         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1428         if (rc)
1429                 RETURN(rc);
1430
1431         rs = req->rq_reply_state;
1432         LASSERT(rs);
1433         rs->rs_msg_len = msgsize;
1434
1435         if (privacy) {
1436                 /* we can choose to let msg simply point to the rear of the
1437                  * buffer, which lead to buffer overlap when doing encryption.
1438                  * usually it's ok and it indeed passed all existing tests.
1439                  * but not sure if there will be subtle problems in the future.
1440                  * so right now we choose to alloc another new buffer. we'll
1441                  * see how it works.
1442                  */
1443 #if 0
1444                 rs->rs_msg = (struct lustre_msg *)
1445                              (rs->rs_repbuf + rs->rs_repbuf_len -
1446                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1447 #endif
1448                 char *msgbuf;
1449
1450                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1451                 OBD_ALLOC(msgbuf, msgsize);
1452                 if (!msgbuf) {
1453                         CERROR("can't alloc %d\n", msgsize);
1454                         svcsec_free_reply_state(rs);
1455                         req->rq_reply_state = NULL;
1456                         RETURN(-ENOMEM);
1457                 }
1458                 rs->rs_msg = (struct lustre_msg *)
1459                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1460         }
1461
1462         req->rq_repmsg = rs->rs_msg;
1463
1464         RETURN(0);
1465 }
1466
1467 static
1468 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1469                             struct ptlrpc_reply_state *rs)
1470 {
1471         unsigned long p1 = (unsigned long) rs->rs_msg;
1472         unsigned long p2 = (unsigned long) rs->rs_buf;
1473
1474         LASSERT(rs->rs_buf);
1475         LASSERT(rs->rs_msg);
1476         LASSERT(rs->rs_msg_len);
1477
1478         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1479                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1480                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1481                            GSS_PRIVBUF_SUFFIX_LEN;
1482                 OBD_FREE(start, size);
1483         }
1484
1485         svcsec_free_reply_state(rs);
1486 }
1487
1488 struct ptlrpc_svcsec svcsec_gss = {
1489         .pss_owner              = THIS_MODULE,
1490         .pss_name               = "GSS_SVCSEC",
1491         .pss_flavor             = {PTLRPC_SEC_GSS, 0},
1492         .accept                 = gss_svcsec_accept,
1493         .authorize              = gss_svcsec_authorize,
1494         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1495         .free_repbuf            = gss_svcsec_free_repbuf,
1496         .cleanup_req            = gss_svcsec_cleanup_req,
1497 };
1498
1499 /* XXX hacking */
1500 void lgss_svc_cache_purge_all(void)
1501 {
1502         cache_purge(&rsi_cache);
1503         cache_purge(&rsc_cache);
1504 }
1505 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1506
1507 void lgss_svc_cache_flush(__u32 uid)
1508 {
1509         rsc_flush(uid);
1510 }
1511 EXPORT_SYMBOL(lgss_svc_cache_flush);
1512
1513 int gss_svc_init(void)
1514 {
1515         int rc;
1516
1517         rc = svcsec_register(&svcsec_gss);
1518         if (!rc) {
1519                 cache_register(&rsc_cache);
1520                 cache_register(&rsi_cache);
1521         }
1522         return rc;
1523 }
1524
1525 void gss_svc_exit(void)
1526 {
1527         int rc;
1528         if ((rc = cache_unregister(&rsi_cache)))
1529                 CERROR("unregister rsi cache: %d\n", rc);
1530         if ((rc = cache_unregister(&rsc_cache)))
1531                 CERROR("unregister rsc cache: %d\n", rc);
1532         if ((rc = svcsec_unregister(&svcsec_gss)))
1533                 CERROR("unregister svcsec_gss: %d\n", rc);
1534 }