Whamcloud - gitweb
3087d03d29251e4afd2fedaffe4305ae699e567b
[fs/lustre-release.git] / lustre / sec / gss / svcsec_gss.c
1 /* -*- mode: c; c-basic-offset: 8; indent-tabs-mode: nil; -*-
2  * vim:expandtab:shiftwidth=8:tabstop=8:
3  *
4  * Modifications for Lustre
5  * Copyright 2004, Cluster File Systems, Inc.
6  * All rights reserved
7  * Author: Eric Mei <ericm@clusterfs.com>
8  */
9
10 /*
11  * Neil Brown <neilb@cse.unsw.edu.au>
12  * J. Bruce Fields <bfields@umich.edu>
13  * Andy Adamson <andros@umich.edu>
14  * Dug Song <dugsong@monkey.org>
15  *
16  * RPCSEC_GSS server authentication.
17  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
18  * (gssapi)
19  *
20  * The RPCSEC_GSS involves three stages:
21  *  1/ context creation
22  *  2/ data exchange
23  *  3/ context destruction
24  *
25  * Context creation is handled largely by upcalls to user-space.
26  *  In particular, GSS_Accept_sec_context is handled by an upcall
27  * Data exchange is handled entirely within the kernel
28  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
29  * Context destruction is handled in-kernel
30  *  GSS_Delete_sec_context is in-kernel
31  *
32  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
33  * The context handle and gss_token are used as a key into the rpcsec_init cache.
34  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
35  * being major_status, minor_status, context_handle, reply_token.
36  * These are sent back to the client.
37  * Sequence window management is handled by the kernel.  The window size if currently
38  * a compile time constant.
39  *
40  * When user-space is happy that a context is established, it places an entry
41  * in the rpcsec_context cache. The key for this cache is the context_handle.
42  * The content includes:
43  *   uid/gidlist - for determining access rights
44  *   mechanism type
45  *   mechanism specific information, such as a key
46  *
47  */
48
49 #define DEBUG_SUBSYSTEM S_SEC
50 #ifdef __KERNEL__
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/slab.h>
55 #include <linux/hash.h>
56 #else
57 #include <liblustre.h>
58 #endif
59
60 #include <linux/sunrpc/cache.h>
61
62 #include <libcfs/kp30.h>
63 #include <linux/obd.h>
64 #include <linux/obd_class.h>
65 #include <linux/obd_support.h>
66 #include <linux/lustre_idl.h>
67 #include <linux/lustre_net.h>
68 #include <linux/lustre_import.h>
69 #include <linux/lustre_sec.h>
70                                                                                                                         
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74
75 static inline unsigned long hash_mem(char *buf, int length, int bits)
76 {
77         unsigned long hash = 0;
78         unsigned long l = 0;
79         int len = 0;
80         unsigned char c;
81         do {
82                 if (len == length) {
83                         c = (char)len; len = -1;
84                 } else
85                         c = *buf++;
86                 l = (l << 8) | c;
87                 len++;
88                 if ((len & (BITS_PER_LONG/8-1))==0)
89                         hash = hash_long(hash^l, BITS_PER_LONG);
90         } while (len);
91         return hash >> (BITS_PER_LONG - bits);
92 }
93
94 /* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests
95  * into replies.
96  *
97  * Key is context handle (\x if empty) and gss_token.
98  * Content is major_status minor_status (integers) context_handle, reply_token.
99  *
100  */
101
102 #define RSI_HASHBITS    6
103 #define RSI_HASHMAX     (1<<RSI_HASHBITS)
104 #define RSI_HASHMASK    (RSI_HASHMAX-1)
105
106 struct rsi {
107         struct cache_head       h;
108         __u32                   lustre_svc;
109         __u32                   naltype;
110         __u32                   netid;
111         __u64                   nid;
112         rawobj_t                in_handle, in_token, in_srv_type;
113         rawobj_t                out_handle, out_token;
114         int                     major_status, minor_status;
115 };
116
117 static struct cache_head *rsi_table[RSI_HASHMAX];
118 static struct cache_detail rsi_cache;
119
120 static void rsi_free(struct rsi *rsii)
121 {
122         rawobj_free(&rsii->in_handle);
123         rawobj_free(&rsii->in_token);
124         rawobj_free(&rsii->out_handle);
125         rawobj_free(&rsii->out_token);
126 }
127
128 static void rsi_put(struct cache_head *item, struct cache_detail *cd)
129 {
130         struct rsi *rsii = container_of(item, struct rsi, h);
131         LASSERT(atomic_read(&item->refcnt) > 0);
132         if (cache_put(item, cd)) {
133                 LASSERT(item->next == NULL);
134                 rsi_free(rsii);
135                 OBD_FREE(rsii, sizeof(*rsii));
136         }
137 }
138
139 static inline int rsi_hash(struct rsi *item)
140 {
141         return hash_mem((char *)item->in_handle.data, item->in_handle.len, RSI_HASHBITS)
142                 ^ hash_mem((char *)item->in_token.data, item->in_token.len, RSI_HASHBITS);
143 }
144
145 static inline int rsi_match(struct rsi *item, struct rsi *tmp)
146 {
147         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
148                 rawobj_equal(&item->in_token, &tmp->in_token));
149 }
150
151 static void rsi_request(struct cache_detail *cd,
152                         struct cache_head *h,
153                         char **bpp, int *blen)
154 {
155         struct rsi *rsii = container_of(h, struct rsi, h);
156
157         qword_addhex(bpp, blen, (char *) &rsii->lustre_svc,
158                      sizeof(rsii->lustre_svc));
159         qword_addhex(bpp, blen, (char *) &rsii->naltype, sizeof(rsii->naltype));
160         qword_addhex(bpp, blen, (char *) &rsii->netid, sizeof(rsii->netid));
161         qword_addhex(bpp, blen, (char *) &rsii->nid, sizeof(rsii->nid));
162         qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len);
163         qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len);
164         (*bpp)[-1] = '\n';
165 }
166
167 static int
168 gssd_reply(struct rsi *item)
169 {
170         struct rsi *tmp;
171         struct cache_head **hp, **head;
172         ENTRY;
173
174         head = &rsi_cache.hash_table[rsi_hash(item)];
175         write_lock(&rsi_cache.hash_lock);
176         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
177                 tmp = container_of(*hp, struct rsi, h);
178                 if (rsi_match(tmp, item)) {
179                         cache_get(&tmp->h);
180                         clear_bit(CACHE_HASHED, &tmp->h.flags);
181                         *hp = tmp->h.next;
182                         tmp->h.next = NULL;
183                         rsi_cache.entries--;
184                         if (test_bit(CACHE_VALID, &tmp->h.flags)) {
185                                 CERROR("rsi is valid\n");
186                                 write_unlock(&rsi_cache.hash_lock);
187                                 rsi_put(&tmp->h, &rsi_cache);
188                                 RETURN(-EINVAL);
189                         }
190                         set_bit(CACHE_HASHED, &item->h.flags);
191                         item->h.next = *hp;
192                         *hp = &item->h;
193                         rsi_cache.entries++;
194                         set_bit(CACHE_VALID, &item->h.flags);
195                         item->h.last_refresh = get_seconds();
196                         write_unlock(&rsi_cache.hash_lock);
197                         cache_fresh(&rsi_cache, &tmp->h, 0);
198                         rsi_put(&tmp->h, &rsi_cache);
199                         RETURN(0);
200                 }
201         }
202         write_unlock(&rsi_cache.hash_lock);
203         RETURN(-EINVAL);
204 }
205
206 /* XXX
207  * here we just wait here for its completion or timedout. it's a
208  * hacking but works, and we'll comeup with real fix if we decided
209  * to still stick with NFS4 cache code
210  */
211 static struct rsi *
212 gssd_upcall(struct rsi *item, struct cache_req *chandle)
213 {
214         struct rsi *tmp;
215         struct cache_head **hp, **head;
216         unsigned long starttime;
217         ENTRY;
218
219         head = &rsi_cache.hash_table[rsi_hash(item)];
220         read_lock(&rsi_cache.hash_lock);
221         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
222                 tmp = container_of(*hp, struct rsi, h);
223                 if (rsi_match(tmp, item)) {
224                         LBUG();
225                         if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
226                                 CERROR("found rsi without VALID\n");
227                                 read_unlock(&rsi_cache.hash_lock);
228                                 return NULL;
229                         }
230                         *hp = tmp->h.next;
231                         tmp->h.next = NULL;
232                         rsi_cache.entries--;
233                         cache_get(&tmp->h);
234                         read_unlock(&rsi_cache.hash_lock);
235                         return tmp;
236                 }
237         }
238         cache_get(&item->h);
239         set_bit(CACHE_HASHED, &item->h.flags);
240         item->h.next = *head;
241         *head = &item->h;
242         rsi_cache.entries++;
243         read_unlock(&rsi_cache.hash_lock);
244         //cache_get(&item->h);
245
246         cache_check(&rsi_cache, &item->h, chandle);
247         starttime = get_seconds();
248         do {
249                 set_current_state(TASK_UNINTERRUPTIBLE);
250                 schedule_timeout(HZ/2);
251                 read_lock(&rsi_cache.hash_lock);
252                 for (hp = head; *hp != NULL; hp = &tmp->h.next) {
253                         tmp = container_of(*hp, struct rsi, h);
254                         if (tmp == item)
255                                 continue;
256                         if (rsi_match(tmp, item)) {
257                                 if (!test_bit(CACHE_VALID, &tmp->h.flags)) {
258                                         read_unlock(&rsi_cache.hash_lock);
259                                         return NULL;
260                                 }
261                                 cache_get(&tmp->h);
262                                 clear_bit(CACHE_HASHED, &tmp->h.flags);
263                                 *hp = tmp->h.next;
264                                 tmp->h.next = NULL;
265                                 rsi_cache.entries--;
266                                 read_unlock(&rsi_cache.hash_lock);
267                                 return tmp;
268                         }
269                 }
270                 read_unlock(&rsi_cache.hash_lock);
271         } while ((get_seconds() - starttime) <= 15);
272         CERROR("15s timeout while waiting cache refill\n");
273         return NULL;
274 }
275
276 static int rsi_parse(struct cache_detail *cd,
277                      char *mesg, int mlen)
278 {
279         /* context token expiry major minor context token */
280         char *buf = mesg;
281         char *ep;
282         int len;
283         struct rsi *rsii;
284         time_t expiry;
285         int status = -EINVAL;
286         ENTRY;
287
288         OBD_ALLOC(rsii, sizeof(*rsii));
289         if (!rsii)
290                 RETURN(-ENOMEM);
291         cache_init(&rsii->h);
292
293         /* handle */
294         len = qword_get(&mesg, buf, mlen);
295         if (len < 0)
296                 goto out;
297         if (rawobj_alloc(&rsii->in_handle, buf, len)) {
298                 status = -ENOMEM;
299                 goto out;
300         }
301
302         /* token */
303         len = qword_get(&mesg, buf, mlen);
304         if (len < 0)
305                 goto out;
306         if (rawobj_alloc(&rsii->in_token, buf, len)) {
307                 status = -ENOMEM;
308                 goto out;
309         }
310
311         /* expiry */
312         expiry = get_expiry(&mesg);
313         if (expiry == 0)
314                 goto out;
315
316         /* major */
317         len = qword_get(&mesg, buf, mlen);
318         if (len <= 0)
319                 goto out;
320         rsii->major_status = simple_strtol(buf, &ep, 10);
321         if (*ep)
322                 goto out;
323
324         /* minor */
325         len = qword_get(&mesg, buf, mlen);
326         if (len <= 0)
327                 goto out;
328         rsii->minor_status = simple_strtol(buf, &ep, 10);
329         if (*ep)
330                 goto out;
331
332         /* out_handle */
333         len = qword_get(&mesg, buf, mlen);
334         if (len < 0)
335                 goto out;
336         if (rawobj_alloc(&rsii->out_handle, buf, len)) {
337                 status = -ENOMEM;
338                 goto out;
339         }
340
341         /* out_token */
342         len = qword_get(&mesg, buf, mlen);
343         if (len < 0)
344                 goto out;
345         if (rawobj_alloc(&rsii->out_token, buf, len)) {
346                 status = -ENOMEM;
347                 goto out;
348         }
349
350         rsii->h.expiry_time = expiry;
351         status = gssd_reply(rsii);
352 out:
353         if (rsii)
354                 rsi_put(&rsii->h, &rsi_cache);
355         RETURN(status);
356 }
357
358 static struct cache_detail rsi_cache = {
359         .hash_size      = RSI_HASHMAX,
360         .hash_table     = rsi_table,
361         .name           = "auth.ptlrpcs.init",
362         .cache_put      = rsi_put,
363         .cache_request  = rsi_request,
364         .cache_parse    = rsi_parse,
365 };
366
367 /*
368  * The rpcsec_context cache is used to store a context that is
369  * used in data exchange.
370  * The key is a context handle. The content is:
371  *  uid, gidlist, mechanism, service-set, mech-specific-data
372  */
373
374 #define RSC_HASHBITS    10
375 #define RSC_HASHMAX     (1<<RSC_HASHBITS)
376 #define RSC_HASHMASK    (RSC_HASHMAX-1)
377
378 #define GSS_SEQ_WIN     512
379
380 struct gss_svc_seq_data {
381         /* highest seq number seen so far: */
382         __u32                   sd_max;
383         /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of
384          * sd_win is nonzero iff sequence number i has been seen already: */
385         unsigned long           sd_win[GSS_SEQ_WIN/BITS_PER_LONG];
386         spinlock_t              sd_lock;
387 };
388
389 struct rsc {
390         struct cache_head       h;
391         rawobj_t                handle;
392         __u32                   remote_realm:1,
393                                 auth_usr_mds:1,
394                                 auth_usr_oss:1;
395         struct vfs_cred         cred;
396         uid_t                   mapped_uid;
397         struct gss_svc_seq_data seqdata;
398         struct gss_ctx         *mechctx;
399 };
400
401 static struct cache_head *rsc_table[RSC_HASHMAX];
402 static struct cache_detail rsc_cache;
403
404 static void rsc_free(struct rsc *rsci)
405 {
406         rawobj_free(&rsci->handle);
407         if (rsci->mechctx)
408                 kgss_delete_sec_context(&rsci->mechctx);
409 #if 0
410         if (rsci->cred.vc_ginfo)
411                 put_group_info(rsci->cred.vc_ginfo);
412 #endif
413 }
414
415 static void rsc_put(struct cache_head *item, struct cache_detail *cd)
416 {
417         struct rsc *rsci = container_of(item, struct rsc, h);
418
419         LASSERT(atomic_read(&item->refcnt) > 0);
420         if (cache_put(item, cd)) {
421                 LASSERT(item->next == NULL);
422                 rsc_free(rsci);
423                 OBD_FREE(rsci, sizeof(*rsci));
424         }
425 }
426
427 static inline int
428 rsc_hash(struct rsc *rsci)
429 {
430         return hash_mem((char *)rsci->handle.data,
431                         rsci->handle.len, RSC_HASHBITS);
432 }
433
434 static inline int
435 rsc_match(struct rsc *new, struct rsc *tmp)
436 {
437         return rawobj_equal(&new->handle, &tmp->handle);
438 }
439
440 static struct rsc *rsc_lookup(struct rsc *item, int set)
441 {
442         struct rsc *tmp = NULL;
443         struct cache_head **hp, **head;
444         head = &rsc_cache.hash_table[rsc_hash(item)];
445         ENTRY;
446
447         if (set)
448                 write_lock(&rsc_cache.hash_lock);
449         else
450                 read_lock(&rsc_cache.hash_lock);
451         for (hp = head; *hp != NULL; hp = &tmp->h.next) {
452                 tmp = container_of(*hp, struct rsc, h);
453                 if (!rsc_match(tmp, item))
454                         continue;
455                 cache_get(&tmp->h);
456                 if (!set)
457                         goto out_noset;
458                 *hp = tmp->h.next;
459                 tmp->h.next = NULL;
460                 clear_bit(CACHE_HASHED, &tmp->h.flags);
461                 rsc_put(&tmp->h, &rsc_cache);
462                 goto out_set;
463         }
464         /* Didn't find anything */
465         if (!set)
466                 goto out_nada;
467         rsc_cache.entries++;
468 out_set:
469         set_bit(CACHE_HASHED, &item->h.flags);
470         item->h.next = *head;
471         *head = &item->h;
472         write_unlock(&rsc_cache.hash_lock);
473         cache_fresh(&rsc_cache, &item->h, item->h.expiry_time);
474         cache_get(&item->h);
475         RETURN(item);
476 out_nada:
477         tmp = NULL;
478 out_noset:
479         read_unlock(&rsc_cache.hash_lock);
480         RETURN(tmp);
481 }
482
483 static int rsc_parse(struct cache_detail *cd,
484                      char *mesg, int mlen)
485 {
486         /* contexthandle expiry [ uid gid N <n gids> mechname
487          * ...mechdata... ] */
488         char *buf = mesg;
489         int len, rv, tmp_int;
490         struct rsc *rsci, *res = NULL;
491         time_t expiry;
492         int status = -EINVAL;
493
494         OBD_ALLOC(rsci, sizeof(*rsci));
495         if (!rsci) {
496                 CERROR("fail to alloc rsci\n");
497                 return -ENOMEM;
498         }
499         cache_init(&rsci->h);
500
501         /* context handle */
502         len = qword_get(&mesg, buf, mlen);
503         if (len < 0) goto out;
504         status = -ENOMEM;
505         if (rawobj_alloc(&rsci->handle, buf, len))
506                 goto out;
507
508         /* expiry */
509         expiry = get_expiry(&mesg);
510         status = -EINVAL;
511         if (expiry == 0)
512                 goto out;
513
514         /* remote flag */
515         rv = get_int(&mesg, &tmp_int);
516         if (rv) {
517                 CERROR("fail to get remote flag\n");
518                 goto out;
519         }
520         rsci->remote_realm = (tmp_int != 0);
521
522         /* mds user flag */
523         rv = get_int(&mesg, &tmp_int);
524         if (rv) {
525                 CERROR("fail to get mds user flag\n");
526                 goto out;
527         }
528         rsci->auth_usr_mds = (tmp_int != 0);
529
530         /* oss user flag */
531         rv = get_int(&mesg, &tmp_int);
532         if (rv) {
533                 CERROR("fail to get oss user flag\n");
534                 goto out;
535         }
536         rsci->auth_usr_oss = (tmp_int != 0);
537
538         /* mapped uid */
539         rv = get_int(&mesg, (int *)&rsci->mapped_uid);
540         if (rv) {
541                 CERROR("fail to get mapped uid\n");
542                 goto out;
543         }
544
545         /* uid, or NEGATIVE */
546         rv = get_int(&mesg, (int *)&rsci->cred.vc_uid);
547         if (rv == -EINVAL)
548                 goto out;
549         if (rv == -ENOENT) {
550                 CERROR("NOENT? set rsc entry negative\n");
551                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
552         } else {
553                 struct gss_api_mech *gm;
554                 rawobj_t tmp_buf;
555                 __u64 ctx_expiry;
556
557                 /* gid */
558                 if (get_int(&mesg, (int *)&rsci->cred.vc_gid))
559                         goto out;
560
561                 /* mech name */
562                 len = qword_get(&mesg, buf, mlen);
563                 if (len < 0)
564                         goto out;
565                 gm = kgss_name_to_mech(buf);
566                 status = -EOPNOTSUPP;
567                 if (!gm)
568                         goto out;
569
570                 status = -EINVAL;
571                 /* mech-specific data: */
572                 len = qword_get(&mesg, buf, mlen);
573                 if (len < 0) {
574                         kgss_mech_put(gm);
575                         goto out;
576                 }
577                 tmp_buf.len = len;
578                 tmp_buf.data = (unsigned char *)buf;
579                 if (kgss_import_sec_context(&tmp_buf, gm, &rsci->mechctx)) {
580                         kgss_mech_put(gm);
581                         goto out;
582                 }
583
584                 /* currently the expiry time passed down from user-space
585                  * is invalid, here we retrive it from mech.
586                  */
587                 if (kgss_inquire_context(rsci->mechctx, &ctx_expiry)) {
588                         CERROR("unable to get expire time, drop it\n");
589                         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
590                         kgss_mech_put(gm);
591                         goto out;
592                 }
593                 expiry = (time_t) ((__u32) ctx_expiry);
594
595                 kgss_mech_put(gm);
596         }
597         rsci->h.expiry_time = expiry;
598         spin_lock_init(&rsci->seqdata.sd_lock);
599         res = rsc_lookup(rsci, 1);
600         // XXX temp debugging
601         {
602                 if (res == rsci) {
603                         CWARN("create ctxt %p(%u), expiry %ld (%lds later)\n",
604                               res, *((__u32 *) res->handle.data),
605                               res->h.expiry_time,
606                               res->h.expiry_time - get_seconds());
607                 } else {
608                         CWARN("create ctxts [%p(%u), ex %ld (%lds later)], "
609                               "[%p(%u), ex %ld (%lds later)]\n",
610                               rsci, *((__u32 *) rsci->handle.data),
611                               rsci->h.expiry_time,
612                               rsci->h.expiry_time - get_seconds(),
613                               res, *((__u32 *) res->handle.data),
614                               res->h.expiry_time,
615                               res->h.expiry_time - get_seconds());
616                 }
617         }
618         rsc_put(&res->h, &rsc_cache);
619         status = 0;
620 out:
621         if (rsci)
622                 rsc_put(&rsci->h, &rsc_cache);
623         return status;
624 }
625
626 /*
627  * flush all entries with @uid. @uid == -1 will match all.
628  * we only know the uid, maybe netid/nid in the future, in all cases
629  * we must search the whole cache
630  */
631 static void rsc_flush(uid_t uid)
632 {
633         struct cache_head **ch;
634         struct rsc *rscp;
635         int n;
636         ENTRY;
637
638         if (uid == -1)
639                 CWARN("flush all gss contexts\n");
640
641         write_lock(&rsc_cache.hash_lock);
642         for (n = 0; n < RSC_HASHMAX; n++) {
643                 for (ch = &rsc_cache.hash_table[n]; *ch;) {
644                         rscp = container_of(*ch, struct rsc, h);
645                         if (uid == -1 || rscp->cred.vc_uid == uid) {
646                                 /* it seems simply set NEGATIVE doesn't work */
647                                 *ch = (*ch)->next;
648                                 rscp->h.next = NULL;
649                                 cache_get(&rscp->h);
650                                 set_bit(CACHE_NEGATIVE, &rscp->h.flags);
651                                 clear_bit(CACHE_HASHED, &rscp->h.flags);
652                                 if (uid != -1)
653                                         CWARN("flush rsc %p(%u) for uid %u\n",
654                                               rscp,
655                                               *((__u32 *) rscp->handle.data),
656                                               rscp->cred.vc_uid);
657                                 rsc_put(&rscp->h, &rsc_cache);
658                                 rsc_cache.entries--;
659                                 continue;
660                         }
661                         ch = &((*ch)->next);
662                 }
663         }
664         write_unlock(&rsc_cache.hash_lock);
665         EXIT;
666 }
667
668 static struct cache_detail rsc_cache = {
669         .hash_size      = RSC_HASHMAX,
670         .hash_table     = rsc_table,
671         .name           = "auth.ptlrpcs.context",
672         .cache_put      = rsc_put,
673         .cache_parse    = rsc_parse,
674 };
675
676 static struct rsc *
677 gss_svc_searchbyctx(rawobj_t *handle)
678 {
679         struct rsc rsci;
680         struct rsc *found;
681
682         rsci.handle = *handle;
683         found = rsc_lookup(&rsci, 0);
684         if (!found)
685                 return NULL;
686
687         if (cache_check(&rsc_cache, &found->h, NULL))
688                 return NULL;
689
690         return found;
691 }
692
693 /* FIXME
694  * again hacking: only try to give the svcgssd a chance to handle
695  * upcalls.
696  */
697 struct cache_deferred_req* my_defer(struct cache_req *req)
698 {
699         yield();
700         return NULL;
701 }
702 static struct cache_req my_chandle = {my_defer};
703
704 /* Implements sequence number algorithm as specified in RFC 2203. */
705 static int
706 gss_check_seq_num(struct gss_svc_seq_data *sd, __u32 seq_num)
707 {
708         int rc = 0;
709
710         spin_lock(&sd->sd_lock);
711         if (seq_num > sd->sd_max) {
712                 if (seq_num >= sd->sd_max + GSS_SEQ_WIN) {
713                         memset(sd->sd_win, 0, sizeof(sd->sd_win));
714                         sd->sd_max = seq_num;
715                 } else {
716                         while(sd->sd_max < seq_num) {
717                                 sd->sd_max++;
718                                 __clear_bit(sd->sd_max % GSS_SEQ_WIN,
719                                             sd->sd_win);
720                         }
721                 }
722                 __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win);
723                 goto exit;
724         } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) {
725                 CERROR("seq %u too low: max %u, win %d\n",
726                         seq_num, sd->sd_max, GSS_SEQ_WIN);
727                 rc = 1;
728                 goto exit;
729         }
730
731         if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) {
732                 CERROR("seq %u is replay: max %u, win %d\n",
733                         seq_num, sd->sd_max, GSS_SEQ_WIN);
734                 rc = 1;
735         }
736 exit:
737         spin_unlock(&sd->sd_lock);
738         return rc;
739 }
740
741 static int
742 gss_svc_verify_request(struct ptlrpc_request *req,
743                        struct rsc *rsci,
744                        struct rpc_gss_wire_cred *gc,
745                        __u32 *vp, __u32 vlen)
746 {
747         struct ptlrpcs_wire_hdr *sec_hdr;
748         struct gss_ctx *ctx = rsci->mechctx;
749         __u32 maj_stat;
750         rawobj_t msg;
751         rawobj_t mic;
752         ENTRY;
753
754         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
755
756         req->rq_reqmsg = (struct lustre_msg *) (req->rq_reqbuf + sizeof(*sec_hdr));
757         req->rq_reqlen = sec_hdr->msg_len;
758
759         msg.len = sec_hdr->msg_len;
760         msg.data = (__u8 *)req->rq_reqmsg;
761
762         mic.len = le32_to_cpu(*vp++);
763         mic.data = (unsigned char *)vp;
764         vlen -= 4;
765
766         if (mic.len > vlen) {
767                 CERROR("checksum len %d, while buffer len %d\n",
768                         mic.len, vlen);
769                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
770         }
771
772         if (mic.len > 256) {
773                 CERROR("invalid mic len %d\n", mic.len);
774                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
775         }
776
777         maj_stat = kgss_verify_mic(ctx, &msg, &mic, NULL);
778         if (maj_stat != GSS_S_COMPLETE) {
779                 CERROR("MIC verification error: major %x\n", maj_stat);
780                 RETURN(maj_stat);
781         }
782
783         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
784                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
785                         req, req->rq_reqmsg->opc, req->rq_xid,
786                         req->rq_reqmsg->transno);
787                 RETURN(GSS_S_DUPLICATE_TOKEN);
788         }
789
790         RETURN(GSS_S_COMPLETE);
791 }
792
793 static int
794 gss_svc_unseal_request(struct ptlrpc_request *req,
795                        struct rsc *rsci,
796                        struct rpc_gss_wire_cred *gc,
797                        __u32 *vp, __u32 vlen)
798 {
799         struct ptlrpcs_wire_hdr *sec_hdr;
800         struct gss_ctx *ctx = rsci->mechctx;
801         rawobj_t cipher_text, plain_text;
802         __u32 major;
803         ENTRY;
804
805         sec_hdr = (struct ptlrpcs_wire_hdr *) req->rq_reqbuf;
806
807         if (vlen < 4) {
808                 CERROR("vlen only %u\n", vlen);
809                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
810         }
811
812         cipher_text.len = le32_to_cpu(*vp++);
813         cipher_text.data = (__u8 *) vp;
814         vlen -= 4;
815         
816         if (cipher_text.len > vlen) {
817                 CERROR("cipher claimed %u while buf only %u\n",
818                         cipher_text.len, vlen);
819                 RETURN(GSS_S_CALL_BAD_STRUCTURE);
820         }
821
822         plain_text = cipher_text;
823
824         major = kgss_unwrap(ctx, GSS_C_QOP_DEFAULT, &cipher_text, &plain_text);
825         if (major) {
826                 CERROR("unwrap error 0x%x\n", major);
827                 RETURN(major);
828         }
829
830         if (gss_check_seq_num(&rsci->seqdata, gc->gc_seq)) {
831                 CERROR("discard replayed request %p(o%u,x"LPU64",t"LPU64")\n",
832                         req, req->rq_reqmsg->opc, req->rq_xid,
833                         req->rq_reqmsg->transno);
834                 RETURN(GSS_S_DUPLICATE_TOKEN);
835         }
836
837         req->rq_reqmsg = (struct lustre_msg *) (vp);
838         req->rq_reqlen = plain_text.len;
839
840         CDEBUG(D_SEC, "msg len %d\n", req->rq_reqlen);
841
842         RETURN(GSS_S_COMPLETE);
843 }
844
845 static int
846 gss_pack_err_notify(struct ptlrpc_request *req,
847                     __u32 major, __u32 minor)
848 {
849         struct gss_svc_data *svcdata = req->rq_svcsec_data;
850         __u32 reslen, *resp, *reslenp;
851         char  nidstr[PTL_NALFMT_SIZE];
852         const __u32 secdata_len = 7 * 4;
853         int rc;
854         ENTRY;
855
856         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_ERR_NOTIFY|OBD_FAIL_ONCE, -EINVAL);
857
858         LASSERT(svcdata);
859         svcdata->is_err_notify = 1;
860         svcdata->reserve_len = 7 * 4;
861
862         rc = lustre_pack_reply(req, 0, NULL, NULL);
863         if (rc) {
864                 CERROR("could not pack reply, err %d\n", rc);
865                 RETURN(rc);
866         }
867
868         LASSERT(req->rq_reply_state);
869         LASSERT(req->rq_reply_state->rs_repbuf);
870         LASSERT(req->rq_reply_state->rs_repbuf_len >= secdata_len);
871         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
872
873         /* header */
874         *resp++ = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
875         *resp++ = cpu_to_le32(PTLRPCS_SVC_NONE);
876         *resp++ = cpu_to_le32(req->rq_replen);
877         reslenp = resp++;
878
879         /* skip lustre msg */
880         resp += req->rq_replen / 4;
881         reslen = svcdata->reserve_len;
882
883         /* gss replay:
884          * version, subflavor, notify, major, minor,
885          * obj1(fake), obj2(fake)
886          */
887         *resp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
888         *resp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
889         *resp++ = cpu_to_le32(PTLRPCS_GSS_PROC_ERR);
890         *resp++ = cpu_to_le32(major);
891         *resp++ = cpu_to_le32(minor);
892         *resp++ = 0;
893         *resp++ = 0;
894         reslen -= (4 * 4);
895         /* the actual sec data length */
896         *reslenp = cpu_to_le32(secdata_len);
897
898         req->rq_reply_state->rs_repdata_len += (secdata_len);
899         CDEBUG(D_SEC, "prepare gss error notify(0x%x/0x%x) to %s\n",
900                major, minor,
901                portals_nid2str(req->rq_peer.peer_ni->pni_number,
902                                req->rq_peer.peer_id.nid, nidstr));
903         RETURN(0);
904 }
905
906 static void dump_cache_head(struct cache_head *h)
907 {
908         CWARN("ref %d, fl %lx, n %p, t %ld, %ld\n",
909               atomic_read(&h->refcnt), h->flags, h->next,
910               h->expiry_time, h->last_refresh);
911 }
912 static void dump_rsi(struct rsi *rsi)
913 {
914         CWARN("dump rsi %p\n", rsi);
915         dump_cache_head(&rsi->h);
916         CWARN("%x,%x,%llx\n", rsi->naltype, rsi->netid, rsi->nid);
917         CWARN("len %d, d %p\n", rsi->in_handle.len, rsi->in_handle.data);
918         CWARN("len %d, d %p\n", rsi->in_token.len, rsi->in_token.data);
919         CWARN("len %d, d %p\n", rsi->out_handle.len, rsi->out_handle.data);
920         CWARN("len %d, d %p\n", rsi->out_token.len, rsi->out_token.data);
921 }
922
923 static int
924 gss_svcsec_handle_init(struct ptlrpc_request *req,
925                        struct rpc_gss_wire_cred *gc,
926                        __u32 *secdata, __u32 seclen,
927                        enum ptlrpcs_error *res)
928 {
929         struct gss_svc_data *svcdata = req->rq_svcsec_data;
930         struct rsc          *rsci;
931         struct rsi          *rsikey, *rsip;
932         rawobj_t             tmpobj;
933         __u32 reslen,       *resp, *reslenp;
934         char                 nidstr[PTL_NALFMT_SIZE];
935         int                  rc;
936         ENTRY;
937
938         LASSERT(svcdata);
939
940         CDEBUG(D_SEC, "processing gss init(%d) request from %s\n", gc->gc_proc,
941                portals_nid2str(req->rq_peer.peer_ni->pni_number,
942                                req->rq_peer.peer_id.nid, nidstr));
943
944         *res = PTLRPCS_BADCRED;
945         OBD_FAIL_RETURN(OBD_FAIL_SVCGSS_INIT_REQ|OBD_FAIL_ONCE, SVC_DROP);
946
947         if (gc->gc_proc == RPC_GSS_PROC_INIT &&
948             gc->gc_ctx.len != 0) {
949                 CERROR("proc %d, ctx_len %d: not really init?\n",
950                 gc->gc_proc == RPC_GSS_PROC_INIT, gc->gc_ctx.len);
951                 RETURN(SVC_DROP);
952         }
953
954         OBD_ALLOC(rsikey, sizeof(*rsikey));
955         if (!rsikey) {
956                 CERROR("out of memory\n");
957                 RETURN(SVC_DROP);
958         }
959         cache_init(&rsikey->h);
960
961         /* obtain lustre svc type */
962         if (seclen < 4) {
963                 CERROR("sec size %d too small\n", seclen);
964                 GOTO(out_rsikey, rc = SVC_DROP);
965         }
966         rsikey->lustre_svc = le32_to_cpu(*secdata++);
967         seclen -= 4;
968
969         /* duplicate context handle. currently always 0 */
970         if (rawobj_dup(&rsikey->in_handle, &gc->gc_ctx)) {
971                 CERROR("fail to dup context handle\n");
972                 GOTO(out_rsikey, rc = SVC_DROP);
973         }
974
975         /* extract token */
976         *res = PTLRPCS_BADVERF;
977         if (rawobj_extract(&tmpobj, &secdata, &seclen)) {
978                 CERROR("can't extract token\n");
979                 GOTO(out_rsikey, rc = SVC_DROP);
980         }
981         if (rawobj_dup(&rsikey->in_token, &tmpobj)) {
982                 CERROR("can't duplicate token\n");
983                 GOTO(out_rsikey, rc = SVC_DROP);
984         }
985
986         rsikey->naltype = (__u32) req->rq_peer.peer_ni->pni_number;
987         rsikey->netid = 0;
988         rsikey->nid = (__u64) req->rq_peer.peer_id.nid;
989
990         rsip = gssd_upcall(rsikey, &my_chandle);
991         if (!rsip) {
992                 CERROR("error in gssd_upcall.\n");
993
994                 rc = SVC_COMPLETE;
995                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
996                         rc = SVC_DROP;
997
998                 GOTO(out_rsikey, rc);
999         }
1000
1001         rsci = gss_svc_searchbyctx(&rsip->out_handle);
1002         if (!rsci) {
1003                 CERROR("rsci still not mature yet?\n");
1004
1005                 rc = SVC_COMPLETE;
1006                 if (gss_pack_err_notify(req, GSS_S_FAILURE, 0))
1007                         rc = SVC_DROP;
1008
1009                 GOTO(out_rsip, rc);
1010         }
1011         CDEBUG(D_SEC, "svcsec create gss context %p(%u@%s)\n",
1012                rsci, rsci->cred.vc_uid,
1013                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1014                                req->rq_peer.peer_id.nid, nidstr));
1015
1016         svcdata->is_init = 1;
1017         svcdata->reserve_len = 6 * 4 +
1018                 size_round4(rsip->out_handle.len) +
1019                 size_round4(rsip->out_token.len);
1020
1021         rc = lustre_pack_reply(req, 0, NULL, NULL);
1022         if (rc) {
1023                 CERROR("failed to pack reply, rc = %d\n", rc);
1024                 set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1025                 GOTO(out, rc = SVC_DROP);
1026         }
1027
1028         /* header */
1029         resp = (__u32 *) req->rq_reply_state->rs_repbuf;
1030         *resp++ = cpu_to_le32(PTLRPCS_FLVR_GSS_NONE);
1031         *resp++ = cpu_to_le32(PTLRPCS_SVC_NONE);
1032         *resp++ = cpu_to_le32(req->rq_replen);
1033         reslenp = resp++;
1034
1035         resp += req->rq_replen / 4;
1036         reslen = svcdata->reserve_len;
1037
1038         /* gss reply:
1039          * status, major, minor, seq, out_handle, out_token
1040          */
1041         *resp++ = cpu_to_le32(PTLRPCS_OK);
1042         *resp++ = cpu_to_le32(rsip->major_status);
1043         *resp++ = cpu_to_le32(rsip->minor_status);
1044         *resp++ = cpu_to_le32(GSS_SEQ_WIN);
1045         reslen -= (4 * 4);
1046         if (rawobj_serialize(&rsip->out_handle,
1047                              &resp, &reslen)) {
1048                 dump_rsi(rsip);
1049                 dump_rsi(rsikey);
1050                 LBUG();
1051         }
1052         if (rawobj_serialize(&rsip->out_token,
1053                              &resp, &reslen)) {
1054                 dump_rsi(rsip);
1055                 dump_rsi(rsikey);
1056                 LBUG();
1057         }
1058         /* the actual sec data length */
1059         *reslenp = cpu_to_le32(svcdata->reserve_len - reslen);
1060
1061         req->rq_reply_state->rs_repdata_len += le32_to_cpu(*reslenp);
1062         CDEBUG(D_SEC, "req %p: msgsize %d, authsize %d, "
1063                "total size %d\n", req, req->rq_replen,
1064                le32_to_cpu(*reslenp),
1065                req->rq_reply_state->rs_repdata_len);
1066
1067         *res = PTLRPCS_OK;
1068
1069         req->rq_remote_realm = rsci->remote_realm;
1070         req->rq_auth_usr_mds = rsci->auth_usr_mds;
1071         req->rq_auth_usr_oss = rsci->auth_usr_oss;
1072         req->rq_auth_uid = rsci->cred.vc_uid;
1073         req->rq_mapped_uid = rsci->mapped_uid;
1074
1075         if (req->rq_auth_usr_mds) {
1076                 CWARN("usr from %s authenticated as mds svc cred\n",
1077                 portals_nid2str(req->rq_peer.peer_ni->pni_number,
1078                                 req->rq_peer.peer_id.nid, nidstr));
1079         }
1080         if (req->rq_auth_usr_oss) {
1081                 CWARN("usr from %s authenticated as oss svc cred\n",
1082                 portals_nid2str(req->rq_peer.peer_ni->pni_number,
1083                                 req->rq_peer.peer_id.nid, nidstr));
1084         }
1085
1086         /* This is simplified since right now we doesn't support
1087          * INIT_CONTINUE yet.
1088          */
1089         if (gc->gc_proc == RPC_GSS_PROC_INIT) {
1090                 struct ptlrpcs_wire_hdr *hdr;
1091
1092                 hdr = buf_to_sec_hdr(req->rq_reqbuf);
1093                 req->rq_reqmsg = buf_to_lustre_msg(req->rq_reqbuf);
1094                 req->rq_reqlen = hdr->msg_len;
1095
1096                 rc = SVC_LOGIN;
1097         } else
1098                 rc = SVC_COMPLETE;
1099
1100 out:
1101         rsc_put(&rsci->h, &rsc_cache);
1102 out_rsip:
1103         rsi_put(&rsip->h, &rsi_cache);
1104 out_rsikey:
1105         rsi_put(&rsikey->h, &rsi_cache);
1106
1107         RETURN(rc);
1108 }
1109
1110 static int
1111 gss_svcsec_handle_data(struct ptlrpc_request *req,
1112                        struct rpc_gss_wire_cred *gc,
1113                        __u32 *secdata, __u32 seclen,
1114                        enum ptlrpcs_error *res)
1115 {
1116         struct rsc          *rsci;
1117         char                 nidstr[PTL_NALFMT_SIZE];
1118         __u32                major;
1119         int                  rc;
1120         ENTRY;
1121
1122         *res = PTLRPCS_GSS_CREDPROBLEM;
1123
1124         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1125         if (!rsci) {
1126                 CWARN("Invalid gss context handle from %s\n",
1127                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1128                                        req->rq_peer.peer_id.nid, nidstr));
1129                 major = GSS_S_NO_CONTEXT;
1130                 goto notify_err;
1131         }
1132
1133         switch (gc->gc_svc) {
1134         case PTLRPCS_GSS_SVC_INTEGRITY:
1135                 major = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1136                 if (major == GSS_S_COMPLETE)
1137                         break;
1138
1139                 CWARN("fail in verify:0x%x: ctx %p@%s\n", major, rsci,
1140                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1141                                        req->rq_peer.peer_id.nid, nidstr));
1142                 goto notify_err;
1143         case PTLRPCS_GSS_SVC_PRIVACY:
1144                 major = gss_svc_unseal_request(req, rsci, gc, secdata, seclen);
1145                 if (major == GSS_S_COMPLETE)
1146                         break;
1147
1148                 CWARN("fail in decrypt:0x%x: ctx %p@%s\n", major, rsci,
1149                        portals_nid2str(req->rq_peer.peer_ni->pni_number,
1150                                        req->rq_peer.peer_id.nid, nidstr));
1151                 goto notify_err;
1152         default:
1153                 CERROR("unsupported gss service %d\n", gc->gc_svc);
1154                 GOTO(out, rc = SVC_DROP);
1155         }
1156
1157         req->rq_remote_realm = rsci->remote_realm;
1158         req->rq_auth_usr_mds = rsci->auth_usr_mds;
1159         req->rq_auth_usr_oss = rsci->auth_usr_oss;
1160         req->rq_auth_uid = rsci->cred.vc_uid;
1161         req->rq_mapped_uid = rsci->mapped_uid;
1162
1163         *res = PTLRPCS_OK;
1164         GOTO(out, rc = SVC_OK);
1165
1166 notify_err:
1167         if (gss_pack_err_notify(req, major, 0))
1168                 rc = SVC_DROP;
1169         else
1170                 rc = SVC_COMPLETE;
1171 out:
1172         if (rsci)
1173                 rsc_put(&rsci->h, &rsc_cache);
1174         RETURN(rc);
1175 }
1176
1177 static int
1178 gss_svcsec_handle_destroy(struct ptlrpc_request *req,
1179                           struct rpc_gss_wire_cred *gc,
1180                           __u32 *secdata, __u32 seclen,
1181                           enum ptlrpcs_error *res)
1182 {
1183         struct gss_svc_data *svcdata = req->rq_svcsec_data;
1184         struct rsc          *rsci;
1185         char                 nidstr[PTL_NALFMT_SIZE];
1186         int                  rc;
1187         ENTRY;
1188
1189         LASSERT(svcdata);
1190         *res = PTLRPCS_GSS_CREDPROBLEM;
1191
1192         rsci = gss_svc_searchbyctx(&gc->gc_ctx);
1193         if (!rsci) {
1194                 CWARN("invalid gss context handle for destroy.\n");
1195                 RETURN(SVC_DROP);
1196         }
1197
1198         if (gc->gc_svc != PTLRPCS_GSS_SVC_INTEGRITY) {
1199                 CERROR("service %d is not supported in destroy.\n",
1200                         gc->gc_svc);
1201                 GOTO(out, rc = SVC_DROP);
1202         }
1203
1204         *res = gss_svc_verify_request(req, rsci, gc, secdata, seclen);
1205         if (*res)
1206                 GOTO(out, rc = SVC_DROP);
1207
1208         /* compose reply, which is actually nothing */
1209         svcdata->is_fini = 1;
1210         if (lustre_pack_reply(req, 0, NULL, NULL))
1211                 GOTO(out, rc = SVC_DROP);
1212
1213         CDEBUG(D_SEC, "svcsec destroy gss context %p(%u@%s)\n",
1214                rsci, rsci->cred.vc_uid,
1215                portals_nid2str(req->rq_peer.peer_ni->pni_number,
1216                                req->rq_peer.peer_id.nid, nidstr));
1217
1218         //XXX temp for debugging
1219         {
1220                 CWARN("destroy ctxt %p(%u/%u)@%s\n",
1221                       rsci, *((__u32 *) rsci->handle.data),
1222                       rsci->cred.vc_uid,
1223                       portals_nid2str(req->rq_peer.peer_ni->pni_number,
1224                                       req->rq_peer.peer_id.nid, nidstr));
1225         }
1226         set_bit(CACHE_NEGATIVE, &rsci->h.flags);
1227         *res = PTLRPCS_OK;
1228         rc = SVC_LOGOUT;
1229 out:
1230         rsc_put(&rsci->h, &rsc_cache);
1231         RETURN(rc);
1232 }
1233
1234 /*
1235  * let incomming request go through security check:
1236  *  o context establishment: invoke user space helper
1237  *  o data exchange: verify/decrypt
1238  *  o context destruction: mark context invalid
1239  *
1240  * in most cases, error will result to drop the packet silently.
1241  */
1242 static int
1243 gss_svcsec_accept(struct ptlrpc_request *req, enum ptlrpcs_error *res)
1244 {
1245         struct gss_svc_data *svcdata;
1246         struct rpc_gss_wire_cred *gc;
1247         struct ptlrpcs_wire_hdr *sec_hdr;
1248         __u32 subflavor, seclen, *secdata, version;
1249         int rc;
1250         ENTRY;
1251
1252         CDEBUG(D_SEC, "request %p\n", req);
1253         LASSERT(req->rq_reqbuf);
1254         LASSERT(req->rq_reqbuf_len);
1255
1256         *res = PTLRPCS_BADCRED;
1257
1258         sec_hdr = buf_to_sec_hdr(req->rq_reqbuf);
1259         LASSERT(SEC_FLAVOR_MAJOR(sec_hdr->flavor) == PTLRPCS_FLVR_MAJOR_GSS);
1260
1261         seclen = req->rq_reqbuf_len - sizeof(*sec_hdr) - sec_hdr->msg_len;
1262         secdata = (__u32 *) buf_to_sec_data(req->rq_reqbuf);
1263
1264         if (sec_hdr->sec_len > seclen) {
1265                 CERROR("seclen %d, while max buf %d\n",
1266                         sec_hdr->sec_len, seclen);
1267                 RETURN(SVC_DROP);
1268         }
1269
1270         if (seclen < 6 * 4) {
1271                 CERROR("sec size %d too small\n", seclen);
1272                 RETURN(SVC_DROP);
1273         }
1274
1275         LASSERT(!req->rq_svcsec_data);
1276         OBD_ALLOC(svcdata, sizeof(*svcdata));
1277         if (!svcdata) {
1278                 CERROR("fail to alloc svcdata\n");
1279                 RETURN(SVC_DROP);
1280         }
1281         req->rq_svcsec_data = svcdata;
1282         gc = &svcdata->clcred;
1283
1284         /* Now secdata/seclen is what we want to parse
1285          */
1286         version = le32_to_cpu(*secdata++);      /* version */
1287         subflavor = le32_to_cpu(*secdata++);    /* subflavor */
1288         gc->gc_proc = le32_to_cpu(*secdata++);  /* proc */
1289         gc->gc_seq = le32_to_cpu(*secdata++);   /* seq */
1290         gc->gc_svc = le32_to_cpu(*secdata++);   /* service */
1291         seclen -= 5 * 4;
1292
1293         CDEBUG(D_SEC, "wire gss_hdr: %u/%u/%u/%u/%u\n",
1294                version, subflavor, gc->gc_proc,
1295                gc->gc_seq, gc->gc_svc);
1296
1297         if (version != PTLRPC_SEC_GSS_VERSION) {
1298                 CERROR("gss version mismatch: %d - %d\n",
1299                         version, PTLRPC_SEC_GSS_VERSION);
1300                 GOTO(err_free, rc = SVC_DROP);
1301         }
1302
1303         if (rawobj_extract(&gc->gc_ctx, &secdata, &seclen)) {
1304                 CERROR("fail to obtain gss context handle\n");
1305                 GOTO(err_free, rc = SVC_DROP);
1306         }
1307
1308         *res = PTLRPCS_BADVERF;
1309         switch(gc->gc_proc) {
1310         case RPC_GSS_PROC_INIT:
1311         case RPC_GSS_PROC_CONTINUE_INIT:
1312                 rc = gss_svcsec_handle_init(req, gc, secdata, seclen, res);
1313                 break;
1314         case RPC_GSS_PROC_DATA:
1315                 rc = gss_svcsec_handle_data(req, gc, secdata, seclen, res);
1316                 break;
1317         case RPC_GSS_PROC_DESTROY:
1318                 rc = gss_svcsec_handle_destroy(req, gc, secdata, seclen, res);
1319                 break;
1320         default:
1321                 rc = SVC_DROP;
1322                 LBUG();
1323         }
1324
1325 err_free:
1326         if (rc == SVC_DROP && req->rq_svcsec_data) {
1327                 OBD_FREE(req->rq_svcsec_data, sizeof(struct gss_svc_data));
1328                 req->rq_svcsec_data = NULL;
1329         }
1330
1331         RETURN(rc);
1332 }
1333
1334 static int
1335 gss_svcsec_authorize(struct ptlrpc_request *req)
1336 {
1337         struct ptlrpc_reply_state *rs = req->rq_reply_state;
1338         struct gss_svc_data *gsd = (struct gss_svc_data *)req->rq_svcsec_data;
1339         struct rpc_gss_wire_cred  *gc = &gsd->clcred;
1340         struct rsc                *rscp;
1341         struct ptlrpcs_wire_hdr   *sec_hdr;
1342         rawobj_buf_t               msg_buf;
1343         rawobj_t                   cipher_buf;
1344         __u32                     *vp, *vpsave, major, vlen, seclen;
1345         rawobj_t                   lmsg, mic;
1346         int                        ret;
1347         ENTRY;
1348
1349         LASSERT(rs);
1350         LASSERT(rs->rs_repbuf);
1351         LASSERT(gsd);
1352
1353         if (gsd->is_init || gsd->is_init_continue ||
1354             gsd->is_err_notify || gsd->is_fini) {
1355                 /* nothing to do in these cases */
1356                 CDEBUG(D_SEC, "req %p: init/fini/err\n", req);
1357                 RETURN(0);
1358         }
1359
1360         if (gc->gc_proc != RPC_GSS_PROC_DATA) {
1361                 CERROR("proc %d not support\n", gc->gc_proc);
1362                 RETURN(-EINVAL);
1363         }
1364
1365         rscp = gss_svc_searchbyctx(&gc->gc_ctx);
1366         if (!rscp) {
1367                 CERROR("ctx %u disapeared under us\n",
1368                        *((__u32 *) gc->gc_ctx.data));
1369                 RETURN(-EINVAL);
1370         }
1371
1372         sec_hdr = (struct ptlrpcs_wire_hdr *) rs->rs_repbuf;
1373         switch (gc->gc_svc) {
1374         case  PTLRPCS_GSS_SVC_INTEGRITY:
1375                 /* prepare various pointers */
1376                 lmsg.len = req->rq_replen;
1377                 lmsg.data = (__u8 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1378                 vp = (__u32 *) (lmsg.data + lmsg.len);
1379                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr) - lmsg.len;
1380                 seclen = vlen;
1381
1382                 sec_hdr->flavor = cpu_to_le32(PTLRPCS_FLVR_GSS_AUTH);
1383                 sec_hdr->msg_len = cpu_to_le32(req->rq_replen);
1384
1385                 /* standard gss hdr */
1386                 LASSERT(vlen >= 7 * 4);
1387                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1388                 *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
1389                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1390                 *vp++ = cpu_to_le32(gc->gc_seq);
1391                 *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_INTEGRITY);
1392                 *vp++ = 0;      /* fake ctx handle */
1393                 vpsave = vp++;  /* reserve size */
1394                 vlen -= 7 * 4;
1395
1396                 mic.len = vlen;
1397                 mic.data = (unsigned char *)vp;
1398
1399                 major = kgss_get_mic(rscp->mechctx, 0, &lmsg, &mic);
1400                 if (major) {
1401                         CERROR("fail to get MIC: 0x%x\n", major);
1402                         GOTO(out, ret = -EINVAL);
1403                 }
1404                 *vpsave = cpu_to_le32(mic.len);
1405                 seclen = seclen - vlen + mic.len;
1406                 sec_hdr->sec_len = cpu_to_le32(seclen);
1407                 rs->rs_repdata_len += size_round(seclen);
1408                 break;
1409         case  PTLRPCS_GSS_SVC_PRIVACY:
1410                 vp = (__u32 *) (rs->rs_repbuf + sizeof(*sec_hdr));
1411                 vlen = rs->rs_repbuf_len - sizeof(*sec_hdr);
1412                 seclen = vlen;
1413
1414                 sec_hdr->flavor = cpu_to_le32(PTLRPCS_FLVR_GSS_PRIV);
1415                 sec_hdr->msg_len = cpu_to_le32(0);
1416
1417                 /* standard gss hdr */
1418                 LASSERT(vlen >= 7 * 4);
1419                 *vp++ = cpu_to_le32(PTLRPC_SEC_GSS_VERSION);
1420                 *vp++ = cpu_to_le32(PTLRPCS_FLVR_KRB5I);
1421                 *vp++ = cpu_to_le32(RPC_GSS_PROC_DATA);
1422                 *vp++ = cpu_to_le32(gc->gc_seq);
1423                 *vp++ = cpu_to_le32(PTLRPCS_GSS_SVC_PRIVACY);
1424                 *vp++ = 0;      /* fake ctx handle */
1425                 vpsave = vp++;  /* reserve size */
1426                 vlen -= 7 * 4;
1427
1428                 msg_buf.buf = (__u8 *) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1429                 msg_buf.buflen = req->rq_replen + GSS_PRIVBUF_PREFIX_LEN +
1430                                  GSS_PRIVBUF_SUFFIX_LEN;
1431                 msg_buf.dataoff = GSS_PRIVBUF_PREFIX_LEN;
1432                 msg_buf.datalen = req->rq_replen;
1433
1434                 cipher_buf.data = (__u8 *) vp;
1435                 cipher_buf.len = vlen;
1436
1437                 major = kgss_wrap(rscp->mechctx, GSS_C_QOP_DEFAULT,
1438                                 &msg_buf, &cipher_buf);
1439                 if (major) {
1440                         CERROR("failed to wrap: 0x%x\n", major);
1441                         GOTO(out, ret = -EINVAL);
1442                 }
1443
1444                 *vpsave = cpu_to_le32(cipher_buf.len);
1445                 seclen = seclen - vlen + cipher_buf.len;
1446                 sec_hdr->sec_len = cpu_to_le32(seclen);
1447                 rs->rs_repdata_len += size_round(seclen);
1448                 break;
1449         default:
1450                 CERROR("Unknown service %d\n", gc->gc_svc);
1451                 GOTO(out, ret = -EINVAL);
1452         }
1453         ret = 0;
1454 out:
1455         rsc_put(&rscp->h, &rsc_cache);
1456
1457         RETURN(ret);
1458 }
1459
1460 static
1461 void gss_svcsec_cleanup_req(struct ptlrpc_svcsec *svcsec,
1462                             struct ptlrpc_request *req)
1463 {
1464         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_svcsec_data;
1465
1466         if (!gsd) {
1467                 CDEBUG(D_SEC, "no svc_data present. do nothing\n");
1468                 return;
1469         }
1470
1471         /* gsd->clclred.gc_ctx is NOT allocated, just set pointer
1472          * to the incoming packet buffer, so don't need free it
1473          */
1474         OBD_FREE(gsd, sizeof(*gsd));
1475         req->rq_svcsec_data = NULL;
1476         return;
1477 }
1478
1479 static
1480 int gss_svcsec_est_payload(struct ptlrpc_svcsec *svcsec,
1481                            struct ptlrpc_request *req,
1482                            int msgsize)
1483 {
1484         struct gss_svc_data *svcdata = req->rq_svcsec_data;
1485         ENTRY;
1486
1487         /* just return the pre-set reserve_len for init/fini/err cases.
1488          */
1489         LASSERT(svcdata);
1490         if (svcdata->is_init) {
1491                 CDEBUG(D_SEC, "is_init, reserver size %d(%d)\n",
1492                        size_round(svcdata->reserve_len),
1493                        svcdata->reserve_len);
1494                 LASSERT(svcdata->reserve_len);
1495                 LASSERT(svcdata->reserve_len % 4 == 0);
1496                 RETURN(size_round(svcdata->reserve_len));
1497         } else if (svcdata->is_err_notify) {
1498                 CDEBUG(D_SEC, "is_err_notify, reserver size %d(%d)\n",
1499                        size_round(svcdata->reserve_len),
1500                        svcdata->reserve_len);
1501                 RETURN(size_round(svcdata->reserve_len));
1502         } else if (svcdata->is_fini) {
1503                 CDEBUG(D_SEC, "is_fini, reserver size 0\n");
1504                 RETURN(0);
1505         } else {
1506                 if (svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_NONE ||
1507                     svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_INTEGRITY)
1508                         RETURN(size_round(GSS_MAX_AUTH_PAYLOAD));
1509                 else if (svcdata->clcred.gc_svc == PTLRPCS_GSS_SVC_PRIVACY)
1510                         RETURN(size_round16(GSS_MAX_AUTH_PAYLOAD + msgsize +
1511                                             GSS_PRIVBUF_PREFIX_LEN +
1512                                             GSS_PRIVBUF_SUFFIX_LEN));
1513                 else {
1514                         CERROR("unknown gss svc %u\n", svcdata->clcred.gc_svc);
1515                         *((int *)0) = 0;
1516                         LBUG();
1517                 }
1518         }
1519         RETURN(0);
1520 }
1521
1522 int gss_svcsec_alloc_repbuf(struct ptlrpc_svcsec *svcsec,
1523                             struct ptlrpc_request *req,
1524                             int msgsize)
1525 {
1526         struct gss_svc_data *gsd = (struct gss_svc_data *) req->rq_svcsec_data;
1527         struct ptlrpc_reply_state *rs;
1528         int msg_payload, sec_payload;
1529         int privacy, rc;
1530         ENTRY;
1531
1532         /* determine the security type: none/auth or priv, we have
1533          * different pack scheme for them.
1534          * init/fini/err will always be treated as none/auth.
1535          */
1536         LASSERT(gsd);
1537         if (!gsd->is_init && !gsd->is_init_continue &&
1538             !gsd->is_fini && !gsd->is_err_notify &&
1539             gsd->clcred.gc_svc == PTLRPCS_GSS_SVC_PRIVACY)
1540                 privacy = 1;
1541         else
1542                 privacy = 0;
1543
1544         msg_payload = privacy ? 0 : msgsize;
1545         sec_payload = gss_svcsec_est_payload(svcsec, req, msgsize);
1546
1547         rc = svcsec_alloc_reply_state(req, msg_payload, sec_payload);
1548         if (rc)
1549                 RETURN(rc);
1550
1551         rs = req->rq_reply_state;
1552         LASSERT(rs);
1553         rs->rs_msg_len = msgsize;
1554
1555         if (privacy) {
1556                 /* we can choose to let msg simply point to the rear of the
1557                  * buffer, which lead to buffer overlap when doing encryption.
1558                  * usually it's ok and it indeed passed all existing tests.
1559                  * but not sure if there will be subtle problems in the future.
1560                  * so right now we choose to alloc another new buffer. we'll
1561                  * see how it works.
1562                  */
1563 #if 0
1564                 rs->rs_msg = (struct lustre_msg *)
1565                              (rs->rs_repbuf + rs->rs_repbuf_len -
1566                               msgsize - GSS_PRIVBUF_SUFFIX_LEN);
1567 #endif
1568                 char *msgbuf;
1569
1570                 msgsize += GSS_PRIVBUF_PREFIX_LEN + GSS_PRIVBUF_SUFFIX_LEN;
1571                 OBD_ALLOC(msgbuf, msgsize);
1572                 if (!msgbuf) {
1573                         CERROR("can't alloc %d\n", msgsize);
1574                         svcsec_free_reply_state(rs);
1575                         req->rq_reply_state = NULL;
1576                         RETURN(-ENOMEM);
1577                 }
1578                 rs->rs_msg = (struct lustre_msg *)
1579                                 (msgbuf + GSS_PRIVBUF_PREFIX_LEN);
1580         }
1581
1582         req->rq_repmsg = rs->rs_msg;
1583
1584         RETURN(0);
1585 }
1586
1587 static
1588 void gss_svcsec_free_repbuf(struct ptlrpc_svcsec *svcsec,
1589                             struct ptlrpc_reply_state *rs)
1590 {
1591         unsigned long p1 = (unsigned long) rs->rs_msg;
1592         unsigned long p2 = (unsigned long) rs->rs_buf;
1593
1594         LASSERT(rs->rs_buf);
1595         LASSERT(rs->rs_msg);
1596         LASSERT(rs->rs_msg_len);
1597
1598         if (p1 < p2 || p1 >= p2 + rs->rs_buf_len) {
1599                 char *start = (char*) rs->rs_msg - GSS_PRIVBUF_PREFIX_LEN;
1600                 int size = rs->rs_msg_len + GSS_PRIVBUF_PREFIX_LEN +
1601                            GSS_PRIVBUF_SUFFIX_LEN;
1602                 OBD_FREE(start, size);
1603         }
1604
1605         svcsec_free_reply_state(rs);
1606 }
1607
1608 struct ptlrpc_svcsec svcsec_gss = {
1609         .pss_owner              = THIS_MODULE,
1610         .pss_name               = "svcsec.gss",
1611         .pss_flavor             = PTLRPCS_FLVR_MAJOR_GSS,
1612         .accept                 = gss_svcsec_accept,
1613         .authorize              = gss_svcsec_authorize,
1614         .alloc_repbuf           = gss_svcsec_alloc_repbuf,
1615         .free_repbuf            = gss_svcsec_free_repbuf,
1616         .cleanup_req            = gss_svcsec_cleanup_req,
1617 };
1618
1619 /* XXX hacking */
1620 void lgss_svc_cache_purge_all(void)
1621 {
1622         cache_purge(&rsi_cache);
1623         cache_purge(&rsc_cache);
1624 }
1625 EXPORT_SYMBOL(lgss_svc_cache_purge_all);
1626
1627 void lgss_svc_cache_flush(__u32 uid)
1628 {
1629         rsc_flush(uid);
1630 }
1631 EXPORT_SYMBOL(lgss_svc_cache_flush);
1632
1633 int gss_svc_init(void)
1634 {
1635         int rc;
1636
1637         rc = svcsec_register(&svcsec_gss);
1638         if (!rc) {
1639                 cache_register(&rsc_cache);
1640                 cache_register(&rsi_cache);
1641         }
1642         return rc;
1643 }
1644
1645 void gss_svc_exit(void)
1646 {
1647         int rc;
1648         if ((rc = cache_unregister(&rsi_cache)))
1649                 CERROR("unregister rsi cache: %d\n", rc);
1650         if ((rc = cache_unregister(&rsc_cache)))
1651                 CERROR("unregister rsc cache: %d\n", rc);
1652         if ((rc = svcsec_unregister(&svcsec_gss)))
1653                 CERROR("unregister svcsec_gss: %d\n", rc);
1654 }