Whamcloud - gitweb
current branches now use lnet from HEAD
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         __u32                   lustre_svc;
109         __u32                   naltype;
110         __u32                   netid;
111         __u64                   nid;
112         rawobj_t                in_handle, in_token, in_srv_type;
113         rawobj_t                out_handle, out_token;
114         int                     major_status, minor_status;
115 };
116
117 static struct cache_head *rsi_table[RSI_HASHMAX];
118 static struct cache_detail rsi_cache;
119
120 static void rsi_free(struct rsi *rsii)
121 {
122         rawobj_free(&rsii->in_handle);
123         rawobj_free(&rsii->in_token);
124         rawobj_free(&rsii->out_handle);
125         rawobj_free(&rsii->out_token);
126 }
127
128 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
129 {
130         struct rsi *rsii = container_of(item, struct rsi, h);
131         LASSERT(atomic_read(&item->refcnt) > 0);
132         if (cache_put(item, cd)) {
133                 LASSERT(item->next == NULL);
134                 rsi_free(rsii);
135                 OBD_FREE(rsii, sizeof(*rsii));
136         }
137 }
138
139 static inline int rsi_hash(struct rsi *item)
140 {
141         return hash_mem((char *)item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
142                 ^ hash_mem((char *)item->in_token.data, item->in_token.len, RSI_HASHBITS);
143 }
144
145 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
146 {
147         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
148                 rawobj_equal(&item->in_token, &tmp->in_token));
149 }
150
151 static void rsi_request(struct cache_detail *cd,
152                         struct cache_head *h,
153                         char **bpp, int *blen)
154 {
155         struct rsi *rsii = container_of(h, struct rsi, h);
156
157         qword_addhex(bpp, blen, (char *) &rsii->lustre_svc,
158                      sizeof(rsii->lustre_svc));
159         qword_addhex(bpp, blen, (char *) &rsii->naltype, sizeof(rsii->naltype));
160         qword_addhex(bpp, blen, (char *) &rsii->netid, sizeof(rsii->netid));
161         qword_addhex(bpp, blen, (char *) &rsii->nid, sizeof(rsii->nid));
162         qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len);
163         qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len);
164         (*bpp)[-1] = '\n';
165 }
166
167 static int
168 gssd_reply(struct rsi *item)
169 {
170         struct rsi *tmp;
171         struct cache_head **hp, **head;
172         ENTRY;
173
174         head = &rsi_cache.hash_table[rsi_hash(item)];
175         write_lock(&rsi_cache.hash_lock);
176         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
177                 tmp = container_of(*hp, struct rsi, h);
178                 if (rsi_match(tmp, item)) {
179                         cache_get(&tmp->h);
180                         clear_bit(CACHE_HASHED, &tmp->h.flags);
181                         *hp = tmp->h.next;
182                         tmp->h.next = NULL;
183                         rsi_cache.entries--;
184                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
185                                 CERROR("rsi is valid\n");
186                                 write_unlock(&rsi_cache.hash_lock);
187                                 rsi_put(&tmp->h, &rsi_cache);
188                                 RETURN(-EINVAL);
189                         }
190                         set_bit(CACHE_HASHED, &item->h.flags);
191                         item->h.next = *hp;
192                         *hp = &item->h;
193                         rsi_cache.entries++;
194                         set_bit(CACHE_VALID, &item->h.flags);
195                         item->h.last_refresh = get_seconds();
196                         write_unlock(&rsi_cache.hash_lock);
197                         cache_fresh(&rsi_cache, &tmp->h, 0);
198                         rsi_put(&tmp->h, &rsi_cache);
199                         RETURN(0);
200                 }
201         }
202         write_unlock(&rsi_cache.hash_lock);
203         RETURN(-EINVAL);
204 }
205
206 /* XXX
207  * here we just wait here for its completion or timedout. it's a
208  * hacking but works, and we'll comeup with real fix if we decided
209  * to still stick with NFS4 cache code
210  */
211 static struct rsi *
212 gssd_upcall(struct rsi *item, struct cache_req *chandle)
213 {
214         struct rsi *tmp;
215         struct cache_head **hp, **head;
216         unsigned long starttime;
217         ENTRY;
218
219         head = &rsi_cache.hash_table[rsi_hash(item)];
220         read_lock(&rsi_cache.hash_lock);
221         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
222                 tmp = container_of(*hp, struct rsi, h);
223                 if (rsi_match(tmp, item)) {
224                         LBUG();
225                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
226                                 CERROR("found rsi without VALID\n");
227                                 read_unlock(&rsi_cache.hash_lock);
228                                 return NULL;
229                         }
230                         *hp = tmp->h.next;
231                         tmp->h.next = NULL;
232                         rsi_cache.entries--;
233                         cache_get(&tmp->h);
234                         read_unlock(&rsi_cache.hash_lock);
235                         return tmp;
236                 }
237         }
238         cache_get(&item->h);
239         set_bit(CACHE_HASHED, &item->h.flags);
240         item->h.next = *head;
241         *head = &item->h;
242         rsi_cache.entries++;
243         read_unlock(&rsi_cache.hash_lock);
244         //cache_get(&item->h);
245
246         cache_check(&rsi_cache, &item->h, chandle);
247         starttime = get_seconds();
248         do {
249                 set_current_state(TASK_UNINTERRUPTIBLE);
250                 schedule_timeout(HZ/2);
251                 read_lock(&rsi_cache.hash_lock);
252                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
253                         tmp = container_of(*hp, struct rsi, h);
254                         if (tmp == item)
255                                 continue;
256                         if (rsi_match(tmp, item)) {
257                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
258                                         read_unlock(&rsi_cache.hash_lock);
259                                         return NULL;
260                                 }
261                                 cache_get(&tmp->h);
262                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
263                                 *hp = tmp->h.next;
264                                 tmp->h.next = NULL;
265                                 rsi_cache.entries--;
266                                 read_unlock(&rsi_cache.hash_lock);
267                                 return tmp;
268                         }
269                 }
270                 read_unlock(&rsi_cache.hash_lock);
271         } while ((get_seconds() - starttime) <= SVCSEC_UPCALL_TIMEOUT);
272         CERROR("%ds timeout while waiting cache refill\n",
273                SVCSEC_UPCALL_TIMEOUT);
274         return NULL;
275 }
276
277 static int rsi_parse(struct cache_detail *cd,
278                      char *mesg, int mlen)
279 {
280         /* context token expiry major minor context token */
281         char *buf = mesg;
282         char *ep;
283         int len;
284         struct rsi *rsii;
285         time_t expiry;
286         int status = -EINVAL;
287         ENTRY;
288
289         OBD_ALLOC(rsii, sizeof(*rsii));
290         if (!rsii)
291                 RETURN(-ENOMEM);
292         cache_init(&rsii->h);
293
294         /* handle */
295         len = qword_get(&mesg, buf, mlen);
296         if (len < 0)
297                 goto out;
298         if (rawobj_alloc(&rsii->in_handle, buf, len)) {
299                 status = -ENOMEM;
300                 goto out;
301         }
302
303         /* token */
304         len = qword_get(&mesg, buf, mlen);
305         if (len < 0)
306                 goto out;
307         if (rawobj_alloc(&rsii->in_token, buf, len)) {
308                 status = -ENOMEM;
309                 goto out;
310         }
311
312         /* expiry */
313         expiry = get_expiry(&mesg);
314         if (expiry == 0)
315                 goto out;
316
317         /* major */
318         len = qword_get(&mesg, buf, mlen);
319         if (len <= 0)
320                 goto out;
321         rsii->major_status = simple_strtol(buf, &ep, 10);
322         if (*ep)
323                 goto out;
324
325         /* minor */
326         len = qword_get(&mesg, buf, mlen);
327         if (len <= 0)
328                 goto out;
329         rsii->minor_status = simple_strtol(buf, &ep, 10);
330         if (*ep)
331                 goto out;
332
333         /* out_handle */
334         len = qword_get(&mesg, buf, mlen);
335         if (len < 0)
336                 goto out;
337         if (rawobj_alloc(&rsii->out_handle, buf, len)) {
338                 status = -ENOMEM;
339                 goto out;
340         }
341
342         /* out_token */
343         len = qword_get(&mesg, buf, mlen);
344         if (len < 0)
345                 goto out;
346         if (rawobj_alloc(&rsii->out_token, buf, len)) {
347                 status = -ENOMEM;
348                 goto out;
349         }
350
351         rsii->h.expiry_time = expiry;
352         status = gssd_reply(rsii);
353 out:
354         if (rsii)
355                 rsi_put(&rsii->h, &rsi_cache);
356         RETURN(status);
357 }
358
359 static struct cache_detail rsi_cache = {
360         .hash_size      = RSI_HASHMAX,
361         .hash_table     = rsi_table,
362         .name           = "auth.ptlrpcs.init",
363         .cache_put      = rsi_put,
364         .cache_request  = rsi_request,
365         .cache_parse    = rsi_parse,
366 };
367
368 /*
369  * The rpcsec_context cache is used to store a context that is
370  * used in data exchange.
371  * The key is a context handle. The content is:
372  *  uid, gidlist, mechanism, service-set, mech-specific-data
373  */
374
375 #define RSC_HASHBITS    10
376 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
377 #define RSC_HASHMASK    (RSC_HASHMAX-1)
378
379 #define GSS_SEQ_WIN     512
380
381 struct gss_svc_seq_data {
382         /* highest seq number seen so far: */
383         __u32                   sd_max;
384         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
385          * sd_win is nonzero iff sequence number i has been seen already: */
386         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
387         spinlock_t              sd_lock;
388 };
389
390 struct rsc {
391         struct cache_head       h;
392         rawobj_t                handle;
393         __u32                   remote_realm:1,
394                                 auth_usr_mds:1,
395                                 auth_usr_oss:1;
396         struct vfs_cred         cred;
397         uid_t                   mapped_uid;
398         struct gss_svc_seq_data seqdata;
399         struct gss_ctx         *mechctx;
400 };
401
402 static struct cache_head *rsc_table[RSC_HASHMAX];
403 static struct cache_detail rsc_cache;
404
405 static void rsc_free(struct rsc *rsci)
406 {
407         rawobj_free(&rsci->handle);
408         if (rsci->mechctx)
409                 kgss_delete_sec_context(&rsci->mechctx);
410 #if 0
411         if (rsci->cred.vc_ginfo)
412                 put_group_info(rsci->cred.vc_ginfo);
413 #endif
414 }
415
416 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
417 {
418         struct rsc *rsci = container_of(item, struct rsc, h);
419
420         LASSERT(atomic_read(&item->refcnt) > 0);
421         if (cache_put(item, cd)) {
422                 LASSERT(item->next == NULL);
423                 rsc_free(rsci);
424                 OBD_FREE(rsci, sizeof(*rsci));
425         }
426 }
427
428 static inline int
429 rsc_hash(struct rsc *rsci)
430 {
431         return hash_mem((char *)rsci->handle.data,
432                         rsci->handle.len, RSC_HASHBITS);
433 }
434
435 static inline int
436 rsc_match(struct rsc *new, struct rsc *tmp)
437 {
438         return rawobj_equal(&new->handle, &tmp->handle);
439 }
440
441 static struct rsc *rsc_lookup(struct rsc *item, int set)
442 {
443         struct rsc *tmp = NULL;
444         struct cache_head **hp, **head;
445         head = &rsc_cache.hash_table[rsc_hash(item)];
446         ENTRY;
447
448         if (set)
449                 write_lock(&rsc_cache.hash_lock);
450         else
451                 read_lock(&rsc_cache.hash_lock);
452         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
453                 tmp = container_of(*hp, struct rsc, h);
454                 if (!rsc_match(tmp, item))
455                         continue;
456                 cache_get(&tmp->h);
457                 if (!set)
458                         goto out_noset;
459                 *hp = tmp->h.next;
460                 tmp->h.next = NULL;
461                 clear_bit(CACHE_HASHED, &tmp->h.flags);
462                 rsc_put(&tmp->h, &rsc_cache);
463                 goto out_set;
464         }
465         /* Didn't find anything */
466         if (!set)
467                 goto out_nada;
468         rsc_cache.entries++;
469 out_set:
470         set_bit(CACHE_HASHED, &item->h.flags);
471         item->h.next = *head;
472         *head = &item->h;
473         write_unlock(&rsc_cache.hash_lock);
474         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
475         cache_get(&item->h);
476         RETURN(item);
477 out_nada:
478         tmp = NULL;
479 out_noset:
480         read_unlock(&rsc_cache.hash_lock);
481         RETURN(tmp);
482 }
483
484 static int rsc_parse(struct cache_detail *cd,
485                      char *mesg, int mlen)
486 {
487         /* contexthandle expiry [ uid gid N <n gids> mechname
488          * ...mechdata... ] */
489         char *buf = mesg;
490         int len, rv, tmp_int;
491         struct rsc *rsci, *res = NULL;
492         time_t expiry;
493         int status = -EINVAL;
494
495         OBD_ALLOC(rsci, sizeof(*rsci));
496         if (!rsci) {
497                 CERROR("fail to alloc rsci\n");
498                 return -ENOMEM;
499         }
500         cache_init(&rsci->h);
501
502         /* context handle */
503         len = qword_get(&mesg, buf, mlen);
504         if (len < 0) goto out;
505         status = -ENOMEM;
506         if (rawobj_alloc(&rsci->handle, buf, len))
507                 goto out;
508
509         /* expiry */
510         expiry = get_expiry(&mesg);
511         status = -EINVAL;
512         if (expiry == 0)
513                 goto out;
514
515         /* remote flag */
516         rv = get_int(&mesg, &tmp_int);
517         if (rv) {
518                 CERROR("fail to get remote flag\n");
519                 goto out;
520         }
521         rsci->remote_realm = (tmp_int != 0);
522
523         /* mds user flag */
524         rv = get_int(&mesg, &tmp_int);
525         if (rv) {
526                 CERROR("fail to get mds user flag\n");
527                 goto out;
528         }
529         rsci->auth_usr_mds = (tmp_int != 0);
530
531         /* oss user flag */
532         rv = get_int(&mesg, &tmp_int);
533         if (rv) {
534                 CERROR("fail to get oss user flag\n");
535                 goto out;
536         }
537         rsci->auth_usr_oss = (tmp_int != 0);
538
539         /* mapped uid */
540         rv = get_int(&mesg, (int *)&rsci->mapped_uid);
541         if (rv) {
542                 CERROR("fail to get mapped uid\n");
543                 goto out;
544         }
545
546         /* uid, or NEGATIVE */
547         rv = get_int(&mesg, (int *)&rsci->cred.vc_uid);
548         if (rv == -EINVAL)
549                 goto out;
550         if (rv == -ENOENT) {
551                 CERROR("NOENT? set rsc entry negative\n");
552                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
553         } else {
554                 struct gss_api_mech *gm;
555                 rawobj_t tmp_buf;
556                 __u64 ctx_expiry;
557
558                 /* gid */
559                 if (get_int(&mesg, (int *)&rsci->cred.vc_gid))
560                         goto out;
561
562                 /* mech name */
563                 len = qword_get(&mesg, buf, mlen);
564                 if (len < 0)
565                         goto out;
566                 gm = kgss_name_to_mech(buf);
567                 status = -EOPNOTSUPP;
568                 if (!gm)
569                         goto out;
570
571                 status = -EINVAL;
572                 /* mech-specific data: */
573                 len = qword_get(&mesg, buf, mlen);
574                 if (len < 0) {
575                         kgss_mech_put(gm);
576                         goto out;
577                 }
578                 tmp_buf.len = len;
579                 tmp_buf.data = (unsigned char *)buf;
580                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
581                         kgss_mech_put(gm);
582                         goto out;
583                 }
584
585                 /* currently the expiry time passed down from user-space
586                  * is invalid, here we retrive it from mech.
587                  */
588                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
589                         CERROR("unable to get expire time, drop it\n");
590                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
591                         kgss_mech_put(gm);
592                         goto out;
593                 }
594                 expiry = (time_t) gss_roundup_expire_time(ctx_expiry);
595
596                 kgss_mech_put(gm);
597         }
598         rsci->h.expiry_time = expiry;
599         spin_lock_init(&rsci->seqdata.sd_lock);
600         res = rsc_lookup(rsci, 1);
601         rsc_put(&res->h, &rsc_cache);
602         status = 0;
603 out:
604         if (rsci)
605                 rsc_put(&rsci->h, &rsc_cache);
606         return status;
607 }
608
609 /*
610  * flush all entries with @uid. @uid == -1 will match all.
611  * we only know the uid, maybe netid/nid in the future, in all cases
612  * we must search the whole cache
613  */
614 static void rsc_flush(uid_t uid)
615 {
616         struct cache_head **ch;
617         struct rsc *rscp;
618         int n;
619         ENTRY;
620
621         if (uid == -1)
622                 CWARN("flush all gss contexts\n");
623
624         write_lock(&rsc_cache.hash_lock);
625         for (n = 0; n < RSC_HASHMAX; n++) {
626                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
627                         rscp = container_of(*ch, struct rsc, h);
628
629                         if (uid != -1 && rscp->cred.vc_uid != uid) {
630                                 ch = &((*ch)->next);
631                                 continue;
632                         }
633
634                         /* it seems simply set NEGATIVE doesn't work */
635                         *ch = (*ch)->next;
636                         rscp->h.next = NULL;
637                         cache_get(&rscp->h);
638                         set_bit(CACHE_NEGATIVE, &rscp->h.flags);
639                         clear_bit(CACHE_HASHED, &rscp->h.flags);
640                         if (uid != -1)
641                                 CWARN("flush rsc %p(%u) for uid %u\n", rscp,
642                                       *((__u32 *) rscp->handle.data),
643                                       rscp->cred.vc_uid);
644                         rsc_put(&rscp->h, &rsc_cache);
645                         rsc_cache.entries--;
646                 }
647         }
648         write_unlock(&rsc_cache.hash_lock);
649         EXIT;
650 }
651
652 static struct cache_detail rsc_cache = {
653         .hash_size      = RSC_HASHMAX,
654         .hash_table     = rsc_table,
655         .name           = "auth.ptlrpcs.context",
656         .cache_put      = rsc_put,
657         .cache_parse    = rsc_parse,
658 };
659
660 static struct rsc *
661 gss_svc_searchbyctx(rawobj_t *handle)
662 {
663         struct rsc rsci;
664         struct rsc *found;
665
666         rsci.handle = *handle;
667         found = rsc_lookup(&rsci, 0);
668         if (!found)
669                 return NULL;
670
671         if (cache_check(&rsc_cache, &found->h, NULL))
672                 return NULL;
673
674         return found;
675 }
676
677 /* FIXME
678  * again hacking: only try to give the svcgssd a chance to handle
679  * upcalls.
680  */
681 struct cache_deferred_req* my_defer(struct cache_req *req)
682 {
683         yield();
684         return NULL;
685 }
686 static struct cache_req my_chandle = {my_defer};
687
688 /* Implements sequence number algorithm as specified in RFC 2203. */
689 static inline void __dbg_dump_seqwin(struct gss_svc_seq_data *sd)
690 {
691         char buf[sizeof(sd->sd_win)*2+1];
692         int i;
693
694         for (i = 0; i < sizeof(sd->sd_win); i++)
695                 sprintf(&buf[i+i], "%02x", ((__u8 *) sd->sd_win)[i]);
696         CWARN("dump seqwin: %s\n", buf);
697 }
698
699 static inline void __dbg_seq_jump(struct gss_svc_seq_data *sd, __u32 seq_num)
700 {
701         CWARN("seq jump to %u, cur max %u!\n", seq_num, sd->sd_max);
702         __dbg_dump_seqwin(sd);
703 }
704
705 static inline void __dbg_seq_increase(struct gss_svc_seq_data *sd, __u32 seq_num)
706 {
707         int n = seq_num - sd->sd_max;
708         int i, notset=0;
709
710         for (i = 0; i < n; i++) {
711                 if (!test_bit(i, sd->sd_win))
712                         notset++;
713         }
714         if (!notset)
715                 return;
716
717         CWARN("seq increase to %u, cur max %u\n", seq_num, sd->sd_max);
718         __dbg_dump_seqwin(sd);
719 }
720
721 static int
722 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
723 {
724         int rc = 0;
725
726         spin_lock(&sd->sd_lock);
727         if (seq_num > sd->sd_max) {
728                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
729                         __dbg_seq_jump(sd, seq_num);
730                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
731                         sd->sd_max = seq_num;
732                 } else {
733                         __dbg_seq_increase(sd, seq_num);
734                         while(sd->sd_max < seq_num) {
735                                 sd->sd_max++;
736                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
737                                             sd->sd_win);
738                         }
739                 }
740                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
741                 goto exit;
742         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
743                 CERROR("seq %u too low: max %u, win %d\n",
744                         seq_num, sd->sd_max, GSS_SEQ_WIN);
745                 rc = 1;
746                 goto exit;
747         }
748
749         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) {
750                 CERROR("seq %u is replay: max %u, win %d\n",
751                         seq_num, sd->sd_max, GSS_SEQ_WIN);
752                 rc = 1;
753         }
754 exit:
755         spin_unlock(&sd->sd_lock);
756         return rc;
757 }
758
759 static int
760 gss_svc_verify_request(struct ptlrpc_request *req,
761                        struct rsc *rsci,
762                        struct rpc_gss_wire_cred *gc,
763                        __u32 *vp, __u32 vlen)
764 {
765         struct ptlrpcs_wire_hdr *sec_hdr;
766         struct gss_ctx *ctx = rsci->mechctx;
767         __u32 maj_stat;
768         rawobj_t msg;
769         rawobj_t mic;
770         ENTRY;
771
772         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
773
774         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
775         req->rq_reqlen = sec_hdr->msg_len;
776
777         msg.len = sec_hdr->msg_len;
778         msg.data = (__u8 *)req->rq_reqmsg;
779
780         mic.len = le32_to_cpu(*vp++);
781         mic.data = (unsigned char *)vp;
782         vlen -= 4;
783
784         if (mic.len > vlen) {
785                 CERROR("checksum len %d, while buffer len %d\n",
786                         mic.len, vlen);
787                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
788         }
789
790         if (mic.len > 256) {
791                 CERROR("invalid mic len %d\n", mic.len);
792                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
793         }
794
795         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
796         if (maj_stat != GSS_S_COMPLETE) {
797                 CERROR("MIC verification error: major %x\n", maj_stat);
798                 RETURN(maj_stat);
799         }
800
801         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
802                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
803                         req, req->rq_reqmsg->opc, req->rq_xid,
804                         req->rq_reqmsg->transno);
805                 RETURN(GSS_S_DUPLICATE_TOKEN);
806         }
807
808         RETURN(GSS_S_COMPLETE);
809 }
810
811 static int
812 gss_svc_unseal_request(struct ptlrpc_request *req,
813                        struct rsc *rsci,
814                        struct rpc_gss_wire_cred *gc,
815                        __u32 *vp, __u32 vlen)
816 {
817         struct ptlrpcs_wire_hdr *sec_hdr;
818         struct gss_ctx *ctx = rsci->mechctx;
819         rawobj_t cipher_text, plain_text;
820         __u32 major;
821         ENTRY;
822
823         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
824
825         if (vlen < 4) {
826                 CERROR("vlen only %u\n", vlen);
827                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
828         }
829
830         cipher_text.len = le32_to_cpu(*vp++);
831         cipher_text.data = (__u8 *) vp;
832         vlen -= 4;
833         
834         if (cipher_text.len > vlen) {
835                 CERROR("cipher claimed %u while buf only %u\n",
836                         cipher_text.len, vlen);
837                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
838         }
839
840         plain_text = cipher_text;
841
842         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
843         if (major) {
844                 CERROR("unwrap error 0x%x\n", major);
845                 RETURN(major);
846         }
847
848         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
849                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
850                         req, req->rq_reqmsg->opc, req->rq_xid,
851                         req->rq_reqmsg->transno);
852                 RETURN(GSS_S_DUPLICATE_TOKEN);
853         }
854
855         req->rq_reqmsg = (struct lustre_msg *) (vp);
856         req->rq_reqlen = plain_text.len;
857
858         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
859
860         RETURN(GSS_S_COMPLETE);
861 }
862
863 static int
864 gss_pack_err_notify(struct ptlrpc_request *req,
865                     __u32 major, __u32 minor)
866 {
867         struct gss_svc_data *svcdata = req->rq_svcsec_data;
868         __u32 reslen, *resp, *reslenp;
869         char  nidstr[PTL_NALFMT_SIZE];
870         const __u32 secdata_len = 7 * 4;
871         int rc;
872         ENTRY;
873
874         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
875
876         LASSERT(svcdata);
877         svcdata->is_err_notify = 1;
878         svcdata->reserve_len = 7 * 4;
879
880         rc = lustre_pack_reply(req, 0, NULL, NULL);
881         if (rc) {
882                 CERROR("could not pack reply, err %d\n", rc);
883                 RETURN(rc);
884         }
885
886         LASSERT(req->rq_reply_state);
887         LASSERT(req->rq_reply_state->rs_repbuf);
888         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
889         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
890
891         /* header */
892         *resp++ = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
893         *resp++ = cpu_to_le32(PTLRPCS_SVC_NONE);
894         *resp++ = cpu_to_le32(req->rq_replen);
895         reslenp = resp++;
896
897         /* skip lustre msg */
898         resp += req->rq_replen / 4;
899         reslen = svcdata->reserve_len;
900
901         /* gss replay:
902          * version, subflavor, notify, major, minor,
903          * obj1(fake), obj2(fake)
904          */
905         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
906         *resp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
907         *resp++ = cpu_to_le32(PTLRPCS_GSS_PROC_ERR);
908         *resp++ = cpu_to_le32(major);
909         *resp++ = cpu_to_le32(minor);
910         *resp++ = 0;
911         *resp++ = 0;
912         reslen -= (4 * 4);
913         /* the actual sec data length */
914         *reslenp = cpu_to_le32(secdata_len);
915
916         req->rq_reply_state->rs_repdata_len += (secdata_len);
917         CDEBUG(D_SEC, "prepare gss error notify(0x%x/0x%x) to %s\n",
918                major, minor,
919                portals_nid2str(req->rq_peer.peer_ni->pni_number,
920                                req->rq_peer.peer_id.nid, nidstr));
921         RETURN(0);
922 }
923
924 static void dump_cache_head(struct cache_head *h)
925 {
926         CWARN("ref %d, fl %lx, n %p, t %ld, %ld\n",
927               atomic_read(&h->refcnt), h->flags, h->next,
928               h->expiry_time, h->last_refresh);
929 }
930 static void dump_rsi(struct rsi *rsi)
931 {
932         CWARN("dump rsi %p\n", rsi);
933         dump_cache_head(&rsi->h);
934         CWARN("%x,%x,%llx\n", rsi->naltype, rsi->netid, rsi->nid);
935         CWARN("len %d, d %p\n", rsi->in_handle.len, rsi->in_handle.data);
936         CWARN("len %d, d %p\n", rsi->in_token.len, rsi->in_token.data);
937         CWARN("len %d, d %p\n", rsi->out_handle.len, rsi->out_handle.data);
938         CWARN("len %d, d %p\n", rsi->out_token.len, rsi->out_token.data);
939 }
940
941 static int
942 gss_svcsec_handle_init(struct ptlrpc_request *req,
943                        struct rpc_gss_wire_cred *gc,
944                        __u32 *secdata, __u32 seclen,
945                        enum ptlrpcs_error *res)
946 {
947         struct gss_svc_data *svcdata = req->rq_svcsec_data;
948         struct rsc          *rsci;
949         struct rsi          *rsikey, *rsip;
950         rawobj_t             tmpobj;
951         __u32 reslen,       *resp, *reslenp;
952         char                 nidstr[PTL_NALFMT_SIZE];
953         int                  rc;
954         ENTRY;
955
956         LASSERT(svcdata);
957
958         CDEBUG(D_SEC, "processing gss init(%d) request from %s\n", gc->gc_proc,
959                portals_nid2str(req->rq_peer.peer_ni->pni_number,
960                                req->rq_peer.peer_id.nid, nidstr));
961
962         *res = PTLRPCS_BADCRED;
963         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
964
965         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
966             gc->gc_ctx.len != 0) {
967                 CERROR("proc %d, ctx_len %d: not really init?\n",
968                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
969                 RETURN(SVC_DROP);
970         }
971
972         OBD_ALLOC(rsikey, sizeof(*rsikey));
973         if (!rsikey) {
974                 CERROR("out of memory\n");
975                 RETURN(SVC_DROP);
976         }
977         cache_init(&rsikey->h);
978
979         /* obtain lustre svc type */
980         if (seclen < 4) {
981                 CERROR("sec size %d too small\n", seclen);
982                 GOTO(out_rsikey, rc = SVC_DROP);
983         }
984         rsikey->lustre_svc = le32_to_cpu(*secdata++);
985         seclen -= 4;
986
987         /* duplicate context handle. currently always 0 */
988         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
989                 CERROR("fail to dup context handle\n");
990                 GOTO(out_rsikey, rc = SVC_DROP);
991         }
992
993         /* extract token */
994         *res = PTLRPCS_BADVERF;
995         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
996                 CERROR("can't extract token\n");
997                 GOTO(out_rsikey, rc = SVC_DROP);
998         }
999         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
1000                 CERROR("can't duplicate token\n");
1001                 GOTO(out_rsikey, rc = SVC_DROP);
1002         }
1003
1004         rsikey->naltype = (__u32) req->rq_peer.peer_ni->pni_number;
1005         rsikey->netid = 0;
1006         rsikey->nid = (__u64) req->rq_peer.peer_id.nid;
1007
1008         rsip = gssd_upcall(rsikey, &my_chandle);
1009         if (!rsip) {
1010                 CERROR("error in gssd_upcall.\n");
1011
1012                 rc = SVC_COMPLETE;
1013                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
1014                         rc = SVC_DROP;
1015
1016                 GOTO(out_rsikey, rc);
1017         }
1018
1019         rsci = gss_svc_searchbyctx(&rsip->out_handle);
1020         if (!rsci) {
1021                 CERROR("rsci still not mature yet?\n");
1022
1023                 rc = SVC_COMPLETE;
1024                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
1025                         rc = SVC_DROP;
1026
1027                 GOTO(out_rsip, rc);
1028         }
1029         CDEBUG(D_SEC, "svcsec create gss context %p(%u@%s)\n",
1030                rsci, rsci->cred.vc_uid,
1031                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1032                                req->rq_peer.peer_id.nid, nidstr));
1033
1034         svcdata->is_init = 1;
1035         svcdata->reserve_len = 7 * 4 +
1036                 size_round4(rsip->out_handle.len) +
1037                 size_round4(rsip->out_token.len);
1038
1039         rc = lustre_pack_reply(req, 0, NULL, NULL);
1040         if (rc) {
1041                 CERROR("failed to pack reply, rc = %d\n", rc);
1042                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1043                 GOTO(out, rc = SVC_DROP);
1044         }
1045
1046         /* header */
1047         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
1048         *resp++ = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
1049         *resp++ = cpu_to_le32(PTLRPCS_SVC_NONE);
1050         *resp++ = cpu_to_le32(req->rq_replen);
1051         reslenp = resp++;
1052
1053         resp += req->rq_replen / 4;
1054         reslen = svcdata->reserve_len;
1055
1056         /* gss reply: (conform to err notify format)
1057          * x, x, seq, major, minor, handle, token
1058          */
1059         *resp++ = 0;
1060         *resp++ = 0;
1061         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
1062         *resp++ = cpu_to_le32(rsip->major_status);
1063         *resp++ = cpu_to_le32(rsip->minor_status);
1064         reslen -= (5 * 4);
1065         if (rawobj_serialize(&rsip->out_handle,
1066                              &resp, &reslen)) {
1067                 dump_rsi(rsip);
1068                 dump_rsi(rsikey);
1069                 LBUG();
1070         }
1071         if (rawobj_serialize(&rsip->out_token,
1072                              &resp, &reslen)) {
1073                 dump_rsi(rsip);
1074                 dump_rsi(rsikey);
1075                 LBUG();
1076         }
1077         /* the actual sec data length */
1078         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
1079
1080         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
1081         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
1082                "total size %d\n", req, req->rq_replen,
1083                le32_to_cpu(*reslenp),
1084                req->rq_reply_state->rs_repdata_len);
1085
1086         *res = PTLRPCS_OK;
1087
1088         req->rq_remote_realm = rsci->remote_realm;
1089         req->rq_auth_usr_mds = rsci->auth_usr_mds;
1090         req->rq_auth_usr_oss = rsci->auth_usr_oss;
1091         req->rq_auth_uid = rsci->cred.vc_uid;
1092         req->rq_mapped_uid = rsci->mapped_uid;
1093
1094         if (req->rq_auth_usr_mds) {
1095                 CWARN("usr from %s authenticated as mds svc cred\n",
1096                 portals_nid2str(req->rq_peer.peer_ni->pni_number,
1097                                 req->rq_peer.peer_id.nid, nidstr));
1098         }
1099         if (req->rq_auth_usr_oss) {
1100                 CWARN("usr from %s authenticated as oss svc cred\n",
1101                 portals_nid2str(req->rq_peer.peer_ni->pni_number,
1102                                 req->rq_peer.peer_id.nid, nidstr));
1103         }
1104
1105         /* This is simplified since right now we doesn't support
1106          * INIT_CONTINUE yet.
1107          */
1108         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
1109                 struct ptlrpcs_wire_hdr *hdr;
1110
1111                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
1112                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
1113                 req->rq_reqlen = hdr->msg_len;
1114
1115                 rc = SVC_LOGIN;
1116         } else
1117                 rc = SVC_COMPLETE;
1118
1119 out:
1120         rsc_put(&rsci->h, &rsc_cache);
1121 out_rsip:
1122         rsi_put(&rsip->h, &rsi_cache);
1123 out_rsikey:
1124         rsi_put(&rsikey->h, &rsi_cache);
1125
1126         RETURN(rc);
1127 }
1128
1129 static int
1130 gss_svcsec_handle_data(struct ptlrpc_request *req,
1131                        struct rpc_gss_wire_cred *gc,
1132                        __u32 *secdata, __u32 seclen,
1133                        enum ptlrpcs_error *res)
1134 {
1135         struct rsc          *rsci;
1136         char                 nidstr[PTL_NALFMT_SIZE];
1137         __u32                major;
1138         int                  rc;
1139         ENTRY;
1140
1141         *res = PTLRPCS_GSS_CREDPROBLEM;
1142
1143         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1144         if (!rsci) {
1145                 CWARN("Invalid gss context handle from %s\n",
1146                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1147                                        req->rq_peer.peer_id.nid, nidstr));
1148                 major = GSS_S_NO_CONTEXT;
1149                 goto notify_err;
1150         }
1151
1152         switch (gc->gc_svc) {
1153         case PTLRPCS_GSS_SVC_INTEGRITY:
1154                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1155                 if (major == GSS_S_COMPLETE)
1156                         break;
1157
1158                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1159                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1160                                        req->rq_peer.peer_id.nid, nidstr));
1161                 goto notify_err;
1162         case PTLRPCS_GSS_SVC_PRIVACY:
1163                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1164                 if (major == GSS_S_COMPLETE)
1165                         break;
1166
1167                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1168                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1169                                        req->rq_peer.peer_id.nid, nidstr));
1170                 goto notify_err;
1171         default:
1172                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1173                 GOTO(out, rc = SVC_DROP);
1174         }
1175
1176         req->rq_remote_realm = rsci->remote_realm;
1177         req->rq_auth_usr_mds = rsci->auth_usr_mds;
1178         req->rq_auth_usr_oss = rsci->auth_usr_oss;
1179         req->rq_auth_uid = rsci->cred.vc_uid;
1180         req->rq_mapped_uid = rsci->mapped_uid;
1181
1182         *res = PTLRPCS_OK;
1183         GOTO(out, rc = SVC_OK);
1184
1185 notify_err:
1186         if (gss_pack_err_notify(req, major, 0))
1187                 rc = SVC_DROP;
1188         else
1189                 rc = SVC_COMPLETE;
1190 out:
1191         if (rsci)
1192                 rsc_put(&rsci->h, &rsc_cache);
1193         RETURN(rc);
1194 }
1195
1196 static int
1197 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1198                           struct rpc_gss_wire_cred *gc,
1199                           __u32 *secdata, __u32 seclen,
1200                           enum ptlrpcs_error *res)
1201 {
1202         struct gss_svc_data *svcdata = req->rq_svcsec_data;
1203         struct rsc          *rsci;
1204         char                 nidstr[PTL_NALFMT_SIZE];
1205         int                  rc;
1206         ENTRY;
1207
1208         LASSERT(svcdata);
1209         *res = PTLRPCS_GSS_CREDPROBLEM;
1210
1211         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1212         if (!rsci) {
1213                 CWARN("invalid gss context handle for destroy.\n");
1214                 RETURN(SVC_DROP);
1215         }
1216
1217         if (gc->gc_svc != PTLRPCS_GSS_SVC_INTEGRITY) {
1218                 CERROR("service %d is not supported in destroy.\n",
1219                         gc->gc_svc);
1220                 GOTO(out, rc = SVC_DROP);
1221         }
1222
1223         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1224         if (*res)
1225                 GOTO(out, rc = SVC_DROP);
1226
1227         /* compose reply, which is actually nothing */
1228         svcdata->is_fini = 1;
1229         if (lustre_pack_reply(req, 0, NULL, NULL))
1230                 GOTO(out, rc = SVC_DROP);
1231
1232         CDEBUG(D_SEC, "svcsec destroy gss context %p(%u@%s)\n",
1233                rsci, rsci->cred.vc_uid,
1234                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1235                                req->rq_peer.peer_id.nid, nidstr));
1236
1237         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1238         *res = PTLRPCS_OK;
1239         rc = SVC_LOGOUT;
1240 out:
1241         rsc_put(&rsci->h, &rsc_cache);
1242         RETURN(rc);
1243 }
1244
1245 /*
1246  * let incomming request go through security check:
1247  *  o context establishment: invoke user space helper
1248  *  o data exchange: verify/decrypt
1249  *  o context destruction: mark context invalid
1250  *
1251  * in most cases, error will result to drop the packet silently.
1252  */
1253 static int
1254 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1255 {
1256         struct gss_svc_data *svcdata;
1257         struct rpc_gss_wire_cred *gc;
1258         struct ptlrpcs_wire_hdr *sec_hdr;
1259         __u32 subflavor, seclen, *secdata, version;
1260         int rc;
1261         ENTRY;
1262
1263         CDEBUG(D_SEC, "request %p\n", req);
1264         LASSERT(req->rq_reqbuf);
1265         LASSERT(req->rq_reqbuf_len);
1266
1267         *res = PTLRPCS_BADCRED;
1268
1269         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1270         LASSERT(SEC_FLAVOR_MAJOR(sec_hdr->flavor) == PTLRPCS_FLVR_MAJOR_GSS);
1271
1272         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1273         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1274
1275         if (sec_hdr->sec_len > seclen) {
1276                 CERROR("seclen %d, while max buf %d\n",
1277                         sec_hdr->sec_len, seclen);
1278                 RETURN(SVC_DROP);
1279         }
1280
1281         if (seclen < 6 * 4) {
1282                 CERROR("sec size %d too small\n", seclen);
1283                 RETURN(SVC_DROP);
1284         }
1285
1286         LASSERT(!req->rq_svcsec_data);
1287         OBD_ALLOC(svcdata, sizeof(*svcdata));
1288         if (!svcdata) {
1289                 CERROR("fail to alloc svcdata\n");
1290                 RETURN(SVC_DROP);
1291         }
1292         req->rq_svcsec_data = svcdata;
1293         gc = &svcdata->clcred;
1294
1295         /* Now secdata/seclen is what we want to parse
1296          */
1297         version = le32_to_cpu(*secdata++);      /* version */
1298         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1299         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1300         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1301         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1302         seclen -= 5 * 4;
1303
1304         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1305                version, subflavor, gc->gc_proc,
1306                gc->gc_seq, gc->gc_svc);
1307
1308         if (version != PTLRPC_SEC_GSS_VERSION) {
1309                 CERROR("gss version mismatch: %d - %d\n",
1310                         version, PTLRPC_SEC_GSS_VERSION);
1311                 GOTO(err_free, rc = SVC_DROP);
1312         }
1313
1314         /* We _must_ alloc new storage for gc_ctx. In case of recovery
1315          * request will be saved to delayed handling, at that time the
1316          * incoming buffer might have already been released.
1317          */
1318         if (rawobj_extract_alloc(&gc->gc_ctx, &secdata, &seclen)) {
1319                 CERROR("fail to obtain gss context handle\n");
1320                 GOTO(err_free, rc = SVC_DROP);
1321         }
1322
1323         *res = PTLRPCS_BADVERF;
1324         switch(gc->gc_proc) {
1325         case RPC_GSS_PROC_INIT:
1326         case RPC_GSS_PROC_CONTINUE_INIT:
1327                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1328                 break;
1329         case RPC_GSS_PROC_DATA:
1330                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1331                 break;
1332         case RPC_GSS_PROC_DESTROY:
1333                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1334                 break;
1335         default:
1336                 rc = SVC_DROP;
1337                 LBUG();
1338         }
1339
1340 err_free:
1341         if (rc == SVC_DROP && req->rq_svcsec_data) {
1342                 OBD_FREE(req->rq_svcsec_data, sizeof(struct gss_svc_data));
1343                 req->rq_svcsec_data = NULL;
1344         }
1345
1346         RETURN(rc);
1347 }
1348
1349 static int
1350 gss_svcsec_authorize(struct ptlrpc_request *req)
1351 {
1352         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1353         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_svcsec_data;
1354         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1355         struct rsc                *rscp;
1356         struct ptlrpcs_wire_hdr   *sec_hdr;
1357         rawobj_buf_t               msg_buf;
1358         rawobj_t                   cipher_buf;
1359         __u32                     *vp, *vpsave, major, vlen, seclen;
1360         rawobj_t                   lmsg, mic;
1361         int                        ret;
1362         ENTRY;
1363
1364         LASSERT(rs);
1365         LASSERT(rs->rs_repbuf);
1366         LASSERT(gsd);
1367
1368         if (gsd->is_init || gsd->is_init_continue ||
1369             gsd->is_err_notify || gsd->is_fini) {
1370                 /* nothing to do in these cases */
1371                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1372                 RETURN(0);
1373         }
1374
1375         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1376                 CERROR("proc %d not support\n", gc->gc_proc);
1377                 RETURN(-EINVAL);
1378         }
1379
1380         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1381         if (!rscp) {
1382                 CERROR("ctx %u disapeared under us\n",
1383                        *((__u32 *) gc->gc_ctx.data));
1384                 RETURN(-EINVAL);
1385         }
1386
1387         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1388         switch (gc->gc_svc) {
1389         case  PTLRPCS_GSS_SVC_INTEGRITY:
1390                 /* prepare various pointers */
1391                 lmsg.len = req->rq_replen;
1392                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1393                 vp = (__u32 *) (lmsg.data + lmsg.len);
1394                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1395                 seclen = vlen;
1396
1397                 sec_hdr->flavor = cpu_to_le32(PTLRPCS_FLVR_GSS_AUTH);
1398                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1399
1400                 /* standard gss hdr */
1401                 LASSERT(vlen >= 7 * 4);
1402                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1403                 *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
1404                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1405                 *vp++ = cpu_to_le32(gc->gc_seq);
1406                 *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_INTEGRITY);
1407                 *vp++ = 0;      /* fake ctx handle */
1408                 vpsave = vp++;  /* reserve size */
1409                 vlen -= 7 * 4;
1410
1411                 mic.len = vlen;
1412                 mic.data = (unsigned char *)vp;
1413
1414                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1415                 if (major) {
1416                         CERROR("fail to get MIC: 0x%x\n", major);
1417                         GOTO(out, ret = -EINVAL);
1418                 }
1419                 *vpsave = cpu_to_le32(mic.len);
1420                 seclen = seclen - vlen + mic.len;
1421                 sec_hdr->sec_len = cpu_to_le32(seclen);
1422                 rs->rs_repdata_len += size_round(seclen);
1423                 break;
1424         case  PTLRPCS_GSS_SVC_PRIVACY:
1425                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1426                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1427                 seclen = vlen;
1428
1429                 sec_hdr->flavor = cpu_to_le32(PTLRPCS_FLVR_GSS_PRIV);
1430                 sec_hdr->msg_len = cpu_to_le32(0);
1431
1432                 /* standard gss hdr */
1433                 LASSERT(vlen >= 7 * 4);
1434                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1435                 *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
1436                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1437                 *vp++ = cpu_to_le32(gc->gc_seq);
1438                 *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_PRIVACY);
1439                 *vp++ = 0;      /* fake ctx handle */
1440                 vpsave = vp++;  /* reserve size */
1441                 vlen -= 7 * 4;
1442
1443                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1444                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1445                                  GSS_PRIVBUF_SUFFIX_LEN;
1446                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1447                 msg_buf.datalen = req->rq_replen;
1448
1449                 cipher_buf.data = (__u8 *) vp;
1450                 cipher_buf.len = vlen;
1451
1452                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1453                                 &msg_buf, &cipher_buf);
1454                 if (major) {
1455                         CERROR("failed to wrap: 0x%x\n", major);
1456                         GOTO(out, ret = -EINVAL);
1457                 }
1458
1459                 *vpsave = cpu_to_le32(cipher_buf.len);
1460                 seclen = seclen - vlen + cipher_buf.len;
1461                 sec_hdr->sec_len = cpu_to_le32(seclen);
1462                 rs->rs_repdata_len += size_round(seclen);
1463                 break;
1464         default:
1465                 CERROR("Unknown service %d\n", gc->gc_svc);
1466                 GOTO(out, ret = -EINVAL);
1467         }
1468         ret = 0;
1469 out:
1470         rsc_put(&rscp->h, &rsc_cache);
1471
1472         RETURN(ret);
1473 }
1474
1475 static
1476 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1477                             struct ptlrpc_request *req)
1478 {
1479         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_svcsec_data;
1480
1481         if (!gsd) {
1482                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1483                 return;
1484         }
1485
1486         /* gc_ctx is allocated, see gss_svcsec_accept() */
1487         rawobj_free(&gsd->clcred.gc_ctx);
1488
1489         OBD_FREE(gsd, sizeof(*gsd));
1490         req->rq_svcsec_data = NULL;
1491         return;
1492 }
1493
1494 static
1495 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1496                            struct ptlrpc_request *req,
1497                            int msgsize)
1498 {
1499         struct gss_svc_data *svcdata = req->rq_svcsec_data;
1500         ENTRY;
1501
1502         /* just return the pre-set reserve_len for init/fini/err cases.
1503          */
1504         LASSERT(svcdata);
1505         if (svcdata->is_init) {
1506                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1507                        size_round(svcdata->reserve_len),
1508                        svcdata->reserve_len);
1509                 LASSERT(svcdata->reserve_len);
1510                 LASSERT(svcdata->reserve_len % 4 == 0);
1511                 RETURN(size_round(svcdata->reserve_len));
1512         } else if (svcdata->is_err_notify) {
1513                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1514                        size_round(svcdata->reserve_len),
1515                        svcdata->reserve_len);
1516                 RETURN(size_round(svcdata->reserve_len));
1517         } else if (svcdata->is_fini) {
1518                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1519                 RETURN(0);
1520         } else {
1521                 if (svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_NONE ||
1522                     svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_INTEGRITY)
1523                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1524                 else if (svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_PRIVACY)
1525                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1526                                             GSS_PRIVBUF_PREFIX_LEN +
1527                                             GSS_PRIVBUF_SUFFIX_LEN));
1528                 else {
1529                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1530                         *((int *)0) = 0;
1531                         LBUG();
1532                 }
1533         }
1534         RETURN(0);
1535 }
1536
1537 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1538                             struct ptlrpc_request *req,
1539                             int msgsize)
1540 {
1541         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_svcsec_data;
1542         struct ptlrpc_reply_state *rs;
1543         int msg_payload, sec_payload;
1544         int privacy, rc;
1545         ENTRY;
1546
1547         /* determine the security type: none/auth or priv, we have
1548          * different pack scheme for them.
1549          * init/fini/err will always be treated as none/auth.
1550          */
1551         LASSERT(gsd);
1552         if (!gsd->is_init && !gsd->is_init_continue &&
1553             !gsd->is_fini && !gsd->is_err_notify &&
1554             gsd->clcred.gc_svc == PTLRPCS_GSS_SVC_PRIVACY)
1555                 privacy = 1;
1556         else
1557                 privacy = 0;
1558
1559         msg_payload = privacy ? 0 : msgsize;
1560         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1561
1562         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1563         if (rc)
1564                 RETURN(rc);
1565
1566         rs = req->rq_reply_state;
1567         LASSERT(rs);
1568         rs->rs_msg_len = msgsize;
1569
1570         if (privacy) {
1571                 /* we can choose to let msg simply point to the rear of the
1572                  * buffer, which lead to buffer overlap when doing encryption.
1573                  * usually it's ok and it indeed passed all existing tests.
1574                  * but not sure if there will be subtle problems in the future.
1575                  * so right now we choose to alloc another new buffer. we'll
1576                  * see how it works.
1577                  */
1578 #if 0
1579                 rs->rs_msg = (struct lustre_msg *)
1580                              (rs->rs_repbuf + rs->rs_repbuf_len -
1581                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1582 #endif
1583                 char *msgbuf;
1584
1585                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1586                 OBD_ALLOC(msgbuf, msgsize);
1587                 if (!msgbuf) {
1588                         CERROR("can't alloc %d\n", msgsize);
1589                         svcsec_free_reply_state(rs);
1590                         req->rq_reply_state = NULL;
1591                         RETURN(-ENOMEM);
1592                 }
1593                 rs->rs_msg = (struct lustre_msg *)
1594                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1595         }
1596
1597         req->rq_repmsg = rs->rs_msg;
1598
1599         RETURN(0);
1600 }
1601
1602 static
1603 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1604                             struct ptlrpc_reply_state *rs)
1605 {
1606         unsigned long p1 = (unsigned long) rs->rs_msg;
1607         unsigned long p2 = (unsigned long) rs->rs_buf;
1608
1609         LASSERT(rs->rs_buf);
1610         LASSERT(rs->rs_msg);
1611         LASSERT(rs->rs_msg_len);
1612
1613         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1614                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1615                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1616                            GSS_PRIVBUF_SUFFIX_LEN;
1617                 OBD_FREE(start, size);
1618         }
1619
1620         svcsec_free_reply_state(rs);
1621 }
1622
1623 struct ptlrpc_svcsec svcsec_gss = {
1624         .pss_owner              = THIS_MODULE,
1625         .pss_name               = "svcsec.gss",
1626         .pss_flavor             = PTLRPCS_FLVR_MAJOR_GSS,
1627         .accept                 = gss_svcsec_accept,
1628         .authorize              = gss_svcsec_authorize,
1629         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1630         .free_repbuf            = gss_svcsec_free_repbuf,
1631         .cleanup_req            = gss_svcsec_cleanup_req,
1632 };
1633
1634 /* XXX hacking */
1635 void lgss_svc_cache_purge_all(void)
1636 {
1637         cache_purge(&rsi_cache);
1638         cache_purge(&rsc_cache);
1639 }
1640 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1641
1642 void lgss_svc_cache_flush(__u32 uid)
1643 {
1644         rsc_flush(uid);
1645 }
1646 EXPORT_SYMBOL(lgss_svc_cache_flush);
1647
1648 int gss_svc_init(void)
1649 {
1650         int rc;
1651
1652         rc = svcsec_register(&svcsec_gss);
1653         if (!rc) {
1654                 cache_register(&rsc_cache);
1655                 cache_register(&rsi_cache);
1656         }
1657         return rc;
1658 }
1659
1660 void gss_svc_exit(void)
1661 {
1662         int rc;
1663
1664         /* XXX rsi didn't take module refcount. without really
1665          * cleanup it we can't simply go, later user-space operations
1666          * will certainly cause oops.
1667          * use space might slow or stuck on something, wait it for
1668          * a bit -- bad hack.
1669          */
1670         while ((rc = cache_unregister(&rsi_cache))) {
1671                 CERROR("unregister rsi cache: %d. Try again\n", rc);
1672                 schedule_timeout(2 * HZ);
1673                 cache_purge(&rsi_cache);
1674         }
1675
1676         if ((rc = cache_unregister(&rsc_cache)))
1677                 CERROR("unregister rsc cache: %d\n", rc);
1678         if ((rc = svcsec_unregister(&svcsec_gss)))
1679                 CERROR("unregister svcsec_gss: %d\n", rc);
1680 }