Whamcloud - gitweb
58946749b2eeaa365de7765a754b1a96799d25a4
[fs/lustre-release.git] / lustre / ptlrpc / gss / gss_svc_upcall.c
1 /*
2  * Modifications for Lustre
3  *
4  * Copyright (c) 2007, 2010, Oracle and/or its affiliates. All rights reserved.
5  *
6  * Copyright (c) 2012, 2014, Intel Corporation.
7  *
8  * Author: Eric Mei <ericm@clusterfs.com>
9  */
10
11 /*
12  * Neil Brown <neilb@cse.unsw.edu.au>
13  * J. Bruce Fields <bfields@umich.edu>
14  * Andy Adamson <andros@umich.edu>
15  * Dug Song <dugsong@monkey.org>
16  *
17  * RPCSEC_GSS server authentication.
18  * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078
19  * (gssapi)
20  *
21  * The RPCSEC_GSS involves three stages:
22  *  1/ context creation
23  *  2/ data exchange
24  *  3/ context destruction
25  *
26  * Context creation is handled largely by upcalls to user-space.
27  *  In particular, GSS_Accept_sec_context is handled by an upcall
28  * Data exchange is handled entirely within the kernel
29  *  In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel.
30  * Context destruction is handled in-kernel
31  *  GSS_Delete_sec_context is in-kernel
32  *
33  * Context creation is initiated by a RPCSEC_GSS_INIT request arriving.
34  * The context handle and gss_token are used as a key into the rpcsec_init cache.
35  * The content of this cache includes some of the outputs of GSS_Accept_sec_context,
36  * being major_status, minor_status, context_handle, reply_token.
37  * These are sent back to the client.
38  * Sequence window management is handled by the kernel.  The window size if currently
39  * a compile time constant.
40  *
41  * When user-space is happy that a context is established, it places an entry
42  * in the rpcsec_context cache. The key for this cache is the context_handle.
43  * The content includes:
44  *   uid/gidlist - for determining access rights
45  *   mechanism type
46  *   mechanism specific information, such as a key
47  *
48  */
49
50 #define DEBUG_SUBSYSTEM S_SEC
51 #include <linux/types.h>
52 #include <linux/init.h>
53 #include <linux/module.h>
54 #include <linux/random.h>
55 #include <linux/slab.h>
56 #include <linux/mutex.h>
57 #include <linux/sunrpc/cache.h>
58 #include <linux/binfmts.h>
59 #include <net/sock.h>
60 #include <linux/un.h>
61
62 #include <obd.h>
63 #include <obd_class.h>
64 #include <obd_support.h>
65 #include <lustre_import.h>
66 #include <lustre_net.h>
67 #include <lustre_nodemap.h>
68 #include <lustre_sec.h>
69 #include <libcfs/linux/linux-hash.h>
70
71 #include "gss_err.h"
72 #include "gss_internal.h"
73 #include "gss_api.h"
74 #include "gss_crypto.h"
75
76 #ifndef HAVE_GET_EXPIRY_2ARGS
77 static inline int __get_expiry2(char **bpp, time64_t *rvp)
78 {
79         *rvp = get_expiry(bpp);
80         return *rvp ? 0 : -EINVAL;
81 }
82 #define get_expiry(ps, pt)      __get_expiry2((ps), (pt))
83 #endif
84
85 #define GSS_SVC_UPCALL_TIMEOUT  (20)
86
87 static DEFINE_SPINLOCK(__ctx_index_lock);
88 static __u64 __ctx_index;
89
90 unsigned int krb5_allow_old_client_csum;
91
92 __u64 gss_get_next_ctx_index(void)
93 {
94         __u64 idx;
95
96         spin_lock(&__ctx_index_lock);
97         idx = __ctx_index++;
98         spin_unlock(&__ctx_index_lock);
99
100         return idx;
101 }
102
103 static inline unsigned long hash_mem(char *buf, int length, int bits)
104 {
105         unsigned long hash = 0;
106         unsigned long l = 0;
107         int len = 0;
108         unsigned char c;
109
110         do {
111                 if (len == length) {
112                         c = (char) len;
113                         len = -1;
114                 } else
115                         c = *buf++;
116
117                 l = (l << 8) | c;
118                 len++;
119
120                 if ((len & (BITS_PER_LONG/8-1)) == 0)
121                         hash = cfs_hash_long(hash^l, BITS_PER_LONG);
122         } while (len);
123
124         return hash >> (BITS_PER_LONG - bits);
125 }
126
127 /* This is a little bit of a concern but we need to make our own hash64 function
128  * as the one from the kernel seems to be buggy by returning a u32:
129  * static __always_inline u32 hash_64_generic(u64 val, unsigned int bits)
130  */
131 #if BITS_PER_LONG == 64
132 static __always_inline __u64 gss_hash_64(__u64 val, unsigned int bits)
133 {
134         __u64 hash = val;
135         /*  Sigh, gcc can't optimise this alone like it does for 32 bits. */
136         __u64 n = hash;
137
138         n <<= 18;
139         hash -= n;
140         n <<= 33;
141         hash -= n;
142         n <<= 3;
143         hash += n;
144         n <<= 3;
145         hash -= n;
146         n <<= 4;
147         hash += n;
148         n <<= 2;
149         hash += n;
150
151         /* High bits are more random, so use them. */
152         return hash >> (64 - bits);
153 }
154
155 static inline unsigned long hash_mem_64(char *buf, int length, int bits)
156 {
157         unsigned long hash = 0;
158         unsigned long l = 0;
159         int len = 0;
160         unsigned char c;
161
162         do {
163                 if (len == length) {
164                         c = (char) len;
165                         len = -1;
166                 } else
167                         c = *buf++;
168
169                 l = (l << 8) | c;
170                 len++;
171
172                 if ((len & (BITS_PER_LONG/8-1)) == 0)
173                         hash = gss_hash_64(hash^l, BITS_PER_LONG);
174         } while (len);
175
176         return hash >> (BITS_PER_LONG - bits);
177 }
178 #endif /* BITS_PER_LONG == 64 */
179
180 /****************************************
181  * rpc sec init (rsi) cache             *
182  ****************************************/
183
184 #define RSI_HASHBITS    (6)
185 #define RSI_HASHMAX     (1 << RSI_HASHBITS)
186 #define RSI_HASHMASK    (RSI_HASHMAX - 1)
187
188 static void rsi_entry_init(struct upcall_cache_entry *entry,
189                            void *args)
190 {
191         struct gss_rsi *rsi = &entry->u.rsi;
192         struct gss_rsi *tmp = args;
193
194         rsi->si_uc_entry = entry;
195         rawobj_dup(&rsi->si_in_handle, &tmp->si_in_handle);
196         rawobj_dup(&rsi->si_in_token, &tmp->si_in_token);
197         rsi->si_out_handle = RAWOBJ_EMPTY;
198         rsi->si_out_token = RAWOBJ_EMPTY;
199
200         rsi->si_lustre_svc = tmp->si_lustre_svc;
201         rsi->si_nid4 = tmp->si_nid4;
202         memcpy(rsi->si_nm_name, tmp->si_nm_name, sizeof(tmp->si_nm_name));
203 }
204
205 static void __rsi_free(struct gss_rsi *rsi)
206 {
207         rawobj_free(&rsi->si_in_handle);
208         rawobj_free(&rsi->si_in_token);
209         rawobj_free(&rsi->si_out_handle);
210         rawobj_free(&rsi->si_out_token);
211 }
212
213 static void rsi_entry_free(struct upcall_cache *cache,
214                            struct upcall_cache_entry *entry)
215 {
216         struct gss_rsi *rsi = &entry->u.rsi;
217
218         __rsi_free(rsi);
219 }
220
221 static inline int rsi_entry_hash(struct gss_rsi *rsi)
222 {
223 #if BITS_PER_LONG == 64
224         return hash_mem_64((char *)rsi->si_in_handle.data,
225                            rsi->si_in_handle.len, RSI_HASHBITS) ^
226                 hash_mem_64((char *)rsi->si_in_token.data,
227                             rsi->si_in_token.len, RSI_HASHBITS);
228 #else
229         return hash_mem((char *)rsi->si_in_handle.data, rsi->si_in_handle.len,
230                         RSI_HASHBITS) ^
231                 hash_mem((char *)rsi->si_in_token.data, rsi->si_in_token.len,
232                          RSI_HASHBITS);
233 #endif
234 }
235
236 static inline int __rsi_entry_match(rawobj_t *h1, rawobj_t *h2,
237                                     rawobj_t *t1, rawobj_t *t2)
238 {
239         return !(rawobj_equal(h1, h2) && rawobj_equal(t1, t2));
240 }
241
242 static inline int rsi_entry_match(struct gss_rsi *rsi, struct gss_rsi *tmp)
243 {
244         return __rsi_entry_match(&rsi->si_in_handle, &tmp->si_in_handle,
245                                  &rsi->si_in_token, &tmp->si_in_token);
246 }
247
248 /* Returns 0 to tell this is a match */
249 static inline int rsi_upcall_compare(struct upcall_cache *cache,
250                                      struct upcall_cache_entry *entry,
251                                      __u64 key, void *args)
252 {
253         struct gss_rsi *rsi1 = &entry->u.rsi;
254         struct gss_rsi *rsi2 = args;
255
256         return rsi_entry_match(rsi1, rsi2);
257 }
258
259 /* See handle_channel_request() userspace for where the upcall data is read */
260 static int rsi_do_upcall(struct upcall_cache *cache,
261                          struct upcall_cache_entry *entry)
262 {
263         int size, len, *blen;
264         char *buffer, *bp, **bpp;
265         char *argv[] = {
266                 [0] = cache->uc_upcall,
267                 [1] = "-c",
268                 [2] = cache->uc_name,
269                 [3] = "-r",
270                 [4] = NULL,
271                 [5] = NULL
272         };
273         char *envp[] = {
274                 [0] = "HOME=/",
275                 [1] = "PATH=/sbin:/usr/sbin",
276                 [2] = NULL
277         };
278         ktime_t start, end;
279         struct gss_rsi *rsi = &entry->u.rsi;
280         __u64 index = 0;
281         int rc;
282
283         ENTRY;
284         CDEBUG(D_SEC, "rsi upcall '%s' on '%s'\n",
285                cache->uc_upcall, cache->uc_name);
286
287         size = 24 + 1 + /* ue_key is uint64_t */
288                 12 + 1 + /* si_lustre_svc is __u32*/
289                 18 + 1 + /* si_nid4 is lnet_nid_t, hex with leading 0x */
290                 18 + 1 + /* index is __u64, hex with leading 0x */
291                 strlen(rsi->si_nm_name) + 1 +
292                 BASE64URL_CHARS(rsi->si_in_handle.len) + 1 +
293                 BASE64URL_CHARS(rsi->si_in_token.len) + 1 +
294                 1 + 1; /* eol */
295         if (size > MAX_ARG_STRLEN)
296                 RETURN(-E2BIG);
297         OBD_ALLOC_LARGE(buffer, size);
298         if (!buffer)
299                 RETURN(-ENOMEM);
300
301         bp = buffer;
302         bpp = &bp;
303         len = size;
304         blen = &len;
305
306         /* if in_handle is null, provide kernel suggestion */
307         if (rsi->si_in_handle.len == 0)
308                 index = gss_get_next_ctx_index();
309
310         /* entry->ue_key is put into args sent via upcall, so that it can be
311          * returned by userspace. This will help find cache entry at downcall,
312          * without unnecessary recomputation of the hash.
313          */
314         gss_u64_write_string(bpp, blen, entry->ue_key);
315         gss_u64_write_string(bpp, blen, rsi->si_lustre_svc);
316         gss_u64_write_hex_string(bpp, blen, rsi->si_nid4);
317         gss_u64_write_hex_string(bpp, blen, index);
318         gss_string_write(bpp, blen, (char *) rsi->si_nm_name);
319         gss_base64url_encode(bpp, blen, rsi->si_in_handle.data,
320                              rsi->si_in_handle.len);
321         gss_base64url_encode(bpp, blen, rsi->si_in_token.data,
322                              rsi->si_in_token.len);
323         (*bpp)[-1] = '\n';
324         (*bpp)[0] = '\0';
325
326         argv[4] = buffer;
327         down_read(&cache->uc_upcall_rwsem);
328         start = ktime_get();
329         rc = call_usermodehelper(argv[0], argv, envp, UMH_WAIT_EXEC);
330         end = ktime_get();
331         up_read(&cache->uc_upcall_rwsem);
332         if (rc < 0) {
333                 CERROR("%s: error invoking upcall %s %s (time %ldus): rc = %d\n",
334                        cache->uc_name, argv[0], argv[2],
335                        (long)ktime_us_delta(end, start), rc);
336         } else {
337                 CDEBUG(D_SEC, "%s: invoked upcall %s %s (time %ldus)\n",
338                        cache->uc_name, argv[0], argv[2],
339                        (long)ktime_us_delta(end, start));
340                 rc = 0;
341         }
342
343         OBD_FREE_LARGE(buffer, size);
344         RETURN(rc);
345 }
346
347 static inline int rsi_downcall_compare(struct upcall_cache *cache,
348                                        struct upcall_cache_entry *entry,
349                                        __u64 key, void *args)
350 {
351         struct gss_rsi *rsi = &entry->u.rsi;
352         struct rsi_downcall_data *sid = args;
353         char *mesg = sid->sid_val;
354         rawobj_t handle, token;
355         char *p = mesg;
356         int len;
357
358         /* sid_val starts with handle and token */
359
360         /* First, handle */
361         len = gss_buffer_get(&mesg, &handle.len, &handle.data);
362         sid->sid_offset = mesg - p;
363         p = mesg;
364
365         /* Second, token */
366         len = gss_buffer_get(&mesg, &token.len, &token.data);
367         sid->sid_offset += mesg - p;
368
369         return __rsi_entry_match(&rsi->si_in_handle, &handle,
370                                  &rsi->si_in_token, &token);
371 }
372
373 static int rsi_parse_downcall(struct upcall_cache *cache,
374                               struct upcall_cache_entry *entry,
375                               void *args)
376 {
377         struct gss_rsi *rsi = &entry->u.rsi;
378         struct rsi_downcall_data *sid = args;
379         int mlen = sid->sid_len;
380         char *mesg = sid->sid_val + sid->sid_offset;
381         char *buf = sid->sid_val;
382         int status = -EINVAL;
383         int len;
384
385         ENTRY;
386
387         if (mlen <= 0)
388                 goto out;
389
390         rsi->si_major_status = sid->sid_maj_stat;
391         rsi->si_minor_status = sid->sid_min_stat;
392
393         /* in_handle and in_token have already been consumed in
394          * rsi_downcall_compare(). sid_offset gives next field.
395          */
396
397         /* out_handle */
398         len = gss_buffer_read(&mesg, buf, mlen);
399         if (len < 0)
400                 goto out;
401         if (rawobj_alloc(&rsi->si_out_handle, buf, len)) {
402                 status = -ENOMEM;
403                 goto out;
404         }
405
406         /* out_token */
407         len = gss_buffer_read(&mesg, buf, mlen);
408         if (len < 0)
409                 goto out;
410         if (rawobj_alloc(&rsi->si_out_token, buf, len)) {
411                 status = -ENOMEM;
412                 goto out;
413         }
414
415         entry->ue_expire = 0;
416         status = 0;
417
418 out:
419         CDEBUG(D_OTHER, "rsi parse %p: %d\n", rsi, status);
420         RETURN(status);
421 }
422
423 struct gss_rsi *rsi_entry_get(struct upcall_cache *cache, struct gss_rsi *rsi)
424 {
425         struct upcall_cache_entry *entry;
426         int hash = rsi_entry_hash(rsi);
427
428         if (!cache)
429                 return ERR_PTR(-ENOENT);
430
431         entry = upcall_cache_get_entry(cache, (__u64)hash, rsi);
432         if (unlikely(!entry))
433                 return ERR_PTR(-ENOENT);
434         if (IS_ERR(entry))
435                 return ERR_CAST(entry);
436
437         return &entry->u.rsi;
438 }
439
440 void rsi_entry_put(struct upcall_cache *cache, struct gss_rsi *rsi)
441 {
442         if (!cache || !rsi)
443                 return;
444
445         upcall_cache_put_entry(cache, rsi->si_uc_entry);
446 }
447
448 void rsi_flush(struct upcall_cache *cache, int hash)
449 {
450         if (hash < 0)
451                 upcall_cache_flush_idle(cache);
452         else
453                 upcall_cache_flush_one(cache, (__u64)hash, NULL);
454 }
455
456 struct upcall_cache_ops rsi_upcall_cache_ops = {
457         .init_entry       = rsi_entry_init,
458         .free_entry       = rsi_entry_free,
459         .upcall_compare   = rsi_upcall_compare,
460         .downcall_compare = rsi_downcall_compare,
461         .do_upcall        = rsi_do_upcall,
462         .parse_downcall   = rsi_parse_downcall,
463 };
464
465 struct upcall_cache *rsicache;
466
467 struct rsi {
468         struct cache_head       h;
469         __u32                   lustre_svc;
470         lnet_nid_t              nid4; /* FIXME Support larger NID */
471         char                    nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
472         wait_queue_head_t       waitq;
473         rawobj_t                in_handle, in_token;
474         rawobj_t                out_handle, out_token;
475         int                     major_status, minor_status;
476 #ifdef HAVE_CACHE_HASH_SPINLOCK
477         struct rcu_head         rcu_head;
478 #endif
479 };
480
481 #ifdef HAVE_CACHE_HEAD_HLIST
482 static struct hlist_head rsi_table[RSI_HASHMAX];
483 #else
484 static struct cache_head *rsi_table[RSI_HASHMAX];
485 #endif
486 static struct cache_detail rsi_cache;
487 static struct rsi *rsi_update(struct rsi *new, struct rsi *old);
488 static struct rsi *rsi_lookup(struct rsi *item);
489
490 #ifdef HAVE_CACHE_DETAIL_WRITERS
491 static inline int channel_users(struct cache_detail *cd)
492 {
493         return atomic_read(&cd->writers);
494 }
495 #else
496 static inline int channel_users(struct cache_detail *cd)
497 {
498         return atomic_read(&cd->readers);
499 }
500 #endif
501
502 static inline int rsi_hash(struct rsi *item)
503 {
504         return hash_mem((char *)item->in_handle.data, item->in_handle.len,
505                         RSI_HASHBITS) ^
506                hash_mem((char *)item->in_token.data, item->in_token.len,
507                         RSI_HASHBITS);
508 }
509
510 static inline int __rsi_match(struct rsi *item, struct rsi *tmp)
511 {
512         return (rawobj_equal(&item->in_handle, &tmp->in_handle) &&
513                 rawobj_equal(&item->in_token, &tmp->in_token));
514 }
515
516 static void rsi_free(struct rsi *rsi)
517 {
518         rawobj_free(&rsi->in_handle);
519         rawobj_free(&rsi->in_token);
520         rawobj_free(&rsi->out_handle);
521         rawobj_free(&rsi->out_token);
522 }
523
524 /* See handle_channel_req() userspace for where the upcall data is read */
525 static void rsi_request(struct cache_detail *cd,
526                         struct cache_head *h,
527                         char **bpp, int *blen)
528 {
529         struct rsi *rsi = container_of(h, struct rsi, h);
530         __u64 index = 0;
531
532         /* if in_handle is null, provide kernel suggestion */
533         if (rsi->in_handle.len == 0)
534                 index = gss_get_next_ctx_index();
535
536         qword_addhex(bpp, blen, (char *) &rsi->lustre_svc,
537                         sizeof(rsi->lustre_svc));
538         qword_addhex(bpp, blen, (char *) &rsi->nid4, sizeof(rsi->nid4));
539         qword_addhex(bpp, blen, (char *) &index, sizeof(index));
540         qword_addhex(bpp, blen, (char *) rsi->nm_name,
541                      strlen(rsi->nm_name) + 1);
542         qword_addhex(bpp, blen, rsi->in_handle.data, rsi->in_handle.len);
543         qword_addhex(bpp, blen, rsi->in_token.data, rsi->in_token.len);
544         (*bpp)[-1] = '\n';
545 }
546
547 static inline void __rsi_init(struct rsi *new, struct rsi *item)
548 {
549         new->out_handle = RAWOBJ_EMPTY;
550         new->out_token = RAWOBJ_EMPTY;
551
552         new->in_handle = item->in_handle;
553         item->in_handle = RAWOBJ_EMPTY;
554         new->in_token = item->in_token;
555         item->in_token = RAWOBJ_EMPTY;
556
557         new->lustre_svc = item->lustre_svc;
558         new->nid4 = item->nid4;
559         memcpy(new->nm_name, item->nm_name, sizeof(item->nm_name));
560         init_waitqueue_head(&new->waitq);
561 }
562
563 static inline void __rsi_update(struct rsi *new, struct rsi *item)
564 {
565         LASSERT(new->out_handle.len == 0);
566         LASSERT(new->out_token.len == 0);
567
568         new->out_handle = item->out_handle;
569         item->out_handle = RAWOBJ_EMPTY;
570         new->out_token = item->out_token;
571         item->out_token = RAWOBJ_EMPTY;
572
573         new->major_status = item->major_status;
574         new->minor_status = item->minor_status;
575 }
576
577 #ifdef HAVE_CACHE_HASH_SPINLOCK
578 static void rsi_free_rcu(struct rcu_head *head)
579 {
580         struct rsi *rsi = container_of(head, struct rsi, rcu_head);
581
582 #ifdef HAVE_CACHE_HEAD_HLIST
583         LASSERT(hlist_unhashed(&rsi->h.cache_list));
584 #else
585         LASSERT(rsi->h.next == NULL);
586 #endif
587         rsi_free(rsi);
588         OBD_FREE_PTR(rsi);
589 }
590
591 static void rsi_put(struct kref *ref)
592 {
593         struct rsi *rsi = container_of(ref, struct rsi, h.ref);
594
595         call_rcu(&rsi->rcu_head, rsi_free_rcu);
596 }
597 #else /* !HAVE_CACHE_HASH_SPINLOCK */
598 static void rsi_put(struct kref *ref)
599 {
600         struct rsi *rsi = container_of(ref, struct rsi, h.ref);
601
602 #ifdef HAVE_CACHE_HEAD_HLIST
603         LASSERT(hlist_unhashed(&rsi->h.cache_list));
604 #else
605         LASSERT(rsi->h.next == NULL);
606 #endif
607         rsi_free(rsi);
608         OBD_FREE_PTR(rsi);
609 }
610 #endif /* HAVE_CACHE_HASH_SPINLOCK */
611
612 static int rsi_match(struct cache_head *a, struct cache_head *b)
613 {
614         struct rsi *item = container_of(a, struct rsi, h);
615         struct rsi *tmp = container_of(b, struct rsi, h);
616
617         return __rsi_match(item, tmp);
618 }
619
620 static void rsi_init(struct cache_head *cnew, struct cache_head *citem)
621 {
622         struct rsi *new = container_of(cnew, struct rsi, h);
623         struct rsi *item = container_of(citem, struct rsi, h);
624
625         __rsi_init(new, item);
626 }
627
628 static void update_rsi(struct cache_head *cnew, struct cache_head *citem)
629 {
630         struct rsi *new = container_of(cnew, struct rsi, h);
631         struct rsi *item = container_of(citem, struct rsi, h);
632
633         __rsi_update(new, item);
634 }
635
636 static struct cache_head *rsi_alloc(void)
637 {
638         struct rsi *rsi;
639
640         OBD_ALLOC_PTR(rsi);
641         if (rsi) 
642                 return &rsi->h;
643         else
644                 return NULL;
645 }
646
647 static int rsi_parse(struct cache_detail *cd, char *mesg, int mlen)
648 {
649         char *buf = mesg;
650         int len;
651         struct rsi rsii, *rsip = NULL;
652         time64_t expiry;
653         int status = -EINVAL;
654         ENTRY;
655
656         memset(&rsii, 0, sizeof(rsii));
657
658         /* handle */
659         len = qword_get(&mesg, buf, mlen);
660         if (len < 0)
661                 goto out;
662         if (rawobj_alloc(&rsii.in_handle, buf, len)) {
663                 status = -ENOMEM;
664                 goto out;
665         }
666
667         /* token */
668         len = qword_get(&mesg, buf, mlen);
669         if (len < 0)
670                 goto out;
671         if (rawobj_alloc(&rsii.in_token, buf, len)) {
672                 status = -ENOMEM;
673                 goto out;
674         }
675
676         rsip = rsi_lookup(&rsii);
677         if (!rsip)
678                 goto out;
679         if (!test_bit(CACHE_PENDING, &rsip->h.flags)) {
680                 /* If this is not a pending request, it probably means
681                  * someone wrote arbitrary data to the init channel.
682                  * Directly return -EINVAL in this case.
683                  */
684                 status = -EINVAL;
685                 goto out;
686         }
687
688         rsii.h.flags = 0;
689         /* expiry */
690         status = get_expiry(&mesg, &expiry);
691         if (status)
692                 goto out;
693
694         len = qword_get(&mesg, buf, mlen);
695         if (len <= 0)
696                 goto out;
697
698         /* major */
699         status = kstrtoint(buf, 10, &rsii.major_status);
700         if (status)
701                 goto out;
702
703         /* minor */
704         len = qword_get(&mesg, buf, mlen);
705         if (len <= 0) {
706                 status = -EINVAL;
707                 goto out;
708         }
709
710         status = kstrtoint(buf, 10, &rsii.minor_status);
711         if (status)
712                 goto out;
713
714         /* out_handle */
715         len = qword_get(&mesg, buf, mlen);
716         if (len < 0)
717                 goto out;
718         if (rawobj_alloc(&rsii.out_handle, buf, len)) {
719                 status = -ENOMEM;
720                 goto out;
721         }
722
723         /* out_token */
724         len = qword_get(&mesg, buf, mlen);
725         if (len < 0)
726                 goto out;
727         if (rawobj_alloc(&rsii.out_token, buf, len)) {
728                 status = -ENOMEM;
729                 goto out;
730         }
731
732         rsii.h.expiry_time = expiry;
733         rsip = rsi_update(&rsii, rsip);
734         status = 0;
735 out:
736         rsi_free(&rsii);
737         if (rsip) {
738                 wake_up(&rsip->waitq);
739                 cache_put(&rsip->h, &rsi_cache);
740         } else {
741                 status = -ENOMEM;
742         }
743
744         if (status)
745                 CERROR("rsi parse error %d\n", status);
746         RETURN(status);
747 }
748
749 static struct cache_detail rsi_cache = {
750         .hash_size      = RSI_HASHMAX,
751         .hash_table     = rsi_table,
752         .name           = "auth.sptlrpc.init",
753         .cache_put      = rsi_put,
754         .cache_request  = rsi_request,
755         .cache_upcall   = sunrpc_cache_pipe_upcall,
756         .cache_parse    = rsi_parse,
757         .match          = rsi_match,
758         .init           = rsi_init,
759         .update         = update_rsi,
760         .alloc          = rsi_alloc,
761 };
762
763 static struct rsi *rsi_lookup(struct rsi *item)
764 {
765         struct cache_head *ch;
766         int hash = rsi_hash(item);
767
768         ch = sunrpc_cache_lookup(&rsi_cache, &item->h, hash);
769         if (ch)
770                 return container_of(ch, struct rsi, h);
771         else
772                 return NULL;
773 }
774
775 static struct rsi *rsi_update(struct rsi *new, struct rsi *old)
776 {
777         struct cache_head *ch;
778         int hash = rsi_hash(new);
779
780         ch = sunrpc_cache_update(&rsi_cache, &new->h, &old->h, hash);
781         if (ch)
782                 return container_of(ch, struct rsi, h);
783         else
784                 return NULL;
785 }
786
787 /****************************************
788  * rpc sec context (rsc) cache          *
789  ****************************************/
790
791 #define RSC_HASHBITS    (10)
792 #define RSC_HASHMAX     (1 << RSC_HASHBITS)
793 #define RSC_HASHMASK    (RSC_HASHMAX - 1)
794
795 static void rsc_entry_init(struct upcall_cache_entry *entry,
796                            void *args)
797 {
798         struct gss_rsc *rsc = &entry->u.rsc;
799         struct gss_rsc *tmp = args;
800
801         rsc->sc_uc_entry = entry;
802         rawobj_dup(&rsc->sc_handle, &tmp->sc_handle);
803
804         rsc->sc_target = NULL;
805         memset(&rsc->sc_ctx, 0, sizeof(rsc->sc_ctx));
806         rsc->sc_ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
807
808         memset(&rsc->sc_ctx.gsc_seqdata, 0, sizeof(rsc->sc_ctx.gsc_seqdata));
809         spin_lock_init(&rsc->sc_ctx.gsc_seqdata.ssd_lock);
810 }
811
812 void __rsc_free(struct gss_rsc *rsc)
813 {
814         rawobj_free(&rsc->sc_handle);
815         rawobj_free(&rsc->sc_ctx.gsc_rvs_hdl);
816         lgss_delete_sec_context(&rsc->sc_ctx.gsc_mechctx);
817 }
818
819 static void rsc_entry_free(struct upcall_cache *cache,
820                            struct upcall_cache_entry *entry)
821 {
822         struct gss_rsc *rsc = &entry->u.rsc;
823
824         __rsc_free(rsc);
825 }
826
827 static inline int rsc_entry_hash(struct gss_rsc *rsc)
828 {
829 #if BITS_PER_LONG == 64
830         return hash_mem_64((char *)rsc->sc_handle.data,
831                            rsc->sc_handle.len, RSC_HASHBITS);
832 #else
833         return hash_mem((char *)rsc->sc_handle.data,
834                         rsc->sc_handle.len, RSC_HASHBITS);
835 #endif
836 }
837
838 static inline int __rsc_entry_match(rawobj_t *h1, rawobj_t *h2)
839 {
840         return !(rawobj_equal(h1, h2));
841 }
842
843 static inline int rsc_entry_match(struct gss_rsc *rsc, struct gss_rsc *tmp)
844 {
845         return __rsc_entry_match(&rsc->sc_handle, &tmp->sc_handle);
846 }
847
848 /* Returns 0 to tell this is a match */
849 static inline int rsc_upcall_compare(struct upcall_cache *cache,
850                                      struct upcall_cache_entry *entry,
851                                      __u64 key, void *args)
852 {
853         struct gss_rsc *rsc1 = &entry->u.rsc;
854         struct gss_rsc *rsc2 = args;
855
856         return rsc_entry_match(rsc1, rsc2);
857 }
858
859 /* rsc upcall is a no-op, we just need a valid entry */
860 static inline int rsc_do_upcall(struct upcall_cache *cache,
861                                 struct upcall_cache_entry *entry)
862 {
863         upcall_cache_update_entry(cache, entry,
864                                   ktime_get_seconds() + cache->uc_entry_expire,
865                                   0);
866         return 0;
867 }
868
869 static inline int rsc_downcall_compare(struct upcall_cache *cache,
870                                        struct upcall_cache_entry *entry,
871                                        __u64 key, void *args)
872 {
873         struct gss_rsc *rsc = &entry->u.rsc;
874         struct rsc_downcall_data *scd = args;
875         char *mesg = scd->scd_val;
876         rawobj_t handle;
877         int len;
878
879         /* scd_val starts with handle */
880         len = gss_buffer_get(&mesg, &handle.len, &handle.data);
881         scd->scd_offset = mesg - scd->scd_val;
882
883         return __rsc_entry_match(&rsc->sc_handle, &handle);
884 }
885
886 static int rsc_parse_downcall(struct upcall_cache *cache,
887                               struct upcall_cache_entry *entry,
888                               void *args)
889 {
890         struct gss_api_mech *gm = NULL;
891         struct gss_rsc *rsc = &entry->u.rsc;
892         struct rsc_downcall_data *scd = args;
893         int mlen = scd->scd_len;
894         char *mesg = scd->scd_val + scd->scd_offset;
895         char *buf = scd->scd_val;
896         int status = -EINVAL;
897         time64_t ctx_expiry;
898         rawobj_t tmp_buf;
899         int len;
900
901         ENTRY;
902
903         if (mlen <= 0)
904                 goto out;
905
906         rsc->sc_ctx.gsc_remote = !!(scd->scd_flags & RSC_DATA_FLAG_REMOTE);
907         rsc->sc_ctx.gsc_usr_root = !!(scd->scd_flags & RSC_DATA_FLAG_ROOT);
908         rsc->sc_ctx.gsc_usr_mds = !!(scd->scd_flags & RSC_DATA_FLAG_MDS);
909         rsc->sc_ctx.gsc_usr_oss = !!(scd->scd_flags & RSC_DATA_FLAG_OSS);
910         rsc->sc_ctx.gsc_mapped_uid = scd->scd_mapped_uid;
911         rsc->sc_ctx.gsc_uid = scd->scd_uid;
912
913         rsc->sc_ctx.gsc_gid = scd->scd_gid;
914         gm = lgss_name_to_mech(scd->scd_mechname);
915         if (!gm) {
916                 status = -EOPNOTSUPP;
917                 goto out;
918         }
919
920         /* handle has already been consumed in rsc_downcall_compare().
921          * scd_offset gives next field.
922          */
923
924         /* context token */
925         len = gss_buffer_read(&mesg, buf, mlen);
926         if (len < 0)
927                 goto out;
928         tmp_buf.len = len;
929         tmp_buf.data = (unsigned char *)buf;
930         if (lgss_import_sec_context(&tmp_buf, gm,
931                                     &rsc->sc_ctx.gsc_mechctx))
932                 goto out;
933
934         if (lgss_inquire_context(rsc->sc_ctx.gsc_mechctx, &ctx_expiry))
935                 goto out;
936
937         /* ctx_expiry is the number of seconds since Jan 1 1970.
938          * We just want the number of seconds into the future.
939          */
940         entry->ue_expire = ktime_get_seconds() +
941                 (ctx_expiry - ktime_get_real_seconds());
942         status = 0;
943
944 out:
945         if (gm)
946                 lgss_mech_put(gm);
947         CDEBUG(D_OTHER, "rsc parse %p: %d\n", rsc, status);
948         RETURN(status);
949 }
950
951 struct gss_rsc *rsc_entry_get(struct upcall_cache *cache, struct gss_rsc *rsc)
952 {
953         struct upcall_cache_entry *entry;
954         int hash = rsc_entry_hash(rsc);
955
956         if (!cache)
957                 return ERR_PTR(-ENOENT);
958
959         entry = upcall_cache_get_entry(cache, (__u64)hash, rsc);
960         if (unlikely(!entry))
961                 return ERR_PTR(-ENOENT);
962         if (IS_ERR(entry))
963                 return ERR_CAST(entry);
964
965         return &entry->u.rsc;
966 }
967
968 void rsc_entry_put(struct upcall_cache *cache, struct gss_rsc *rsc)
969 {
970         if (!cache || !rsc)
971                 return;
972
973         upcall_cache_put_entry(cache, rsc->sc_uc_entry);
974 }
975
976 void rsc_flush(struct upcall_cache *cache, int hash)
977 {
978         if (hash < 0)
979                 upcall_cache_flush_idle(cache);
980         else
981                 upcall_cache_flush_one(cache, (__u64)hash, NULL);
982 }
983
984 struct upcall_cache_ops rsc_upcall_cache_ops = {
985         .init_entry       = rsc_entry_init,
986         .free_entry       = rsc_entry_free,
987         .upcall_compare   = rsc_upcall_compare,
988         .downcall_compare = rsc_downcall_compare,
989         .do_upcall        = rsc_do_upcall,
990         .parse_downcall   = rsc_parse_downcall,
991 };
992
993 struct upcall_cache *rsccache;
994
995 struct rsc {
996         struct cache_head       h;
997         struct obd_device      *target;
998         rawobj_t                handle;
999         struct gss_svc_ctx      ctx;
1000 #ifdef HAVE_CACHE_HASH_SPINLOCK
1001         struct rcu_head         rcu_head;
1002 #endif
1003 };
1004
1005 #ifdef HAVE_CACHE_HEAD_HLIST
1006 static struct hlist_head rsc_table[RSC_HASHMAX];
1007 #else
1008 static struct cache_head *rsc_table[RSC_HASHMAX];
1009 #endif
1010 static struct cache_detail rsc_cache;
1011 static struct rsc *rsc_update(struct rsc *new, struct rsc *old);
1012 static struct rsc *rsc_lookup(struct rsc *item);
1013
1014 static void rsc_free(struct rsc *rsci)
1015 {
1016         rawobj_free(&rsci->handle);
1017         rawobj_free(&rsci->ctx.gsc_rvs_hdl);
1018         lgss_delete_sec_context(&rsci->ctx.gsc_mechctx);
1019 }
1020
1021 static inline int rsc_hash(struct rsc *rsci)
1022 {
1023         return hash_mem((char *)rsci->handle.data,
1024                         rsci->handle.len, RSC_HASHBITS);
1025 }
1026
1027 static inline int __rsc_match(struct rsc *new, struct rsc *tmp)
1028 {
1029         return rawobj_equal(&new->handle, &tmp->handle);
1030 }
1031
1032 static inline void __rsc_init(struct rsc *new, struct rsc *tmp)
1033 {
1034         new->handle = tmp->handle;
1035         tmp->handle = RAWOBJ_EMPTY;
1036
1037         new->target = NULL;
1038         memset(&new->ctx, 0, sizeof(new->ctx));
1039         new->ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
1040 }
1041
1042 static inline void __rsc_update(struct rsc *new, struct rsc *tmp)
1043 {
1044         new->ctx = tmp->ctx;
1045         memset(&tmp->ctx, 0, sizeof(tmp->ctx));
1046         tmp->ctx.gsc_rvs_hdl = RAWOBJ_EMPTY;
1047         tmp->ctx.gsc_mechctx = NULL;
1048         tmp->target = NULL;
1049
1050         memset(&new->ctx.gsc_seqdata, 0, sizeof(new->ctx.gsc_seqdata));
1051         spin_lock_init(&new->ctx.gsc_seqdata.ssd_lock);
1052 }
1053
1054 #ifdef HAVE_CACHE_HASH_SPINLOCK
1055 static void rsc_free_rcu(struct rcu_head *head)
1056 {
1057         struct rsc *rsci = container_of(head, struct rsc, rcu_head);
1058
1059 #ifdef HAVE_CACHE_HEAD_HLIST
1060         LASSERT(hlist_unhashed(&rsci->h.cache_list));
1061 #else
1062         LASSERT(rsci->h.next == NULL);
1063 #endif
1064         rawobj_free(&rsci->handle);
1065         OBD_FREE_PTR(rsci);
1066 }
1067
1068 static void rsc_put(struct kref *ref)
1069 {
1070         struct rsc *rsci = container_of(ref, struct rsc, h.ref);
1071
1072         rawobj_free(&rsci->ctx.gsc_rvs_hdl);
1073         lgss_delete_sec_context(&rsci->ctx.gsc_mechctx);
1074         call_rcu(&rsci->rcu_head, rsc_free_rcu);
1075 }
1076 #else /* !HAVE_CACHE_HASH_SPINLOCK */
1077 static void rsc_put(struct kref *ref)
1078 {
1079         struct rsc *rsci = container_of(ref, struct rsc, h.ref);
1080
1081 #ifdef HAVE_CACHE_HEAD_HLIST
1082         LASSERT(hlist_unhashed(&rsci->h.cache_list));
1083 #else
1084         LASSERT(rsci->h.next == NULL);
1085 #endif
1086         rsc_free(rsci);
1087         OBD_FREE_PTR(rsci);
1088 }
1089 #endif /* HAVE_CACHE_HASH_SPINLOCK */
1090
1091 static int rsc_match(struct cache_head *a, struct cache_head *b)
1092 {
1093         struct rsc *new = container_of(a, struct rsc, h);
1094         struct rsc *tmp = container_of(b, struct rsc, h);
1095
1096         return __rsc_match(new, tmp);
1097 }
1098
1099 static void rsc_init(struct cache_head *cnew, struct cache_head *ctmp)
1100 {
1101         struct rsc *new = container_of(cnew, struct rsc, h);
1102         struct rsc *tmp = container_of(ctmp, struct rsc, h);
1103
1104         __rsc_init(new, tmp);
1105 }
1106
1107 static void update_rsc(struct cache_head *cnew, struct cache_head *ctmp)
1108 {
1109         struct rsc *new = container_of(cnew, struct rsc, h);
1110         struct rsc *tmp = container_of(ctmp, struct rsc, h);
1111
1112         __rsc_update(new, tmp);
1113 }
1114
1115 static struct cache_head * rsc_alloc(void)
1116 {
1117         struct rsc *rsc;
1118
1119         OBD_ALLOC_PTR(rsc);
1120         if (rsc)
1121                 return &rsc->h;
1122         else
1123                 return NULL;
1124 }
1125
1126 static int rsc_parse(struct cache_detail *cd, char *mesg, int mlen)
1127 {
1128         char *buf = mesg;
1129         int len, rv, tmp_int;
1130         struct rsc rsci, *rscp = NULL;
1131         time64_t expiry;
1132         int status = -EINVAL;
1133         struct gss_api_mech *gm = NULL;
1134
1135         memset(&rsci, 0, sizeof(rsci));
1136
1137         /* context handle */
1138         len = qword_get(&mesg, buf, mlen);
1139         if (len < 0)
1140                 goto out;
1141
1142         status = -ENOMEM;
1143         if (rawobj_alloc(&rsci.handle, buf, len))
1144                 goto out;
1145
1146         rsci.h.flags = 0;
1147         /* expiry */
1148         status = get_expiry(&mesg, &expiry);
1149         if (status)
1150                 goto out;
1151
1152         status = -EINVAL;
1153         /* remote flag */
1154         rv = get_int(&mesg, &tmp_int);
1155         if (rv) {
1156                 CERROR("fail to get remote flag\n");
1157                 goto out;
1158         }
1159         rsci.ctx.gsc_remote = (tmp_int != 0);
1160
1161         /* root user flag */
1162         rv = get_int(&mesg, &tmp_int);
1163         if (rv) {
1164                 CERROR("fail to get root user flag\n");
1165                 goto out;
1166         }
1167         rsci.ctx.gsc_usr_root = (tmp_int != 0);
1168
1169         /* mds user flag */
1170         rv = get_int(&mesg, &tmp_int);
1171         if (rv) {
1172                 CERROR("fail to get mds user flag\n");
1173                 goto out;
1174         }
1175         rsci.ctx.gsc_usr_mds = (tmp_int != 0);
1176
1177         /* oss user flag */
1178         rv = get_int(&mesg, &tmp_int);
1179         if (rv) {
1180                 CERROR("fail to get oss user flag\n");
1181                 goto out;
1182         }
1183         rsci.ctx.gsc_usr_oss = (tmp_int != 0);
1184
1185         /* mapped uid */
1186         rv = get_int(&mesg, (int *) &rsci.ctx.gsc_mapped_uid);
1187         if (rv) {
1188                 CERROR("fail to get mapped uid\n");
1189                 goto out;
1190         }
1191
1192         rscp = rsc_lookup(&rsci);
1193         if (!rscp)
1194                 goto out;
1195
1196         /* uid, or NEGATIVE */
1197         rv = get_int(&mesg, (int *) &rsci.ctx.gsc_uid);
1198         if (rv == -EINVAL)
1199                 goto out;
1200         if (rv == -ENOENT) {
1201                 CERROR("NOENT? set rsc entry negative\n");
1202                 set_bit(CACHE_NEGATIVE, &rsci.h.flags);
1203         } else {
1204                 rawobj_t tmp_buf;
1205                 time64_t ctx_expiry;
1206
1207                 /* gid */
1208                 if (get_int(&mesg, (int *) &rsci.ctx.gsc_gid))
1209                         goto out;
1210
1211                 /* mech name */
1212                 len = qword_get(&mesg, buf, mlen);
1213                 if (len < 0)
1214                         goto out;
1215                 gm = lgss_name_to_mech(buf);
1216                 status = -EOPNOTSUPP;
1217                 if (!gm)
1218                         goto out;
1219
1220                 status = -EINVAL;
1221                 /* mech-specific data: */
1222                 len = qword_get(&mesg, buf, mlen);
1223                 if (len < 0)
1224                         goto out;
1225
1226                 tmp_buf.len = len;
1227                 tmp_buf.data = (unsigned char *)buf;
1228                 if (lgss_import_sec_context(&tmp_buf, gm,
1229                                             &rsci.ctx.gsc_mechctx))
1230                         goto out;
1231
1232                 /* set to seconds since machine booted */
1233                 expiry = ktime_get_seconds();
1234
1235                 /* currently the expiry time passed down from user-space
1236                  * is invalid, here we retrive it from mech.
1237                  */
1238                 if (lgss_inquire_context(rsci.ctx.gsc_mechctx, &ctx_expiry)) {
1239                         CERROR("unable to get expire time, drop it\n");
1240                         goto out;
1241                 }
1242
1243                 /* ctx_expiry is the number of seconds since Jan 1 1970.
1244                  * We want just the  number of seconds into the future.
1245                  */
1246                 expiry += ctx_expiry - ktime_get_real_seconds();
1247         }
1248
1249         rsci.h.expiry_time = expiry;
1250         rscp = rsc_update(&rsci, rscp);
1251         status = 0;
1252 out:
1253         if (gm)
1254                 lgss_mech_put(gm);
1255         rsc_free(&rsci);
1256         if (rscp)
1257                 cache_put(&rscp->h, &rsc_cache);
1258         else
1259                 status = -ENOMEM;
1260
1261         if (status)
1262                 CERROR("parse rsc error %d\n", status);
1263         return status;
1264 }
1265
1266 static struct cache_detail rsc_cache = {
1267         .hash_size      = RSC_HASHMAX,
1268         .hash_table     = rsc_table,
1269         .name           = "auth.sptlrpc.context",
1270         .cache_put      = rsc_put,
1271         .cache_parse    = rsc_parse,
1272         .match          = rsc_match,
1273         .init           = rsc_init,
1274         .update         = update_rsc,
1275         .alloc          = rsc_alloc,
1276 };
1277
1278 static struct rsc *rsc_lookup(struct rsc *item)
1279 {
1280         struct cache_head *ch;
1281         int                hash = rsc_hash(item);
1282
1283         ch = sunrpc_cache_lookup(&rsc_cache, &item->h, hash);
1284         if (ch)
1285                 return container_of(ch, struct rsc, h);
1286         else
1287                 return NULL;
1288 }
1289
1290 static struct rsc *rsc_update(struct rsc *new, struct rsc *old)
1291 {
1292         struct cache_head *ch;
1293         int                hash = rsc_hash(new);
1294
1295         ch = sunrpc_cache_update(&rsc_cache, &new->h, &old->h, hash);
1296         if (ch)
1297                 return container_of(ch, struct rsc, h);
1298         else
1299                 return NULL;
1300 }
1301
1302 #define COMPAT_RSC_PUT(item, cd)        cache_put((item), (cd))
1303
1304 /****************************************
1305  * rsc cache flush                      *
1306  ****************************************/
1307
1308 static struct gss_rsc *gss_svc_searchbyctx(rawobj_t *handle)
1309 {
1310         struct gss_rsc rsc;
1311         struct gss_rsc *found;
1312
1313         memset(&rsc, 0, sizeof(rsc));
1314         if (rawobj_dup(&rsc.sc_handle, handle))
1315                 return NULL;
1316
1317         found = rsc_entry_get(rsccache, &rsc);
1318         __rsc_free(&rsc);
1319         if (IS_ERR_OR_NULL(found))
1320                 return found;
1321         if (!found->sc_ctx.gsc_mechctx) {
1322                 rsc_entry_put(rsccache, found);
1323                 return ERR_PTR(-ENOENT);
1324         }
1325         return found;
1326 }
1327
1328 int gss_svc_upcall_install_rvs_ctx(struct obd_import *imp,
1329                                    struct gss_sec *gsec,
1330                                    struct gss_cli_ctx *gctx)
1331 {
1332         struct gss_rsc rsc, *rscp = NULL;
1333         time64_t ctx_expiry;
1334         __u32 major;
1335         int rc;
1336
1337         ENTRY;
1338         memset(&rsc, 0, sizeof(rsc));
1339
1340         if (!imp || !imp->imp_obd) {
1341                 CERROR("invalid imp, drop\n");
1342                 RETURN(-EPROTO);
1343         }
1344
1345         if (rawobj_alloc(&rsc.sc_handle, (char *)&gsec->gs_rvs_hdl,
1346                          sizeof(gsec->gs_rvs_hdl)))
1347                 GOTO(out, rc = -ENOMEM);
1348
1349         rscp = rsc_entry_get(rsccache, &rsc);
1350         __rsc_free(&rsc);
1351         if (IS_ERR_OR_NULL(rscp))
1352                 GOTO(out, rc = -ENOMEM);
1353
1354         major = lgss_copy_reverse_context(gctx->gc_mechctx,
1355                                           &rscp->sc_ctx.gsc_mechctx);
1356         if (major != GSS_S_COMPLETE)
1357                 GOTO(out, rc = -ENOMEM);
1358
1359         if (lgss_inquire_context(rscp->sc_ctx.gsc_mechctx, &ctx_expiry)) {
1360                 CERROR("%s: unable to get expire time, drop\n",
1361                        imp->imp_obd->obd_name);
1362                 GOTO(out, rc = -EINVAL);
1363         }
1364         rscp->sc_uc_entry->ue_expire = ktime_get_seconds() +
1365                 (ctx_expiry - ktime_get_real_seconds());
1366
1367         switch (imp->imp_obd->u.cli.cl_sp_to) {
1368         case LUSTRE_SP_MDT:
1369                 rscp->sc_ctx.gsc_usr_mds = 1;
1370                 break;
1371         case LUSTRE_SP_OST:
1372                 rscp->sc_ctx.gsc_usr_oss = 1;
1373                 break;
1374         case LUSTRE_SP_CLI:
1375                 rscp->sc_ctx.gsc_usr_root = 1;
1376                 break;
1377         case LUSTRE_SP_MGS:
1378                 /* by convention, all 3 set to 1 means MGS */
1379                 rscp->sc_ctx.gsc_usr_mds = 1;
1380                 rscp->sc_ctx.gsc_usr_oss = 1;
1381                 rscp->sc_ctx.gsc_usr_root = 1;
1382                 break;
1383         default:
1384                 break;
1385         }
1386
1387         rscp->sc_target = imp->imp_obd;
1388         rawobj_dup(&gctx->gc_svc_handle, &rscp->sc_handle);
1389
1390         CDEBUG(D_SEC, "%s: create reverse svc ctx %p to %s: idx %#llx\n",
1391                imp->imp_obd->obd_name, &rscp->sc_ctx, obd2cli_tgt(imp->imp_obd),
1392                gsec->gs_rvs_hdl);
1393         rc = 0;
1394 out:
1395         if (!IS_ERR_OR_NULL(rscp))
1396                 rsc_entry_put(rsccache, rscp);
1397         if (rc)
1398                 CERROR("%s: can't create reverse svc ctx idx %#llx: rc = %d\n",
1399                        imp->imp_obd->obd_name, gsec->gs_rvs_hdl, rc);
1400         RETURN(rc);
1401 }
1402
1403 int gss_svc_upcall_expire_rvs_ctx(rawobj_t *handle)
1404 {
1405         const time64_t expire = 20;
1406         struct gss_rsc *rscp;
1407
1408         rscp = gss_svc_searchbyctx(handle);
1409         if (!IS_ERR_OR_NULL(rscp)) {
1410                 CDEBUG(D_SEC,
1411                        "reverse svcctx %p (rsc %p) expire in %lld seconds\n",
1412                        &rscp->sc_ctx, rscp, expire);
1413
1414                 rscp->sc_uc_entry->ue_expire = ktime_get_seconds() + expire;
1415                 rsc_entry_put(rsccache, rscp);
1416         }
1417         return 0;
1418 }
1419
1420 int gss_svc_upcall_dup_handle(rawobj_t *handle, struct gss_svc_ctx *ctx)
1421 {
1422         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
1423
1424         return rawobj_dup(handle, &rscp->sc_handle);
1425 }
1426
1427 int gss_svc_upcall_update_sequence(rawobj_t *handle, __u32 seq)
1428 {
1429         struct gss_rsc *rscp;
1430
1431         rscp = gss_svc_searchbyctx(handle);
1432         if (!IS_ERR_OR_NULL(rscp)) {
1433                 CDEBUG(D_SEC, "reverse svcctx %p (rsc %p) update seq to %u\n",
1434                        &rscp->sc_ctx, rscp, seq + 1);
1435
1436                 rscp->sc_ctx.gsc_rvs_seq = seq + 1;
1437                 rsc_entry_put(rsccache, rscp);
1438         }
1439         return 0;
1440 }
1441
1442 int gss_svc_upcall_handle_init(struct ptlrpc_request *req,
1443                                struct gss_svc_reqctx *grctx,
1444                                struct gss_wire_ctx *gw,
1445                                struct obd_device *target,
1446                                __u32 lustre_svc,
1447                                rawobj_t *rvs_hdl,
1448                                rawobj_t *in_token)
1449 {
1450         struct gss_rsi rsi = { 0 }, *rsip = NULL;
1451         struct ptlrpc_reply_state *rs;
1452         struct gss_rsc *rscp = NULL;
1453         int replen = sizeof(struct ptlrpc_body);
1454         struct gss_rep_header *rephdr;
1455         int rc, rc2;
1456
1457         ENTRY;
1458
1459         rsi.si_lustre_svc = lustre_svc;
1460         /* In case of MR, rq_peer is not the NID from which request is received,
1461          * but primary NID of peer.
1462          * So we need LNetPrimaryNID(rq_source) to match what the clients uses.
1463          */
1464         LNetPrimaryNID(&req->rq_source.nid);
1465         rsi.si_nid4 = lnet_nid_to_nid4(&req->rq_source.nid);
1466         nodemap_test_nid(lnet_nid_to_nid4(&req->rq_peer.nid), rsi.si_nm_name,
1467                          sizeof(rsi.si_nm_name));
1468
1469         /* Note that context handle is always 0 for for INIT. */
1470         rc2 = rawobj_dup(&rsi.si_in_handle, &gw->gw_handle);
1471         if (rc2) {
1472                 CERROR("%s: failed to duplicate context handle: rc = %d\n",
1473                        target->obd_name, rc2);
1474                 GOTO(out, rc = SECSVC_DROP);
1475         }
1476
1477         rc2 = rawobj_dup(&rsi.si_in_token, in_token);
1478         if (rc2) {
1479                 CERROR("%s: failed to duplicate token: rc = %d\n",
1480                        target->obd_name, rc2);
1481                 rawobj_free(&rsi.si_in_handle);
1482                 GOTO(out, rc = SECSVC_DROP);
1483         }
1484
1485         rsip = rsi_entry_get(rsicache, &rsi);
1486         __rsi_free(&rsi);
1487         if (IS_ERR_OR_NULL(rsip)) {
1488                 if (IS_ERR(rsip))
1489                         rc2 = PTR_ERR(rsip);
1490                 else
1491                         rc2 = -EINVAL;
1492                 CERROR("%s: failed to get entry from rsi cache (nid %s): rc = %d\n",
1493                        target->obd_name,
1494                        libcfs_nid2str(lnet_nid_to_nid4(&req->rq_source.nid)),
1495                        rc2);
1496
1497                 if (!gss_pack_err_notify(req, GSS_S_FAILURE, 0))
1498                         rc = SECSVC_COMPLETE;
1499                 else
1500                         rc = SECSVC_DROP;
1501
1502                 GOTO(out, rc);
1503         }
1504
1505         rscp = gss_svc_searchbyctx(&rsip->si_out_handle);
1506         if (IS_ERR_OR_NULL(rscp)) {
1507                 /* gss mechanism returned major and minor code so we return
1508                  * those in error message */
1509
1510                 if (!gss_pack_err_notify(req, rsip->si_major_status,
1511                                          rsip->si_minor_status))
1512                         rc = SECSVC_COMPLETE;
1513                 else
1514                         rc = SECSVC_DROP;
1515
1516                 CERROR("%s: authentication failed: rc = %d\n",
1517                        target->obd_name, rc);
1518                 GOTO(out, rc);
1519         } else {
1520                 /* we need to take an extra ref on the cache entry,
1521                  * as a pointer to sc_ctx is stored in grctx
1522                  */
1523                 upcall_cache_get_entry_raw(rscp->sc_uc_entry);
1524                 grctx->src_ctx = &rscp->sc_ctx;
1525         }
1526
1527         if (gw->gw_flags & LUSTRE_GSS_PACK_KCSUM) {
1528                 grctx->src_ctx->gsc_mechctx->hash_func = gss_digest_hash;
1529         } else if (!strcmp(grctx->src_ctx->gsc_mechctx->mech_type->gm_name,
1530                            "krb5") &&
1531                    !krb5_allow_old_client_csum) {
1532                 CWARN("%s: deny connection from '%s' due to missing 'krb_csum' feature, set 'sptlrpc.gss.krb5_allow_old_client_csum=1' to allow, but recommend client upgrade: rc = %d\n",
1533                       target->obd_name, libcfs_nidstr(&req->rq_peer.nid),
1534                       -EPROTO);
1535                 GOTO(out, rc = SECSVC_DROP);
1536         } else {
1537                 grctx->src_ctx->gsc_mechctx->hash_func =
1538                         gss_digest_hash_compat;
1539         }
1540
1541         if (rawobj_dup(&rscp->sc_ctx.gsc_rvs_hdl, rvs_hdl)) {
1542                 CERROR("%s: failed duplicate reverse handle\n",
1543                        target->obd_name);
1544                 GOTO(out, rc = SECSVC_DROP);
1545         }
1546
1547         rscp->sc_target = target;
1548
1549         CDEBUG(D_SEC, "%s: server create rsc %p(%u->%s)\n",
1550                target->obd_name, rscp, rscp->sc_ctx.gsc_uid,
1551                libcfs_nidstr(&req->rq_peer.nid));
1552
1553         if (rsip->si_out_handle.len > PTLRPC_GSS_MAX_HANDLE_SIZE) {
1554                 CERROR("%s: handle size %u too large\n",
1555                        target->obd_name, rsip->si_out_handle.len);
1556                 GOTO(out, rc = SECSVC_DROP);
1557         }
1558
1559         grctx->src_init = 1;
1560         grctx->src_reserve_len = round_up(rsip->si_out_token.len, 4);
1561
1562         rc = lustre_pack_reply_v2(req, 1, &replen, NULL, 0);
1563         if (rc) {
1564                 CERROR("%s: failed to pack reply: rc = %d\n",
1565                        target->obd_name, rc);
1566                 GOTO(out, rc = SECSVC_DROP);
1567         }
1568
1569         rs = req->rq_reply_state;
1570         LASSERT(rs->rs_repbuf->lm_bufcount == 3);
1571         LASSERT(rs->rs_repbuf->lm_buflens[0] >=
1572                 sizeof(*rephdr) + rsip->si_out_handle.len);
1573         LASSERT(rs->rs_repbuf->lm_buflens[2] >= rsip->si_out_token.len);
1574
1575         rephdr = lustre_msg_buf(rs->rs_repbuf, 0, 0);
1576         rephdr->gh_version = PTLRPC_GSS_VERSION;
1577         rephdr->gh_flags = 0;
1578         rephdr->gh_proc = PTLRPC_GSS_PROC_ERR;
1579         rephdr->gh_major = rsip->si_major_status;
1580         rephdr->gh_minor = rsip->si_minor_status;
1581         rephdr->gh_seqwin = GSS_SEQ_WIN;
1582         rephdr->gh_handle.len = rsip->si_out_handle.len;
1583         memcpy(rephdr->gh_handle.data, rsip->si_out_handle.data,
1584                rsip->si_out_handle.len);
1585
1586         memcpy(lustre_msg_buf(rs->rs_repbuf, 2, 0), rsip->si_out_token.data,
1587                rsip->si_out_token.len);
1588
1589         rs->rs_repdata_len = lustre_shrink_msg(rs->rs_repbuf, 2,
1590                                                rsip->si_out_token.len, 0);
1591
1592         rc = SECSVC_OK;
1593
1594 out:
1595         if (!IS_ERR_OR_NULL(rsip))
1596                 rsi_entry_put(rsicache, rsip);
1597         if (!IS_ERR_OR_NULL(rscp)) {
1598                 /* if anything went wrong, we don't keep the context too */
1599                 if (rc != SECSVC_OK)
1600                         UC_CACHE_SET_INVALID(rscp->sc_uc_entry);
1601                 else
1602                         CDEBUG(D_SEC, "%s: create rsc with idx %#llx\n",
1603                                target->obd_name,
1604                                gss_handle_to_u64(&rscp->sc_handle));
1605
1606                 rsc_entry_put(rsccache, rscp);
1607         }
1608         RETURN(rc);
1609 }
1610
1611 struct gss_svc_ctx *gss_svc_upcall_get_ctx(struct ptlrpc_request *req,
1612                                            struct gss_wire_ctx *gw)
1613 {
1614         struct gss_rsc *rscp;
1615
1616         rscp = gss_svc_searchbyctx(&gw->gw_handle);
1617         if (IS_ERR_OR_NULL(rscp)) {
1618                 CWARN("Invalid gss ctx idx %#llx from %s\n",
1619                       gss_handle_to_u64(&gw->gw_handle),
1620                       libcfs_nidstr(&req->rq_peer.nid));
1621                 return NULL;
1622         }
1623
1624         return &rscp->sc_ctx;
1625 }
1626
1627 void gss_svc_upcall_put_ctx(struct gss_svc_ctx *ctx)
1628 {
1629         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
1630
1631         rsc_entry_put(rsccache, rscp);
1632 }
1633
1634 void gss_svc_upcall_destroy_ctx(struct gss_svc_ctx *ctx)
1635 {
1636         struct gss_rsc *rscp = container_of(ctx, struct gss_rsc, sc_ctx);
1637
1638         UC_CACHE_SET_INVALID(rscp->sc_uc_entry);
1639         rscp->sc_uc_entry->ue_expire = 1;
1640 }
1641
1642 /* Wait for userspace daemon to open socket, approx 1.5 s.
1643  * If socket is not open, upcall requests might fail.
1644  */
1645 static int check_gssd_socket(void)
1646 {
1647         struct sockaddr_un *sun;
1648         struct socket *sock;
1649         int tries = 0;
1650         int err;
1651
1652 #ifdef HAVE_SOCK_CREATE_KERN_USE_NET
1653         err = sock_create_kern(current->nsproxy->net_ns,
1654                                AF_UNIX, SOCK_STREAM, 0, &sock);
1655 #else
1656         err = sock_create_kern(AF_UNIX, SOCK_STREAM, 0, &sock);
1657 #endif
1658         if (err < 0) {
1659                 CDEBUG(D_SEC, "Failed to create socket: %d\n", err);
1660                 return err;
1661         }
1662
1663         OBD_ALLOC(sun, sizeof(*sun));
1664         if (!sun) {
1665                 sock_release(sock);
1666                 return -ENOMEM;
1667         }
1668         memset(sun, 0, sizeof(*sun));
1669         sun->sun_family = AF_UNIX;
1670         strncpy(sun->sun_path, GSS_SOCKET_PATH, sizeof(sun->sun_path));
1671
1672         /* Try to connect to the socket */
1673         while (tries++ < 6) {
1674                 err = kernel_connect(sock, (struct sockaddr *)sun,
1675                                      sizeof(*sun), 0);
1676                 if (!err)
1677                         break;
1678                 schedule_timeout_uninterruptible(cfs_time_seconds(1) / 4);
1679         }
1680         if (err < 0)
1681                 CDEBUG(D_SEC, "Failed to connect to socket: %d\n", err);
1682         else
1683                 kernel_sock_shutdown(sock, SHUT_RDWR);
1684
1685         sock_release(sock);
1686         OBD_FREE(sun, sizeof(*sun));
1687         return err;
1688 }
1689
1690 int __init gss_init_svc_upcall(void)
1691 {
1692         int rc;
1693
1694         /*
1695          * this helps reducing context index confliction. after server reboot,
1696          * conflicting request from clients might be filtered out by initial
1697          * sequence number checking, thus no chance to sent error notification
1698          * back to clients.
1699          */
1700         get_random_bytes(&__ctx_index, sizeof(__ctx_index));
1701
1702 #ifdef HAVE_CACHE_HEAD_HLIST
1703         for (rc = 0; rc < rsi_cache.hash_size; rc++)
1704                 INIT_HLIST_HEAD(&rsi_cache.hash_table[rc]);
1705 #endif
1706         rc = cache_register_net(&rsi_cache, &init_net);
1707         if (rc != 0)
1708                 return rc;
1709
1710 #ifdef HAVE_CACHE_HEAD_HLIST
1711         for (rc = 0; rc < rsc_cache.hash_size; rc++)
1712                 INIT_HLIST_HEAD(&rsc_cache.hash_table[rc]);
1713 #endif
1714         rc = cache_register_net(&rsc_cache, &init_net);
1715         if (rc != 0) {
1716                 cache_unregister_net(&rsi_cache, &init_net);
1717                 return rc;
1718         }
1719
1720         rsicache = upcall_cache_init(RSI_CACHE_NAME, RSI_UPCALL_PATH,
1721                                      UC_RSICACHE_HASH_SIZE,
1722                                      3600, /* entry expire: 1 h */
1723                                      30, /* acquire expire: 30 s */
1724                                      false, /* can't replay acquire */
1725                                      &rsi_upcall_cache_ops);
1726         if (IS_ERR(rsicache)) {
1727                 rc = PTR_ERR(rsicache);
1728                 rsicache = NULL;
1729                 return rc;
1730         }
1731         rsccache = upcall_cache_init(RSC_CACHE_NAME, RSC_UPCALL_PATH,
1732                                      UC_RSCCACHE_HASH_SIZE,
1733                                      3600, /* replaced with one from mech */
1734                                      100, /* arbitrary, not used */
1735                                      false, /* can't replay acquire */
1736                                      &rsc_upcall_cache_ops);
1737         if (IS_ERR(rsccache)) {
1738                 upcall_cache_cleanup(rsicache);
1739                 rsicache = NULL;
1740                 rc = PTR_ERR(rsccache);
1741                 rsccache = NULL;
1742                 return rc;
1743         }
1744
1745         if (check_gssd_socket())
1746                 CDEBUG(D_SEC,
1747                        "Init channel not opened by lsvcgssd, GSS might not work on server side until daemon is active\n");
1748
1749         return 0;
1750 }
1751
1752 void gss_exit_svc_upcall(void)
1753 {
1754         cache_purge(&rsi_cache);
1755         cache_unregister_net(&rsi_cache, &init_net);
1756
1757         cache_purge(&rsc_cache);
1758         cache_unregister_net(&rsc_cache, &init_net);
1759
1760         upcall_cache_cleanup(rsicache);
1761         upcall_cache_cleanup(rsccache);
1762 }