Whamcloud - gitweb
LU-3289 gss: Add userspace support for GSS null and sk
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2   svc_in_gssd_proc.c
3
4   Copyright (c) 2000 The Regents of the University of Michigan.
5   All rights reserved.
6
7   Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8
9   Copyright (c) 2014, Intel Corporation.
10
11   Redistribution and use in source and binary forms, with or without
12   modification, are permitted provided that the following conditions
13   are met:
14
15   1. Redistributions of source code must retain the above copyright
16      notice, this list of conditions and the following disclaimer.
17   2. Redistributions in binary form must reproduce the above copyright
18      notice, this list of conditions and the following disclaimer in the
19      documentation and/or other materials provided with the distribution.
20   3. Neither the name of the University nor the names of its
21      contributors may be used to endorse or promote products derived
22      from this software without specific prior written permission.
23
24   THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
25   WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
26   MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27   DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
28   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
29   CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
30   SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
31   BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
32   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
33   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
34   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
36 */
37
38 #include <sys/param.h>
39 #include <sys/stat.h>
40
41 #include <inttypes.h>
42 #include <pwd.h>
43 #include <stdio.h>
44 #include <unistd.h>
45 #include <ctype.h>
46 #include <string.h>
47 #include <fcntl.h>
48 #include <errno.h>
49 #ifdef HAVE_NETDB_H
50 # include <netdb.h>
51 #endif
52
53 #include <stdbool.h>
54 #include <lnet/nidstr.h>
55
56 #include "svcgssd.h"
57 #include "gss_util.h"
58 #include "err_util.h"
59 #include "context.h"
60 #include "cacheio.h"
61 #include "lsupport.h"
62 #include "gss_oids.h"
63 #include "sk_utils.h"
64 #include <lustre/lustre_idl.h>
65
66 #define SVCGSSD_CONTEXT_CHANNEL "/proc/net/rpc/auth.sptlrpc.context/channel"
67 #define SVCGSSD_INIT_CHANNEL    "/proc/net/rpc/auth.sptlrpc.init/channel"
68
69 #define TOKEN_BUF_SIZE          8192
70
71 struct svc_cred {
72         uint32_t cr_remote;
73         uint32_t cr_usr_root;
74         uint32_t cr_usr_mds;
75         uint32_t cr_usr_oss;
76         uid_t    cr_uid;
77         uid_t    cr_mapped_uid;
78         uid_t    cr_gid;
79 };
80
81 struct svc_nego_data {
82         /* kernel data*/
83         uint32_t        lustre_svc;
84         lnet_nid_t      nid;
85         uint64_t        handle_seq;
86         char            nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
87         gss_buffer_desc in_tok;
88         gss_buffer_desc out_tok;
89         gss_buffer_desc in_handle;
90         gss_buffer_desc out_handle;
91         uint32_t        maj_stat;
92         uint32_t        min_stat;
93
94         /* userspace data */
95         gss_OID                 mech;
96         gss_ctx_id_t            ctx;
97         gss_buffer_desc         ctx_token;
98 };
99
100 static int
101 do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
102                 gss_OID mechoid, gss_buffer_desc *context_token)
103 {
104         FILE *f;
105         const char *mechname;
106         int err;
107
108         printerr(2, "doing downcall\n");
109         mechname = gss_OID_mech_name(mechoid);
110         if (mechname == NULL)
111                 goto out_err;
112         f = fopen(SVCGSSD_CONTEXT_CHANNEL, "w");
113         if (f == NULL) {
114                 printerr(0, "WARNING: unable to open downcall channel "
115                              "%s: %s\n",
116                              SVCGSSD_CONTEXT_CHANNEL, strerror(errno));
117                 goto out_err;
118         }
119         qword_printhex(f, out_handle->value, out_handle->length);
120         /* XXX are types OK for the rest of this? */
121         qword_printint(f, 3600); /* an hour should be sufficient */
122         qword_printint(f, cred->cr_remote);
123         qword_printint(f, cred->cr_usr_root);
124         qword_printint(f, cred->cr_usr_mds);
125         qword_printint(f, cred->cr_usr_oss);
126         qword_printint(f, cred->cr_mapped_uid);
127         qword_printint(f, cred->cr_uid);
128         qword_printint(f, cred->cr_gid);
129         qword_print(f, mechname);
130         qword_printhex(f, context_token->value, context_token->length);
131         err = qword_eol(f);
132         fclose(f);
133         return err;
134 out_err:
135         printerr(0, "WARNING: downcall failed\n");
136         return -1;
137 }
138
139 struct gss_verifier {
140         u_int32_t       flav;
141         gss_buffer_desc body;
142 };
143
144 #define RPCSEC_GSS_SEQ_WIN      5
145
146 static int
147 send_response(FILE *f, gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
148               u_int32_t maj_stat, u_int32_t min_stat,
149               gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
150 {
151         char buf[2 * TOKEN_BUF_SIZE];
152         char *bp = buf;
153         int blen = sizeof(buf);
154         /* XXXARG: */
155         int g;
156
157         printerr(2, "sending reply\n");
158         qword_addhex(&bp, &blen, in_handle->value, in_handle->length);
159         qword_addhex(&bp, &blen, in_token->value, in_token->length);
160         qword_addint(&bp, &blen, 3600); /* an hour should be sufficient */
161         qword_adduint(&bp, &blen, maj_stat);
162         qword_adduint(&bp, &blen, min_stat);
163         qword_addhex(&bp, &blen, out_handle->value, out_handle->length);
164         qword_addhex(&bp, &blen, out_token->value, out_token->length);
165         qword_addeol(&bp, &blen);
166         if (blen <= 0) {
167                 printerr(0, "WARNING: send_response: message too long\n");
168                 return -1;
169         }
170         g = open(SVCGSSD_INIT_CHANNEL, O_WRONLY);
171         if (g == -1) {
172                 printerr(0, "WARNING: open %s failed: %s\n",
173                                 SVCGSSD_INIT_CHANNEL, strerror(errno));
174                 return -1;
175         }
176         *bp = '\0';
177         printerr(3, "writing message: %s", buf);
178         if (write(g, buf, bp - buf) == -1) {
179                 printerr(0, "WARNING: failed to write message\n");
180                 close(g);
181                 return -1;
182         }
183         close(g);
184         return 0;
185 }
186
187 #define rpc_auth_ok                     0
188 #define rpc_autherr_badcred             1
189 #define rpc_autherr_rejectedcred        2
190 #define rpc_autherr_badverf             3
191 #define rpc_autherr_rejectedverf        4
192 #define rpc_autherr_tooweak             5
193 #define rpcsec_gsserr_credproblem       13
194 #define rpcsec_gsserr_ctxproblem        14
195
196 #if 0
197 static void
198 add_supplementary_groups(char *secname, char *name, struct svc_cred *cred)
199 {
200         int ret;
201         static gid_t *groups = NULL;
202
203         cred->cr_ngroups = NGROUPS;
204         ret = nfs4_gss_princ_to_grouplist(secname, name,
205                         cred->cr_groups, &cred->cr_ngroups);
206         if (ret < 0) {
207                 groups = realloc(groups, cred->cr_ngroups*sizeof(gid_t));
208                 ret = nfs4_gss_princ_to_grouplist(secname, name,
209                                 groups, &cred->cr_ngroups);
210                 if (ret < 0)
211                         cred->cr_ngroups = 0;
212                 else {
213                         if (cred->cr_ngroups > NGROUPS)
214                                 cred->cr_ngroups = NGROUPS;
215                         memcpy(cred->cr_groups, groups,
216                                         cred->cr_ngroups*sizeof(gid_t));
217                 }
218         }
219 }
220 #endif
221
222 #if 0
223 static int
224 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred)
225 {
226         u_int32_t       maj_stat, min_stat;
227         gss_buffer_desc name;
228         char            *sname;
229         int             res = -1;
230         uid_t           uid, gid;
231         gss_OID         name_type = GSS_C_NO_OID;
232         char            *secname;
233
234         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
235         if (maj_stat != GSS_S_COMPLETE) {
236                 pgsserr("get_ids: gss_display_name",
237                         maj_stat, min_stat, mech);
238                 goto out;
239         }
240         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
241             !(sname = calloc(name.length + 1, 1))) {
242                 printerr(0, "WARNING: get_ids: error allocating %d bytes "
243                         "for sname\n", name.length + 1);
244                 gss_release_buffer(&min_stat, &name);
245                 goto out;
246         }
247         memcpy(sname, name.value, name.length);
248         printerr(1, "sname = %s\n", sname);
249         gss_release_buffer(&min_stat, &name);
250
251         res = -EINVAL;
252         if ((secname = mech2file(mech)) == NULL) {
253                 printerr(0, "WARNING: get_ids: error mapping mech to "
254                         "file for name '%s'\n", sname);
255                 goto out_free;
256         }
257         nfs4_init_name_mapping(NULL); /* XXX: should only do this once */
258         res = nfs4_gss_princ_to_ids(secname, sname, &uid, &gid);
259         if (res < 0) {
260                 /*
261                  * -ENOENT means there was no mapping, any other error
262                  * value means there was an error trying to do the
263                  * mapping.
264                  * If there was no mapping, we send down the value -1
265                  * to indicate that the anonuid/anongid for the export
266                  * should be used.
267                  */
268                 if (res == -ENOENT) {
269                         cred->cr_uid = -1;
270                         cred->cr_gid = -1;
271                         cred->cr_ngroups = 0;
272                         res = 0;
273                         goto out_free;
274                 }
275                 printerr(0, "WARNING: get_ids: failed to map name '%s' "
276                         "to uid/gid: %s\n", sname, strerror(-res));
277                 goto out_free;
278         }
279         cred->cr_uid = uid;
280         cred->cr_gid = gid;
281         add_supplementary_groups(secname, sname, cred);
282         res = 0;
283 out_free:
284         free(sname);
285 out:
286         return res;
287 }
288 #endif
289
290 #if 0
291 void
292 print_hexl(int pri, unsigned char *cp, int length)
293 {
294         int i, j, jm;
295         unsigned char c;
296
297         printerr(pri, "length %d\n",length);
298         printerr(pri, "\n");
299
300         for (i = 0; i < length; i += 0x10) {
301                 printerr(pri, "  %04x: ", (unsigned int)i);
302                 jm = length - i;
303                 jm = jm > 16 ? 16 : jm;
304
305                 for (j = 0; j < jm; j++) {
306                         if ((j % 2) == 1)
307                                 printerr(pri, "%02x ", (unsigned int)cp[i+j]);
308                         else
309                                 printerr(pri, "%02x", (unsigned int)cp[i+j]);
310                 }
311                 for (; j < 16; j++) {
312                         if ((j % 2) == 1)
313                                 printerr(pri,"   ");
314                         else
315                                 printerr(pri,"  ");
316                 }
317                 printerr(pri," ");
318
319                 for (j = 0; j < jm; j++) {
320                         c = cp[i+j];
321                         c = isprint(c) ? c : '.';
322                         printerr(pri,"%c", c);
323                 }
324                 printerr(pri,"\n");
325         }
326 }
327 #endif
328
329 static int
330 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
331         lnet_nid_t nid, uint32_t lustre_svc)
332 {
333         u_int32_t       maj_stat, min_stat;
334         gss_buffer_desc name;
335         char            *sname, *host, *realm;
336         const int       namebuf_size = 512;
337         char            namebuf[namebuf_size];
338         int             res = -1;
339         gss_OID         name_type = GSS_C_NO_OID;
340         struct passwd   *pw;
341
342         cred->cr_remote = 0;
343         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
344         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
345
346         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
347         if (maj_stat != GSS_S_COMPLETE) {
348                 pgsserr("get_ids: gss_display_name",
349                         maj_stat, min_stat, mech);
350                 return -1;
351         }
352         if (name.length >= 0xffff || /* be certain name.length+1 doesn't overflow */
353             !(sname = calloc(name.length + 1, 1))) {
354                 printerr(0, "WARNING: get_ids: error allocating %zu bytes "
355                         "for sname\n", name.length + 1);
356                 gss_release_buffer(&min_stat, &name);
357                 return -1;
358         }
359         memcpy(sname, name.value, name.length);
360         sname[name.length] = '\0';
361         gss_release_buffer(&min_stat, &name);
362
363         if (lustre_svc == LUSTRE_GSS_SVC_MDS)
364                 lookup_mapping(sname, nid, &cred->cr_mapped_uid);
365         else
366                 cred->cr_mapped_uid = -1;
367
368         realm = strchr(sname, '@');
369         if (realm) {
370                 *realm++ = '\0';
371         } else {
372                 printerr(0, "ERROR: %s has no realm name\n", sname);
373                 goto out_free;
374         }
375
376         host = strchr(sname, '/');
377         if (host)
378                 *host++ = '\0';
379
380         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
381                 printerr(0, "forbid %s as a user name\n", sname);
382                 goto out_free;
383         }
384
385         /* 1. check host part */
386         if (host) {
387                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
388                         printerr(0, "ERROR: failed to resolve hostname for "
389                                  "%s/%s@%s from %016llx\n",
390                                  sname, host, realm, nid);
391                         goto out_free;
392                 }
393
394                 if (strcasecmp(host, namebuf)) {
395                         printerr(0, "ERROR: %s/%s@%s claimed hostname doesn't "
396                                  "match %s, nid %016llx\n", sname, host, realm,
397                                  namebuf, nid);
398                         goto out_free;
399                 }
400         } else {
401                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
402                     !strcmp(sname, GSSD_SERVICE_OSS)) {
403                         printerr(0, "ERROR: %s@%s from %016llx doesn't "
404                                  "bind with hostname\n", sname, realm, nid);
405                         goto out_free;
406                 }
407         }
408
409         /* 2. check realm and user */
410         switch (lustre_svc) {
411         case LUSTRE_GSS_SVC_MDS:
412                 if (strcasecmp(mds_local_realm, realm)) {
413                         cred->cr_remote = 1;
414
415                         /* only allow mapped user from remote realm */
416                         if (cred->cr_mapped_uid == -1) {
417                                 printerr(0, "ERROR: %s%s%s@%s from %016llx "
418                                          "is remote but without mapping\n",
419                                          sname, host ? "/" : "",
420                                          host ? host : "", realm, nid);
421                                 break;
422                         }
423                 } else {
424                         if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
425                                 cred->cr_uid = 0;
426                                 cred->cr_usr_root = 1;
427                         } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
428                                 cred->cr_uid = 0;
429                                 cred->cr_usr_mds = 1;
430                         } else if (!strcmp(sname, GSSD_SERVICE_OSS)) {
431                                 cred->cr_uid = 0;
432                                 cred->cr_usr_oss = 1;
433                         } else {
434                                 pw = getpwnam(sname);
435                                 if (pw != NULL) {
436                                         cred->cr_uid = pw->pw_uid;
437                                         printerr(2, "%s resolve to uid %u\n",
438                                                  sname, cred->cr_uid);
439                                 } else if (cred->cr_mapped_uid != -1) {
440                                         printerr(2, "user %s from %016llx is "
441                                                  "mapped to %u\n", sname, nid,
442                                                  cred->cr_mapped_uid);
443                                 } else {
444                                         printerr(0, "ERROR: invalid user, "
445                                                  "%s/%s@%s from %016llx\n",
446                                                  sname, host, realm, nid);
447                                         break;
448                                 }
449                         }
450                 }
451
452                 res = 0;
453                 break;
454         case LUSTRE_GSS_SVC_MGS:
455                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
456                         cred->cr_uid = 0;
457                         cred->cr_usr_oss = 1;
458                 }
459                 /* fall through */
460         case LUSTRE_GSS_SVC_OSS:
461                 if (!strcmp(sname, LUSTRE_ROOT_NAME)) {
462                         cred->cr_uid = 0;
463                         cred->cr_usr_root = 1;
464                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
465                         cred->cr_uid = 0;
466                         cred->cr_usr_mds = 1;
467                 }
468                 if (cred->cr_uid == -1) {
469                         printerr(0, "ERROR: svc %d doesn't accept user %s "
470                                  "from %016llx\n", lustre_svc, sname, nid);
471                         break;
472                 }
473                 res = 0;
474                 break;
475         default:
476                 assert(0);
477         }
478
479 out_free:
480         if (!res)
481                 printerr(1, "%s: authenticated %s%s%s@%s from %016llx\n",
482                          lustre_svc_name[lustre_svc], sname,
483                          host ? "/" : "", host ? host : "", realm, nid);
484         free(sname);
485         return res;
486 }
487
488 typedef struct gss_union_ctx_id_t {
489         gss_OID         mech_type;
490         gss_ctx_id_t    internal_ctx_id;
491 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
492
493 int handle_sk(struct svc_nego_data *snd)
494 {
495         struct sk_cred *skc = NULL;
496         struct svc_cred cred;
497         gss_buffer_desc bufs[7];
498         gss_buffer_desc remote_pub_key = GSS_C_EMPTY_BUFFER;
499         char *target;
500         uint32_t rc = GSS_S_FAILURE;
501         uint32_t flags;
502         int numbufs = 7;
503         int i;
504
505         printerr(3, "Handling sk request\n");
506
507         /* See lgss_sk_using_cred() for client side token
508          * bufs returned are in this order:
509          * bufs[0] - iv
510          * bufs[1] - p
511          * bufs[2] - remote_pub_key
512          * bufs[3] - target
513          * bufs[4] - nodemap_hash
514          * bufs[5] - flags
515          * bufs[6] - hmac */
516         i = sk_decode_netstring(bufs, numbufs, &snd->in_tok);
517         if (i < numbufs) {
518                 printerr(0, "Invalid netstring token received from peer\n");
519                 rc = GSS_S_DEFECTIVE_TOKEN;
520                 goto out_err;
521         }
522
523         /* target must be a null terminated string */
524         i = bufs[3].length - 1;
525         target = bufs[3].value;
526         if (i >= 0 && target[i] != '\0') {
527                 printerr(0, "Invalid target from netstring\n");
528                 for (i = 0; i < numbufs; i++)
529                         free(bufs[i].value);
530                 goto out_err;
531         }
532
533         memcpy(&flags, bufs[5].value, sizeof(flags));
534         skc = sk_create_cred(target, snd->nm_name, be32_to_cpu(flags));
535         if (!skc) {
536                 printerr(0, "Failed to create sk credentials\n");
537                 for (i = 0; i < numbufs; i++)
538                         free(bufs[i].value);
539                 goto out_err;
540         }
541
542         /* Take control of all the allocated buffers from decoding */
543         skc->sc_kctx.skc_iv = bufs[0];
544         skc->sc_p = bufs[1];
545         remote_pub_key = bufs[2];
546         skc->sc_nodemap_hash = bufs[4];
547         skc->sc_hmac = bufs[6];
548
549         /* Verify that the peer has used a key size greater to or equal
550          * the size specified by the key file */
551         if (skc->sc_flags & LGSS_SVC_PRIV &&
552             skc->sc_p.length < skc->sc_session_keylen) {
553                 printerr(0, "Peer DH parameters do not meet the size required "
554                          "by keyfile\n");
555                 goto out_err;
556         }
557
558         /* Verify HMAC from peer.  Ideally this would happen before anything
559          * else but we don't have enough information to lookup key without the
560          * token (fsname and cluster_hash) so it's done shortly after. */
561         rc = sk_verify_hmac(skc, bufs, numbufs - 1, EVP_sha256(),
562                             &skc->sc_hmac);
563         free(bufs[3].value);
564         free(bufs[5].value);
565         if (rc != GSS_S_COMPLETE) {
566                 printerr(0, "HMAC verification error: 0x%x from peer %s\n",
567                          rc, libcfs_nid2str((lnet_nid_t) snd->nid));
568                 goto out_err;
569         }
570
571         /* Check that the cluster hash matches the hash of nodemap name */
572         rc = sk_verify_hash(snd->nm_name, EVP_sha256(), &skc->sc_nodemap_hash);
573         if (rc != GSS_S_COMPLETE) {
574                 printerr(0, "Cluster hash failed validation: 0x%x\n", rc);
575                 goto out_err;
576         }
577
578         rc = sk_gen_params(skc, false);
579         if (rc != GSS_S_COMPLETE) {
580                 printerr(0, "Failed to generate DH params for responder\n");
581                 goto out_err;
582         }
583         if (sk_compute_key(skc, &remote_pub_key)) {
584                 printerr(0, "Failed to compute session key from DH params\n");
585                 goto out_err;
586         }
587         if (sk_kdf(skc, snd->nid, &snd->in_tok)) {
588                 printerr(0, "Failed to calulate derviced session key\n");
589                 goto out_err;
590         }
591         if (sk_serialize_kctx(skc, &snd->ctx_token)) {
592                 printerr(0, "Failed to serialize context for kernel\n");
593                 goto out_err;
594         }
595
596         /* Server reply only contains the servers public key and HMAC */
597         bufs[0] = skc->sc_pub_key;
598         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs, 1, EVP_sha256(),
599                          &skc->sc_hmac)) {
600                 printerr(0, "Failed to sign parameters\n");
601                 goto out_err;
602         }
603         bufs[1] = skc->sc_hmac;
604         if (sk_encode_netstring(bufs, 2, &snd->out_tok)) {
605                 printerr(0, "Failed to encode netstring for token\n");
606                 goto out_err;
607         }
608
609         printerr(2, "Created netstring of %zd bytes\n", snd->out_tok.length);
610
611         snd->out_handle.length = sizeof(snd->handle_seq);
612         memcpy(snd->out_handle.value, &snd->handle_seq,
613                sizeof(snd->handle_seq));
614         snd->maj_stat = GSS_S_COMPLETE;
615
616         /* fix credentials */
617         memset(&cred, 0, sizeof(cred));
618         cred.cr_mapped_uid = -1;
619
620         if (skc->sc_flags & LGSS_ROOT_CRED_ROOT)
621                 cred.cr_usr_root = 1;
622         if (skc->sc_flags & LGSS_ROOT_CRED_MDT)
623                 cred.cr_usr_mds = 1;
624         if (skc->sc_flags & LGSS_ROOT_CRED_OST)
625                 cred.cr_usr_oss = 1;
626
627         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
628
629         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
630         free(remote_pub_key.value);
631         free(snd->ctx_token.value);
632         snd->ctx_token.length = 0;
633
634         printerr(3, "sk returning success\n");
635         return 0;
636
637 out_err:
638         snd->maj_stat = rc;
639         if (remote_pub_key.value)
640                 free(remote_pub_key.value);
641         if (snd->ctx_token.value)
642                 free(snd->ctx_token.value);
643         snd->ctx_token.length = 0;
644
645         if (skc)
646                 sk_free_cred(skc);
647         printerr(3, "sk returning failure\n");
648         return -1;
649 }
650
651 int handle_null(struct svc_nego_data *snd)
652 {
653         struct svc_cred cred;
654         uint64_t tmp;
655         uint32_t flags;
656
657         /* null just uses the same token as the return token and for
658          * for sending to the kernel.  It is a single uint64_t. */
659         if (snd->in_tok.length != sizeof(uint64_t)) {
660                 snd->maj_stat = GSS_S_DEFECTIVE_TOKEN;
661                 printerr(0, "Invalid token size (%zd) received\n",
662                          snd->in_tok.length);
663                 return -1;
664         }
665         snd->out_tok.length = snd->in_tok.length;
666         snd->out_tok.value = malloc(snd->out_tok.length);
667         if (!snd->out_tok.value) {
668                 snd->maj_stat = GSS_S_FAILURE;
669                 printerr(0, "Failed to allocate out_tok\n");
670                 return -1;
671         }
672
673         snd->ctx_token.length = snd->in_tok.length;
674         snd->ctx_token.value = malloc(snd->ctx_token.length);
675         if (!snd->ctx_token.value) {
676                 snd->maj_stat = GSS_S_FAILURE;
677                 printerr(0, "Failed to allocate ctx_token\n");
678                 return -1;
679         }
680
681         snd->out_handle.length = sizeof(snd->handle_seq);
682         memcpy(snd->out_handle.value, &snd->handle_seq,
683                sizeof(snd->handle_seq));
684         snd->maj_stat = GSS_S_COMPLETE;
685
686         memcpy(&tmp, snd->in_tok.value, sizeof(tmp));
687         tmp = be64_to_cpu(tmp);
688         flags = (uint32_t)(tmp & 0x00000000ffffffff);
689         memset(&cred, 0, sizeof(cred));
690         cred.cr_mapped_uid = -1;
691
692         if (flags & LGSS_ROOT_CRED_ROOT)
693                 cred.cr_usr_root = 1;
694         if (flags & LGSS_ROOT_CRED_MDT)
695                 cred.cr_usr_mds = 1;
696         if (flags & LGSS_ROOT_CRED_OST)
697                 cred.cr_usr_oss = 1;
698
699         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
700
701         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
702         free(snd->ctx_token.value);
703         snd->ctx_token.length = 0;
704
705         return 0;
706 }
707
708 static int handle_krb(struct svc_nego_data *snd)
709 {
710         u_int32_t               ret_flags;
711         gss_name_t              client_name;
712         gss_buffer_desc         ignore_out_tok = {.value = NULL};
713         gss_OID                 mech = GSS_C_NO_OID;
714         gss_cred_id_t           svc_cred;
715         u_int32_t               ignore_min_stat;
716         struct svc_cred         cred;
717
718         svc_cred = gssd_select_svc_cred(snd->lustre_svc);
719         if (!svc_cred) {
720                 printerr(0, "no service credential for svc %u\n",
721                          snd->lustre_svc);
722                 goto out_err;
723         }
724
725         snd->maj_stat = gss_accept_sec_context(&snd->min_stat, &snd->ctx,
726                                                svc_cred, &snd->in_tok,
727                                                GSS_C_NO_CHANNEL_BINDINGS,
728                                                &client_name, &mech,
729                                                &snd->out_tok, &ret_flags, NULL,
730                                                NULL);
731
732         if (snd->maj_stat == GSS_S_CONTINUE_NEEDED) {
733                 printerr(1, "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
734
735                 /* Save the context handle for future calls */
736                 snd->out_handle.length = sizeof(snd->ctx);
737                 memcpy(snd->out_handle.value, &snd->ctx, sizeof(snd->ctx));
738                 return 0;
739         } else if (snd->maj_stat != GSS_S_COMPLETE) {
740                 printerr(0, "WARNING: gss_accept_sec_context failed\n");
741                 pgsserr("handle_krb: gss_accept_sec_context",
742                         snd->maj_stat, snd->min_stat, mech);
743                 goto out_err;
744         }
745
746         if (get_ids(client_name, mech, &cred, snd->nid, snd->lustre_svc)) {
747                 /* get_ids() prints error msg */
748                 snd->maj_stat = GSS_S_BAD_NAME; /* XXX ? */
749                 gss_release_name(&ignore_min_stat, &client_name);
750                 goto out_err;
751         }
752         gss_release_name(&ignore_min_stat, &client_name);
753
754         /* Context complete. Pass handle_seq in out_handle to use
755          * for context lookup in the kernel. */
756         snd->out_handle.length = sizeof(snd->handle_seq);
757         memcpy(snd->out_handle.value, &snd->handle_seq,
758                sizeof(snd->handle_seq));
759
760         /* kernel needs ctx to calculate verifier on null response, so
761          * must give it context before doing null call: */
762         if (serialize_context_for_kernel(snd->ctx, &snd->ctx_token, mech)) {
763                 printerr(0, "WARNING: handle_krb: "
764                          "serialize_context_for_kernel failed\n");
765                 snd->maj_stat = GSS_S_FAILURE;
766                 goto out_err;
767         }
768         /* We no longer need the gss context */
769         gss_delete_sec_context(&ignore_min_stat, &snd->ctx, &ignore_out_tok);
770         do_svc_downcall(&snd->out_handle, &cred, mech, &snd->ctx_token);
771
772         return 0;
773
774 out_err:
775         if (snd->ctx != GSS_C_NO_CONTEXT)
776                 gss_delete_sec_context(&ignore_min_stat, &snd->ctx,
777                                        &ignore_out_tok);
778
779         return 1;
780 }
781
782 /*
783  * return -1 only if we detect error during reading from upcall channel,
784  * all other cases return 0.
785  */
786 int handle_channel_request(FILE *f)
787 {
788         char                    in_tok_buf[TOKEN_BUF_SIZE];
789         char                    in_handle_buf[15];
790         char                    out_handle_buf[15];
791         gss_buffer_desc         ctx_token      = {.value = NULL},
792                                 null_token     = {.value = NULL};
793         uint32_t                lustre_mech;
794         static char             *lbuf;
795         static int              lbuflen;
796         static char             *cp;
797         int                     get_len;
798         int                     rc = 1;
799         u_int32_t               ignore_min_stat;
800         struct svc_nego_data    snd = {
801                 .in_tok.value           = in_tok_buf,
802                 .in_handle.value        = in_handle_buf,
803                 .out_handle.value       = out_handle_buf,
804                 .maj_stat               = GSS_S_FAILURE,
805                 .ctx                    = GSS_C_NO_CONTEXT,
806         };
807
808         printerr(2, "handling request\n");
809         if (readline(fileno(f), &lbuf, &lbuflen) != 1) {
810                 printerr(0, "WARNING: failed reading request\n");
811                 return -1;
812         }
813
814         cp = lbuf;
815
816         /* see rsi_request() for the format of data being input here */
817         qword_get(&cp, (char *)&snd.lustre_svc, sizeof(snd.lustre_svc));
818
819         /* lustre_svc is the svc and gss subflavor */
820         lustre_mech = (snd.lustre_svc & LUSTRE_GSS_MECH_MASK) >>
821                       LUSTRE_GSS_MECH_SHIFT;
822         snd.lustre_svc = snd.lustre_svc & LUSTRE_GSS_SVC_MASK;
823         switch (lustre_mech) {
824         case LGSS_MECH_KRB5:
825                 if (!krb_enabled) {
826                         printerr(1, "WARNING: Request for kerberos but service "
827                                  "support not enabled\n");
828                         goto ignore;
829                 }
830                 snd.mech = &krb5oid;
831                 break;
832         case LGSS_MECH_NULL:
833                 if (!null_enabled) {
834                         printerr(1, "WARNING: Request for gssnull but service "
835                                  "support not enabled\n");
836                         goto ignore;
837                 }
838                 snd.mech = &nulloid;
839                 break;
840         case LGSS_MECH_SK:
841                 if (!sk_enabled) {
842                         printerr(1, "WARNING: Request for sk but service "
843                                  "support not enabled\n");
844                         goto ignore;
845                 }
846                 snd.mech = &skoid;
847                 break;
848         default:
849                 printerr(0, "WARNING: invalid mechanism recevied: %d\n",
850                          lustre_mech);
851                 goto out_err;
852                 break;
853         }
854
855         qword_get(&cp, (char *)&snd.nid, sizeof(snd.nid));
856         qword_get(&cp, (char *)&snd.handle_seq, sizeof(snd.handle_seq));
857         qword_get(&cp, snd.nm_name, sizeof(snd.nm_name));
858         printerr(2, "handling req: svc %u, nid %016llx, idx %"PRIx64" nodemap "
859                  "%s\n", snd.lustre_svc, snd.nid, snd.handle_seq, snd.nm_name);
860
861         get_len = qword_get(&cp, snd.in_handle.value, sizeof(in_handle_buf));
862         if (get_len < 0) {
863                 printerr(0, "WARNING: failed parsing request\n");
864                 goto out_err;
865         }
866         snd.in_handle.length = (size_t)get_len;
867
868         printerr(3, "in_handle:\n");
869         print_hexl(3, snd.in_handle.value, snd.in_handle.length);
870
871         get_len = qword_get(&cp, snd.in_tok.value, sizeof(in_tok_buf));
872         if (get_len < 0) {
873                 printerr(0, "WARNING: failed parsing request\n");
874                 goto out_err;
875         }
876         snd.in_tok.length = (size_t)get_len;
877
878         printerr(3, "in_tok:\n");
879         print_hexl(3, snd.in_tok.value, snd.in_tok.length);
880
881         if (snd.in_handle.length != 0) { /* CONTINUE_INIT case */
882                 if (snd.in_handle.length != sizeof(snd.ctx)) {
883                         printerr(0, "WARNING: input handle has unexpected "
884                                  "length %zu\n", snd.in_handle.length);
885                         goto out_err;
886                 }
887                 /* in_handle is the context id stored in the out_handle
888                  * for the GSS_S_CONTINUE_NEEDED case below.  */
889                 memcpy(&snd.ctx, snd.in_handle.value, snd.in_handle.length);
890         }
891
892         if (lustre_mech == LGSS_MECH_KRB5)
893                 rc = handle_krb(&snd);
894         else if (lustre_mech == LGSS_MECH_SK)
895                 rc = handle_sk(&snd);
896         else if (lustre_mech == LGSS_MECH_NULL)
897                 rc = handle_null(&snd);
898         else
899                 printerr(0, "WARNING: Received or request for"
900                          "subflavor that is not enabled: %d\n", lustre_mech);
901
902 out_err:
903         /* Failures send a null token */
904         if (rc == 0)
905                 send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
906                               snd.min_stat, &snd.out_handle, &snd.out_tok);
907         else
908                 send_response(f, &snd.in_handle, &snd.in_tok, snd.maj_stat,
909                               snd.min_stat, &null_token, &null_token);
910
911         /* cleanup buffers */
912         if (snd.ctx_token.value != NULL)
913                 free(ctx_token.value);
914         if (snd.out_tok.value != NULL)
915                 gss_release_buffer(&ignore_min_stat, &snd.out_tok);
916
917         /* For junk wire data just ignore */
918 ignore:
919         return 0;
920 }