Whamcloud - gitweb
456e19d8f81dc4ff0075dfdfdcac6103f191550a
[fs/lustre-release.git] / lustre / utils / gss / svcgssd_proc.c
1 /*
2  * svc_in_gssd_proc.c
3  *
4  * Copyright (c) 2000 The Regents of the University of Michigan.
5  * All rights reserved.
6  *
7  * Copyright (c) 2002 Bruce Fields <bfields@UMICH.EDU>
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  * 3. Neither the name of the University nor the names of its
19  *    contributors may be used to endorse or promote products derived
20  *    from this software without specific prior written permission.
21  *
22  * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
23  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
24  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
25  * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
26  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
27  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
28  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
29  * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30  * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
31  * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
32  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33  */
34
35 #include <sys/param.h>
36 #include <sys/stat.h>
37
38 #include <inttypes.h>
39 #include <pwd.h>
40 #include <stdio.h>
41 #include <unistd.h>
42 #include <ctype.h>
43 #include <string.h>
44 #include <fcntl.h>
45 #include <errno.h>
46 #ifdef HAVE_NETDB_H
47 # include <netdb.h>
48 #endif
49
50 #include <stdbool.h>
51
52 #include "svcgssd.h"
53 #include "gss_util.h"
54 #include "err_util.h"
55 #include "context.h"
56 #include "cacheio.h"
57 #include "lsupport.h"
58 #include "gss_oids.h"
59 #include <time.h>
60 #include <linux/lustre/lustre_idl.h>
61 #include "sk_utils.h"
62 #include <sys/time.h>
63 #include <gssapi/gssapi_krb5.h>
64 #include <libcfs/util/param.h>
65
66 struct svc_cred {
67         uint32_t cr_remote;
68         uint32_t cr_usr_root;
69         uint32_t cr_usr_mds;
70         uint32_t cr_usr_oss;
71         uid_t    cr_uid;
72         uid_t    cr_mapped_uid;
73         uid_t    cr_gid;
74 };
75
76 struct svc_nego_data {
77         /* kernel data*/
78         uint32_t        lustre_svc;
79         lnet_nid_t      nid;
80         uint64_t        handle_seq;
81         char            nm_name[LUSTRE_NODEMAP_NAME_LENGTH + 1];
82         gss_buffer_desc in_tok;
83         gss_buffer_desc out_tok;
84         gss_buffer_desc in_handle;
85         gss_buffer_desc out_handle;
86         uint32_t        maj_stat;
87         uint32_t        min_stat;
88
89         /* userspace data */
90         gss_OID                 mech;
91         gss_ctx_id_t            ctx;
92         gss_buffer_desc         ctx_token;
93 };
94
95 static int do_svc_downcall(gss_buffer_desc *out_handle, struct svc_cred *cred,
96                            gss_OID mechoid, gss_buffer_desc *ctx_token)
97 {
98         struct rsc_downcall_data *rsc_dd;
99         int blen, fd, size, rc = -1;
100         const char *mechname;
101         glob_t path;
102         char *bp;
103
104         printerr(LL_INFO, "doing downcall\n");
105
106         size = out_handle->length + sizeof(__u32) +
107                 ctx_token->length + sizeof(__u32);
108         blen = size;
109
110         size += offsetof(struct rsc_downcall_data, scd_val[0]);
111         rsc_dd = calloc(1, size);
112         if (!rsc_dd) {
113                 printerr(LL_ERR, "malloc downcall data (%d) failed\n", size);
114                 return -ENOMEM;
115         }
116         rsc_dd->scd_magic = RSC_DOWNCALL_MAGIC;
117         rsc_dd->scd_err = 0;
118
119         rsc_dd->scd_flags =
120                 (cred->cr_remote ? RSC_DATA_FLAG_REMOTE : 0) |
121                 (cred->cr_usr_root ? RSC_DATA_FLAG_ROOT : 0) |
122                 (cred->cr_usr_mds ? RSC_DATA_FLAG_MDS : 0) |
123                 (cred->cr_usr_oss ? RSC_DATA_FLAG_OSS : 0);
124         rsc_dd->scd_mapped_uid = cred->cr_mapped_uid;
125         rsc_dd->scd_uid = cred->cr_uid;
126         rsc_dd->scd_gid = cred->cr_gid;
127         mechname = gss_OID_mech_name(mechoid);
128         if (mechname == NULL)
129                 goto out;
130         if (snprintf(rsc_dd->scd_mechname, sizeof(rsc_dd->scd_mechname),
131                      "%s", mechname) >= sizeof(rsc_dd->scd_mechname))
132                 goto out;
133
134         bp = rsc_dd->scd_val;
135         gss_buffer_write(&bp, &blen, out_handle->value, out_handle->length);
136         gss_buffer_write(&bp, &blen, ctx_token->value, ctx_token->length);
137         if (blen < 0) {
138                 printerr(LL_ERR, "ERROR: %s: message too long > %d\n",
139                          __func__, size);
140                 rc = -EMSGSIZE;
141                 goto out;
142         }
143         rsc_dd->scd_len = bp - rsc_dd->scd_val;
144
145         rc = cfs_get_param_paths(&path, RSC_DOWNCALL_PATH);
146         if (rc != 0) {
147                 rc = -errno;
148                 goto out;
149         }
150
151         fd = open(path.gl_pathv[0], O_WRONLY);
152         if (fd == -1) {
153                 rc = -errno;
154                 printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
155                          __func__, RSC_DOWNCALL_PATH, strerror(-rc));
156                 goto out_path;
157         }
158         size = offsetof(struct rsc_downcall_data,
159                         scd_val[bp - rsc_dd->scd_val]);
160         printerr(LL_DEBUG, "writing downcall data, size %d\n", size);
161         if (write(fd, rsc_dd, size) == -1) {
162                 rc = -errno;
163                 printerr(LL_ERR, "ERROR: %s failed: %s\n",
164                          __func__, strerror(-rc));
165         }
166         printerr(LL_DEBUG, "downcall data written ok\n");
167
168         close(fd);
169 out_path:
170         cfs_free_param_data(&path);
171 out:
172         free(rsc_dd);
173         if (rc)
174                 printerr(LL_ERR, "ERROR: downcall failed\n");
175         return rc;
176 }
177
178 #define RPCSEC_GSS_SEQ_WIN      5
179
180 static int send_response(int auth_res, __u64 hash,
181                         gss_buffer_desc *in_handle, gss_buffer_desc *in_token,
182                         u_int32_t maj_stat, u_int32_t min_stat,
183                         gss_buffer_desc *out_handle, gss_buffer_desc *out_token)
184 {
185         struct rsi_downcall_data *rsi_dd;
186         int blen, fd, size, rc = 0;
187         glob_t path;
188         char *bp;
189
190         printerr(LL_INFO, "sending reply\n");
191
192         size = in_handle->length + sizeof(__u32) +
193                 in_token->length + sizeof(__u32) +
194                 sizeof(__u32) + sizeof(__u32);
195         if (!auth_res)
196                 size += out_handle->length + out_token->length;
197         blen = size;
198
199         size += offsetof(struct rsi_downcall_data, sid_val[0]);
200         rsi_dd = calloc(1, size);
201         if (!rsi_dd) {
202                 printerr(LL_ERR, "malloc downcall data (%d) failed\n", size);
203                 return -ENOMEM;
204         }
205         rsi_dd->sid_magic = RSI_DOWNCALL_MAGIC;
206         rsi_dd->sid_hash = hash;
207         rsi_dd->sid_maj_stat = maj_stat;
208         rsi_dd->sid_min_stat = min_stat;
209
210         bp = rsi_dd->sid_val;
211         gss_buffer_write(&bp, &blen, in_handle->value, in_handle->length);
212         gss_buffer_write(&bp, &blen, in_token->value, in_token->length);
213         if (!auth_res) {
214                 gss_buffer_write(&bp, &blen, out_handle->value,
215                                  out_handle->length);
216                 gss_buffer_write(&bp, &blen, out_token->value,
217                                  out_token->length);
218         } else {
219                 rsi_dd->sid_err = -EACCES;
220                 gss_buffer_write(&bp, &blen, NULL, 0);
221                 gss_buffer_write(&bp, &blen, NULL, 0);
222         }
223         if (blen < 0) {
224                 printerr(LL_ERR, "ERROR: %s: message too long > %d\n",
225                          __func__, size);
226                 rc = -EMSGSIZE;
227                 goto out;
228         }
229         rsi_dd->sid_len = bp - rsi_dd->sid_val;
230
231         rc = cfs_get_param_paths(&path, RSI_DOWNCALL_PATH);
232         if (rc != 0) {
233                 rc = -errno;
234                 printerr(LL_ERR, "ERROR: %s: cannot get param path %s: %s\n",
235                          __func__, RSI_DOWNCALL_PATH, strerror(-rc));
236                 goto out;
237         }
238
239         fd = open(path.gl_pathv[0], O_WRONLY);
240         if (fd == -1) {
241                 rc = -errno;
242                 printerr(LL_ERR, "ERROR: %s: open %s failed: %s\n",
243                          __func__, RSI_DOWNCALL_PATH, strerror(-rc));
244                 goto out_path;
245         }
246         size = offsetof(struct rsi_downcall_data,
247                         sid_val[bp - rsi_dd->sid_val]);
248         printerr(LL_DEBUG, "writing response, size %d\n", size);
249         if (write(fd, rsi_dd, size) == -1) {
250                 rc = -errno;
251                 printerr(LL_ERR, "ERROR: %s failed: %s\n",
252                          __func__, strerror(-rc));
253         }
254         printerr(LL_DEBUG, "response written ok\n");
255
256         close(fd);
257 out_path:
258         cfs_free_param_data(&path);
259 out:
260         free(rsi_dd);
261         return rc;
262 }
263
264 #define rpc_auth_ok                     0
265 #define rpc_autherr_badcred             1
266 #define rpc_autherr_rejectedcred        2
267 #define rpc_autherr_badverf             3
268 #define rpc_autherr_rejectedverf        4
269 #define rpc_autherr_tooweak             5
270 #define rpcsec_gsserr_credproblem       13
271 #define rpcsec_gsserr_ctxproblem        14
272
273 static int lookup_localname(gss_name_t client_name, char *princ, lnet_nid_t nid,
274                             uid_t *uid)
275 {
276         u_int32_t maj_stat, min_stat;
277         gss_buffer_desc localname;
278         char *sname;
279         int rc = -1;
280
281         *uid = -1;
282         maj_stat = gss_localname(&min_stat, client_name, GSS_C_NO_OID,
283                                  &localname);
284         if (maj_stat != GSS_S_COMPLETE) {
285                 printerr(LL_INFO, "no local name for %s/%#Lx\n", princ, nid);
286                 return rc;
287         }
288
289         sname = calloc(localname.length + 1, 1);
290         if (!sname) {
291                 printerr(LL_ERR, "%s: error allocating %zu bytes\n",
292                          __func__, localname.length + 1);
293                 goto free;
294         }
295         memcpy(sname, localname.value, localname.length);
296         sname[localname.length] = '\0';
297
298         *uid = parse_uid(sname);
299         free(sname);
300         printerr(LL_WARN, "found local uid: %s ==> %d\n", princ, *uid);
301         rc = 0;
302
303 free:
304         gss_release_buffer(&min_stat, &localname);
305         return rc;
306 }
307
308 static int lookup_id(gss_name_t client_name, char *princ, lnet_nid_t nid,
309                      uid_t *uid)
310 {
311         if (!mapping_empty())
312                 return lookup_mapping(princ, nid, uid);
313
314         return lookup_localname(client_name, princ, nid, uid);
315 }
316
317 static int
318 get_ids(gss_name_t client_name, gss_OID mech, struct svc_cred *cred,
319         lnet_nid_t nid, uint32_t lustre_svc)
320 {
321         u_int32_t       maj_stat, min_stat;
322         gss_buffer_desc name;
323         char            *sname, *host, *realm;
324         const int       namebuf_size = 512;
325         char            namebuf[namebuf_size];
326         int             res = -1;
327         gss_OID         name_type = GSS_C_NO_OID;
328         struct passwd   *pw;
329
330         cred->cr_remote = 0;
331         cred->cr_usr_root = cred->cr_usr_mds = cred->cr_usr_oss = 0;
332         cred->cr_uid = cred->cr_mapped_uid = cred->cr_gid = -1;
333
334         maj_stat = gss_display_name(&min_stat, client_name, &name, &name_type);
335         if (maj_stat != GSS_S_COMPLETE) {
336                 pgsserr("get_ids: gss_display_name",
337                         maj_stat, min_stat, mech);
338                 return -1;
339         }
340         /* be certain name.length+1 doesn't overflow */
341         if (name.length >= 0xffff ||
342             !(sname = calloc(name.length + 1, 1))) {
343                 printerr(LL_ERR,
344                          "ERROR: %s: error allocating %zu bytes for sname\n",
345                          __func__, name.length + 1);
346                 gss_release_buffer(&min_stat, &name);
347                 return -1;
348         }
349         memcpy(sname, name.value, name.length);
350         sname[name.length] = '\0';
351         gss_release_buffer(&min_stat, &name);
352
353         if (lustre_svc == LUSTRE_GSS_SVC_MDS &&
354             lookup_id(client_name, sname, nid, &cred->cr_mapped_uid))
355                 printerr(LL_DEBUG, "no id found for %s\n", sname);
356
357         realm = strchr(sname, '@');
358         if (realm) {
359                 *realm++ = '\0';
360         } else {
361                 printerr(LL_ERR, "ERROR: %s has no realm name\n", sname);
362                 goto out_free;
363         }
364
365         host = strchr(sname, '/');
366         if (host)
367                 *host++ = '\0';
368
369         if (strcmp(sname, GSSD_SERVICE_MGS) == 0) {
370                 printerr(LL_ERR, "forbid %s as a user name\n", sname);
371                 goto out_free;
372         }
373
374         /* 1. check host part */
375         if (host) {
376                 if (lnet_nid2hostname(nid, namebuf, namebuf_size)) {
377                         printerr(LL_ERR,
378                                  "ERROR: failed to resolve hostname for %s/%s@%s from %016llx\n",
379                                  sname, host, realm, nid);
380                         goto out_free;
381                 }
382
383                 if (strcasecmp(host, namebuf)) {
384                         printerr(LL_ERR,
385                                  "ERROR: %s/%s@%s claimed hostname doesn't match %s, nid %016llx\n",
386                                  sname, host, realm,
387                                  namebuf, nid);
388                         goto out_free;
389                 }
390         } else {
391                 if (!strcmp(sname, GSSD_SERVICE_MDS) ||
392                     !strcmp(sname, GSSD_SERVICE_OSS)) {
393                         printerr(LL_ERR,
394                                  "ERROR: %s@%s from %016llx doesn't bind with hostname\n",
395                                  sname, realm, nid);
396                         goto out_free;
397                 }
398         }
399
400         /* 2. check realm and user */
401         switch (lustre_svc) {
402         case LUSTRE_GSS_SVC_MDS:
403                 if (strcasecmp(mds_local_realm, realm) != 0) {
404                         /* Remote realm case */
405                         cred->cr_remote = 1;
406
407                         /* Prevent access to unmapped user from remote realm */
408                         if (cred->cr_mapped_uid == -1) {
409                                 printerr(LL_ERR,
410                                          "ERROR: %s%s%s@%s from %016llx is remote but without mapping\n",
411                                          sname, host ? "/" : "",
412                                          host ? host : "", realm, nid);
413                                 break;
414                         }
415                         goto valid;
416                 }
417
418                 /* Now we know we are dealing with a local realm */
419
420                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
421                     !strcmp(sname, GSSD_SERVICE_HOST)) {
422                         cred->cr_uid = 0;
423                         cred->cr_usr_root = 1;
424                         goto valid;
425                 }
426                 if (!strcmp(sname, GSSD_SERVICE_MDS)) {
427                         cred->cr_uid = 0;
428                         cred->cr_usr_mds = 1;
429                         goto valid;
430                 }
431                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
432                         cred->cr_uid = 0;
433                         cred->cr_usr_oss = 1;
434                         goto valid;
435                 }
436                 if (cred->cr_mapped_uid != -1) {
437                         printerr(LL_INFO,
438                                  "user %s from %016llx is mapped to %u\n",
439                                  sname, nid,
440                                  cred->cr_mapped_uid);
441                         goto valid;
442                 }
443                 pw = getpwnam(sname);
444                 if (pw != NULL) {
445                         cred->cr_uid = pw->pw_uid;
446                         printerr(LL_INFO, "%s resolve to uid %u\n",
447                                  sname, cred->cr_uid);
448                         goto valid;
449                 }
450                 printerr(LL_ERR, "ERROR: invalid user, %s/%s@%s from %016llx\n",
451                          sname, host, realm, nid);
452                 break;
453
454 valid:
455                 res = 0;
456                 break;
457         case LUSTRE_GSS_SVC_MGS:
458                 if (!strcmp(sname, GSSD_SERVICE_OSS)) {
459                         cred->cr_uid = 0;
460                         cred->cr_usr_oss = 1;
461                 }
462                 fallthrough;
463         case LUSTRE_GSS_SVC_OSS:
464                 if (!strcmp(sname, LUSTRE_ROOT_NAME) ||
465                     !strcmp(sname, GSSD_SERVICE_HOST)) {
466                         cred->cr_uid = 0;
467                         cred->cr_usr_root = 1;
468                 } else if (!strcmp(sname, GSSD_SERVICE_MDS)) {
469                         cred->cr_uid = 0;
470                         cred->cr_usr_mds = 1;
471                 }
472                 if (cred->cr_uid == -1) {
473                         printerr(LL_ERR,
474                                  "ERROR: svc %d doesn't accept user %s from %016llx\n",
475                                  lustre_svc, sname, nid);
476                         break;
477                 }
478                 res = 0;
479                 break;
480         default:
481                 assert(0);
482         }
483
484 out_free:
485         if (!res)
486                 printerr(LL_WARN, "%s: authenticated %s%s%s@%s from %016llx\n",
487                          lustre_svc_name[lustre_svc], sname,
488                          host ? "/" : "", host ? host : "", realm, nid);
489         free(sname);
490         return res;
491 }
492
493 typedef struct gss_union_ctx_id_t {
494         gss_OID         mech_type;
495         gss_ctx_id_t    internal_ctx_id;
496 } gss_union_ctx_id_desc, *gss_union_ctx_id_t;
497
498 int handle_sk(struct svc_nego_data *snd)
499 {
500 #ifdef HAVE_OPENSSL_SSK
501         struct sk_cred *skc = NULL;
502         struct svc_cred cred;
503         gss_buffer_desc bufs[SK_INIT_BUFFERS];
504         gss_buffer_desc remote_pub_key = GSS_C_EMPTY_BUFFER;
505         char *target;
506         uint32_t rc = GSS_S_DEFECTIVE_TOKEN;
507         uint32_t version;
508         uint32_t flags;
509         int i;
510         int attempts = 0;
511
512         printerr(LL_DEBUG, "Handling sk request\n");
513         memset(bufs, 0, sizeof(gss_buffer_desc) * SK_INIT_BUFFERS);
514
515         /* See lgss_sk_using_cred() for client side token formation.
516          * Decoding initiator buffers */
517         i = sk_decode_netstring(bufs, SK_INIT_BUFFERS, &snd->in_tok);
518         if (i < SK_INIT_BUFFERS) {
519                 printerr(LL_ERR,
520                          "Invalid netstring token received from peer\n");
521                 goto cleanup_buffers;
522         }
523
524         /* Allowing for a larger length first buffer in the future */
525         if (bufs[SK_INIT_VERSION].length < sizeof(version)) {
526                 printerr(LL_ERR, "Invalid version received (wrong size)\n");
527                 goto cleanup_buffers;
528         }
529         memcpy(&version, bufs[SK_INIT_VERSION].value, sizeof(version));
530         version = be32toh(version);
531         if (version != SK_MSG_VERSION) {
532                 printerr(LL_ERR, "Invalid version received: %d\n", version);
533                 goto cleanup_buffers;
534         }
535
536         rc = GSS_S_FAILURE;
537
538         /* target must be a null terminated string */
539         i = bufs[SK_INIT_TARGET].length - 1;
540         target = bufs[SK_INIT_TARGET].value;
541         if (i >= 0 && target[i] != '\0') {
542                 printerr(LL_ERR, "Invalid target from netstring\n");
543                 goto cleanup_buffers;
544         }
545
546         if (bufs[SK_INIT_FLAGS].length != sizeof(flags)) {
547                 printerr(LL_ERR, "Invalid flags from netstring\n");
548                 goto cleanup_buffers;
549         }
550         memcpy(&flags, bufs[SK_INIT_FLAGS].value, sizeof(flags));
551
552         skc = sk_create_cred(target, snd->nm_name, be32toh(flags));
553         if (!skc) {
554                 printerr(LL_ERR, "Failed to create sk credentials\n");
555                 goto cleanup_buffers;
556         }
557
558         /* Verify that the peer has used a prime size greater or equal to
559          * the size specified in the key file which may contain only zero
560          * fill but the size specifies the mimimum supported size on
561          * servers */
562         if (skc->sc_flags & LGSS_SVC_PRIV &&
563             bufs[SK_INIT_P].length < skc->sc_p.length) {
564                 printerr(LL_ERR,
565                          "Peer DHKE prime does not meet the size required by keyfile: %zd bits\n",
566                          skc->sc_p.length * 8);
567                 goto cleanup_buffers;
568         }
569
570         /* Throw out the p from the server and use the wire data */
571         free(skc->sc_p.value);
572         skc->sc_p.value = NULL;
573         skc->sc_p.length = 0;
574
575         /* Take control of all the allocated buffers from decoding */
576         if (bufs[SK_INIT_RANDOM].length !=
577             sizeof(skc->sc_kctx.skc_peer_random)) {
578                 printerr(LL_ERR, "Invalid size for client random\n");
579                 goto cleanup_buffers;
580         }
581
582         memcpy(&skc->sc_kctx.skc_peer_random, bufs[SK_INIT_RANDOM].value,
583                sizeof(skc->sc_kctx.skc_peer_random));
584         skc->sc_p = bufs[SK_INIT_P];
585         remote_pub_key = bufs[SK_INIT_PUB_KEY];
586         skc->sc_nodemap_hash = bufs[SK_INIT_NODEMAP];
587         skc->sc_hmac = bufs[SK_INIT_HMAC];
588
589         /* Verify HMAC from peer.  Ideally this would happen before anything
590          * else but we don't have enough information to lookup key without the
591          * token (fsname and cluster_hash) so it's done after. */
592         rc = sk_verify_hmac(skc, bufs, SK_INIT_BUFFERS - 1, EVP_sha256(),
593                             &skc->sc_hmac);
594         if (rc != GSS_S_COMPLETE) {
595                 printerr(LL_ERR, "HMAC verification error: 0x%x from peer %s\n",
596                          rc, libcfs_nid2str((lnet_nid_t)snd->nid));
597                 goto cleanup_partial;
598         }
599
600         /* Check that the cluster hash matches the hash of nodemap name */
601         rc = sk_verify_hash(snd->nm_name, EVP_sha256(), &skc->sc_nodemap_hash);
602         if (rc != GSS_S_COMPLETE) {
603                 printerr(LL_ERR, "Cluster hash failed validation: 0x%x\n", rc);
604                 goto cleanup_partial;
605         }
606
607 redo:
608         rc = sk_gen_params(skc, sk_dh_checks);
609         if (rc != GSS_S_COMPLETE) {
610                 printerr(LL_ERR,
611                          "Failed to generate DH params for responder\n");
612                 goto cleanup_partial;
613         }
614         rc = sk_compute_dh_key(skc, &remote_pub_key);
615         if (rc == GSS_S_BAD_QOP && attempts < 2) {
616                 /* GSS_S_BAD_QOP means the generated shared key was shorter
617                  * than expected. Just retry twice before giving up.
618                  */
619                 attempts++;
620                 if (skc->sc_params) {
621                         EVP_PKEY_free(skc->sc_params);
622                         skc->sc_params = NULL;
623                 }
624                 if (skc->sc_pub_key.value) {
625                         free(skc->sc_pub_key.value);
626                         skc->sc_pub_key.value = NULL;
627                 }
628                 skc->sc_pub_key.length = 0;
629                 if (skc->sc_dh_shared_key.value) {
630                         /* erase secret key before freeing memory */
631                         memset(skc->sc_dh_shared_key.value, 0,
632                                skc->sc_dh_shared_key.length);
633                         free(skc->sc_dh_shared_key.value);
634                         skc->sc_dh_shared_key.value = NULL;
635                 }
636                 skc->sc_dh_shared_key.length = 0;
637                 goto redo;
638         } else if (rc != GSS_S_COMPLETE) {
639                 printerr(LL_ERR,
640                          "Failed to compute session key from DH params\n");
641                 goto cleanup_partial;
642         }
643
644         /* Cleanup init buffers we have copied or don't need anymore */
645         free(bufs[SK_INIT_VERSION].value);
646         free(bufs[SK_INIT_RANDOM].value);
647         free(bufs[SK_INIT_TARGET].value);
648         free(bufs[SK_INIT_FLAGS].value);
649
650         /* Server reply contains the servers public key, random,  and HMAC */
651         version = htobe32(SK_MSG_VERSION);
652         bufs[SK_RESP_VERSION].value = &version;
653         bufs[SK_RESP_VERSION].length = sizeof(version);
654         bufs[SK_RESP_RANDOM].value = &skc->sc_kctx.skc_host_random;
655         bufs[SK_RESP_RANDOM].length = sizeof(skc->sc_kctx.skc_host_random);
656         bufs[SK_RESP_PUB_KEY] = skc->sc_pub_key;
657         if (sk_sign_bufs(&skc->sc_kctx.skc_shared_key, bufs,
658                          SK_RESP_BUFFERS - 1, EVP_sha256(),
659                          &skc->sc_hmac)) {
660                 printerr(LL_ERR, "Failed to sign parameters\n");
661                 goto out_err;
662         }
663         bufs[SK_RESP_HMAC] = skc->sc_hmac;
664         if (sk_encode_netstring(bufs, SK_RESP_BUFFERS, &snd->out_tok)) {
665                 printerr(LL_ERR, "Failed to encode netstring for token\n");
666                 goto out_err;
667         }
668         printerr(LL_INFO, "Created netstring of %zd bytes\n",
669                  snd->out_tok.length);
670
671         if (sk_session_kdf(skc, snd->nid, &snd->in_tok, &snd->out_tok)) {
672                 printerr(LL_ERR, "Failed to calculate derived session key\n");
673                 goto out_err;
674         }
675         if (sk_compute_keys(skc)) {
676                 printerr(LL_ERR,
677                          "Failed to compute HMAC and encryption keys\n");
678                 goto out_err;
679         }
680         if (sk_serialize_kctx(skc, &snd->ctx_token)) {
681                 printerr(LL_ERR, "Failed to serialize context for kernel\n");
682                 goto out_err;
683         }
684
685         snd->out_handle.length = sizeof(snd->handle_seq);
686         memcpy(snd->out_handle.value, &snd->handle_seq,
687                sizeof(snd->handle_seq));
688         snd->maj_stat = GSS_S_COMPLETE;
689
690         /* fix credentials */
691         memset(&cred, 0, sizeof(cred));
692         cred.cr_mapped_uid = -1;
693
694         if (skc->sc_flags & LGSS_ROOT_CRED_ROOT)
695                 cred.cr_usr_root = 1;
696         if (skc->sc_flags & LGSS_ROOT_CRED_MDT)
697                 cred.cr_usr_mds = 1;
698         if (skc->sc_flags & LGSS_ROOT_CRED_OST)
699                 cred.cr_usr_oss = 1;
700
701         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
702
703         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_request */
704         if (remote_pub_key.length != 0) {
705                 free(remote_pub_key.value);
706                 remote_pub_key.value = NULL;
707                 remote_pub_key.length = 0;
708         }
709         if (snd->ctx_token.value) {
710                 free(snd->ctx_token.value);
711                 snd->ctx_token.value = NULL;
712                 snd->ctx_token.length = 0;
713         }
714
715         printerr(LL_DEBUG, "sk returning success\n");
716         return 0;
717
718 cleanup_buffers:
719         for (i = 0; i < SK_INIT_BUFFERS; i++)
720                 free(bufs[i].value);
721         sk_free_cred(skc);
722         snd->maj_stat = rc;
723         return -1;
724
725 cleanup_partial:
726         free(bufs[SK_INIT_VERSION].value);
727         free(bufs[SK_INIT_RANDOM].value);
728         free(bufs[SK_INIT_TARGET].value);
729         free(bufs[SK_INIT_FLAGS].value);
730         if (remote_pub_key.length != 0) {
731                 free(remote_pub_key.value);
732                 remote_pub_key.value = NULL;
733                 remote_pub_key.length = 0;
734         }
735         sk_free_cred(skc);
736         snd->maj_stat = rc;
737         return -1;
738
739 out_err:
740         snd->maj_stat = rc;
741         if (snd->ctx_token.value) {
742                 free(snd->ctx_token.value);
743                 snd->ctx_token.value = NULL;
744                 snd->ctx_token.length = 0;
745         }
746         if (remote_pub_key.length != 0) {
747                 free(remote_pub_key.value);
748                 remote_pub_key.value = NULL;
749                 remote_pub_key.length = 0;
750         }
751         sk_free_cred(skc);
752         printerr(LL_DEBUG, "sk returning failure\n");
753 #else /* !HAVE_OPENSSL_SSK */
754         printerr(LL_ERR, "ERROR: shared key subflavour is not enabled\n");
755 #endif /* HAVE_OPENSSL_SSK */
756         return -1;
757 }
758
759 int handle_null(struct svc_nego_data *snd)
760 {
761         struct svc_cred cred;
762         uint64_t tmp;
763         uint32_t flags;
764
765         /* null just uses the same token as the return token and for
766          * for sending to the kernel.  It is a single uint64_t. */
767         if (snd->in_tok.length != sizeof(uint64_t)) {
768                 snd->maj_stat = GSS_S_DEFECTIVE_TOKEN;
769                 printerr(LL_ERR, "Invalid token size (%zd) received\n",
770                          snd->in_tok.length);
771                 return -1;
772         }
773         snd->out_tok.length = snd->in_tok.length;
774         snd->out_tok.value = malloc(snd->out_tok.length);
775         if (!snd->out_tok.value) {
776                 snd->maj_stat = GSS_S_FAILURE;
777                 printerr(LL_ERR, "Failed to allocate out_tok\n");
778                 return -1;
779         }
780
781         snd->ctx_token.length = snd->in_tok.length;
782         snd->ctx_token.value = malloc(snd->ctx_token.length);
783         if (!snd->ctx_token.value) {
784                 snd->maj_stat = GSS_S_FAILURE;
785                 printerr(LL_ERR, "Failed to allocate ctx_token\n");
786                 return -1;
787         }
788
789         snd->out_handle.length = sizeof(snd->handle_seq);
790         memcpy(snd->out_handle.value, &snd->handle_seq,
791                sizeof(snd->handle_seq));
792         snd->maj_stat = GSS_S_COMPLETE;
793
794         memcpy(&tmp, snd->in_tok.value, sizeof(tmp));
795         tmp = be64toh(tmp);
796         flags = (uint32_t)(tmp & 0x00000000ffffffff);
797         memset(&cred, 0, sizeof(cred));
798         cred.cr_mapped_uid = -1;
799
800         if (flags & LGSS_ROOT_CRED_ROOT)
801                 cred.cr_usr_root = 1;
802         if (flags & LGSS_ROOT_CRED_MDT)
803                 cred.cr_usr_mds = 1;
804         if (flags & LGSS_ROOT_CRED_OST)
805                 cred.cr_usr_oss = 1;
806
807         do_svc_downcall(&snd->out_handle, &cred, snd->mech, &snd->ctx_token);
808
809         /* cleanup ctx_token, out_tok is cleaned up in handle_channel_req */
810         free(snd->ctx_token.value);
811         snd->ctx_token.length = 0;
812
813         return 0;
814 }
815
816 static int handle_krb(struct svc_nego_data *snd)
817 {
818         u_int32_t               ret_flags;
819         gss_name_t              client_name;
820         gss_buffer_desc         ignore_out_tok = {.value = NULL};
821         gss_OID                 mech = GSS_C_NO_OID;
822         gss_cred_id_t           svc_cred;
823         u_int32_t               ignore_min_stat;
824         struct svc_cred         cred;
825
826         svc_cred = gssd_select_svc_cred(snd->lustre_svc);
827         if (!svc_cred) {
828                 printerr(LL_ERR, "no service credential for svc %u\n",
829                          snd->lustre_svc);
830                 goto out_err;
831         }
832
833         snd->maj_stat = gss_accept_sec_context(&snd->min_stat, &snd->ctx,
834                                                svc_cred, &snd->in_tok,
835                                                GSS_C_NO_CHANNEL_BINDINGS,
836                                                &client_name, &mech,
837                                                &snd->out_tok, &ret_flags, NULL,
838                                                NULL);
839
840         if (snd->maj_stat == GSS_S_CONTINUE_NEEDED) {
841                 printerr(LL_WARN,
842                          "gss_accept_sec_context GSS_S_CONTINUE_NEEDED\n");
843
844                 /* Save the context handle for future calls */
845                 snd->out_handle.length = sizeof(snd->ctx);
846                 memcpy(snd->out_handle.value, &snd->ctx, sizeof(snd->ctx));
847                 return 0;
848         } else if (snd->maj_stat != GSS_S_COMPLETE) {
849                 printerr(LL_ERR, "ERROR: gss_accept_sec_context failed\n");
850                 pgsserr("handle_krb: gss_accept_sec_context",
851                         snd->maj_stat, snd->min_stat, mech);
852                 goto out_err;
853         }
854
855         if (get_ids(client_name, mech, &cred, snd->nid, snd->lustre_svc)) {
856                 /* get_ids() prints error msg */
857                 snd->maj_stat = GSS_S_BAD_NAME; /* XXX ? */
858                 gss_release_name(&ignore_min_stat, &client_name);
859                 goto out_err;
860         }
861         gss_release_name(&ignore_min_stat, &client_name);
862
863         /* Context complete. Pass handle_seq in out_handle to use
864          * for context lookup in the kernel. */
865         snd->out_handle.length = sizeof(snd->handle_seq);
866         memcpy(snd->out_handle.value, &snd->handle_seq,
867                sizeof(snd->handle_seq));
868
869         /* kernel needs ctx to calculate verifier on null response, so
870          * must give it context before doing null call: */
871         if (serialize_context_for_kernel(snd->ctx, &snd->ctx_token, mech)) {
872                 printerr(LL_ERR,
873                          "ERROR: %s: serialize_context_for_kernel failed\n",
874                         __func__);
875                 snd->maj_stat = GSS_S_FAILURE;
876                 goto out_err;
877         }
878         /* We no longer need the gss context */
879         gss_delete_sec_context(&ignore_min_stat, &snd->ctx, &ignore_out_tok);
880         do_svc_downcall(&snd->out_handle, &cred, mech, &snd->ctx_token);
881         /* We no longer need the context token */
882         if (snd->ctx_token.value) {
883                 free(snd->ctx_token.value);
884                 snd->ctx_token.value = NULL;
885                 snd->ctx_token.length = 0;
886         }
887         return 0;
888
889 out_err:
890         if (snd->ctx != GSS_C_NO_CONTEXT)
891                 gss_delete_sec_context(&ignore_min_stat, &snd->ctx,
892                                        &ignore_out_tok);
893
894         return 1;
895 }
896
897 int handle_channel_request(int fd)
898 {
899         char in_handle_buf[15];
900         char out_handle_buf[15];
901         uint32_t lustre_mech;
902         static char *lbuf;
903         static int lbuflen;
904         static char *cp;
905         int get_len;
906         int rc;
907         u_int32_t ignore_min_stat;
908         struct svc_nego_data snd = {
909                 .in_tok.value           = NULL,
910                 .in_handle.value        = in_handle_buf,
911                 .out_handle.value       = out_handle_buf,
912                 .maj_stat               = GSS_S_FAILURE,
913                 .ctx                    = GSS_C_NO_CONTEXT,
914         };
915         __u64 hash = 0;
916         __u64 tmp_lustre_svc = 0;
917
918         printerr(LL_INFO, "handling request\n");
919         if (readline(fd, &lbuf, &lbuflen) != 1) {
920                 printerr(LL_ERR, "ERROR: failed reading request\n");
921                 return -1;
922         }
923
924         cp = lbuf;
925
926         /* see rsi_do_upcall() for the format of data being input here */
927         rc = gss_u64_read_string(&cp, &hash);
928         if (rc < 0) {
929                 printerr(LL_ERR, "ERROR: failed parsing request: hash\n");
930                 goto out_err;
931         }
932         rc = gss_u64_read_string(&cp, &tmp_lustre_svc);
933         if (rc < 0) {
934                 printerr(LL_ERR, "ERROR: failed parsing request: lustre svc\n");
935                 goto out_err;
936         }
937         snd.lustre_svc = tmp_lustre_svc;
938         /* lustre_svc is the svc and gss subflavor */
939         lustre_mech = (snd.lustre_svc & LUSTRE_GSS_MECH_MASK) >>
940                 LUSTRE_GSS_MECH_SHIFT;
941         snd.lustre_svc = snd.lustre_svc & LUSTRE_GSS_SVC_MASK;
942         switch (lustre_mech) {
943         case LGSS_MECH_KRB5:
944                 if (!krb_enabled) {
945                         static time_t next_krb;
946
947                         if (time(NULL) > next_krb) {
948                                 printerr(LL_WARN,
949                                          "warning: Request for kerberos but service support not enabled\n");
950                                 next_krb = time(NULL) + 3600;
951                         }
952                         goto ignore;
953                 }
954                 snd.mech = &krb5oid;
955                 break;
956         case LGSS_MECH_NULL:
957                 if (!null_enabled) {
958                         static time_t next_null;
959
960                         if (time(NULL) > next_null) {
961                                 printerr(LL_WARN,
962                                          "warning: Request for gssnull but service support not enabled\n");
963                                 next_null = time(NULL) + 3600;
964                         }
965                         goto ignore;
966                 }
967                 snd.mech = &nulloid;
968                 break;
969         case LGSS_MECH_SK:
970                 if (!sk_enabled) {
971                         static time_t next_ssk;
972
973                         if (time(NULL) > next_ssk) {
974                                 printerr(LL_WARN,
975                                          "warning: Request for SSK but service support not %s\n",
976 #ifdef HAVE_OPENSSL_SSK
977                                          "enabled"
978 #else
979                                          "included"
980 #endif
981                                         );
982                                 next_ssk = time(NULL) + 3600;
983                         }
984
985                         goto ignore;
986                 }
987                 snd.mech = &skoid;
988                 break;
989         default:
990                 printerr(LL_ERR, "WARNING: invalid mechanism recevied: %d\n",
991                          lustre_mech);
992                 goto out_err;
993                 break;
994         }
995
996         rc = gss_u64_read_string(&cp, (__u64 *)&snd.nid);
997         if (rc < 0) {
998                 printerr(LL_ERR, "ERROR: failed parsing request: source nid\n");
999                 goto out_err;
1000         }
1001         rc = gss_u64_read_string(&cp, (__u64 *)&snd.handle_seq);
1002         if (rc < 0) {
1003                 printerr(LL_ERR, "ERROR: failed parsing request: handle seq\n");
1004                 goto out_err;
1005         }
1006         get_len = gss_string_read(&cp, snd.nm_name, sizeof(snd.nm_name), 0);
1007         if (get_len <= 0) {
1008                 printerr(LL_ERR,
1009                          "ERROR: failed parsing request: nodemap name\n");
1010                 goto out_err;
1011         }
1012         snd.nm_name[get_len] = '\0';
1013         printerr(LL_INFO,
1014                  "handling req: svc %u, nid %016llx, idx %"PRIx64" nodemap %s\n",
1015                  snd.lustre_svc, snd.nid, snd.handle_seq, snd.nm_name);
1016
1017         get_len = gss_base64url_decode(&cp, snd.in_handle.value,
1018                                        sizeof(in_handle_buf));
1019         if (get_len < 0) {
1020                 printerr(LL_ERR, "ERROR: failed parsing request: in handle\n");
1021                 goto out_err;
1022         }
1023         snd.in_handle.length = (size_t)get_len;
1024
1025         printerr(LL_DEBUG, "in_handle:\n");
1026         print_hexl(3, snd.in_handle.value, snd.in_handle.length);
1027
1028         snd.in_tok.value = malloc(strlen(cp));
1029         if (!snd.in_tok.value) {
1030                 printerr(LL_ERR, "ERROR: failed alloc for in token\n");
1031                 goto out_err;
1032         }
1033         get_len = gss_base64url_decode(&cp, snd.in_tok.value, strlen(cp));
1034         if (get_len < 0) {
1035                 printerr(LL_ERR, "ERROR: failed parsing request: in token\n");
1036                 goto out_err;
1037         }
1038         snd.in_tok.length = (size_t)get_len;
1039
1040         printerr(LL_DEBUG, "in_tok:\n");
1041         print_hexl(3, snd.in_tok.value, snd.in_tok.length);
1042
1043         if (snd.in_handle.length != 0) { /* CONTINUE_INIT case */
1044                 if (snd.in_handle.length != sizeof(snd.ctx)) {
1045                         printerr(LL_ERR,
1046                                  "ERROR: input handle has unexpected length %zu\n",
1047                                  snd.in_handle.length);
1048                         goto out_err;
1049                 }
1050                 /* in_handle is the context id stored in the out_handle
1051                  * for the GSS_S_CONTINUE_NEEDED case below.  */
1052                 memcpy(&snd.ctx, snd.in_handle.value, snd.in_handle.length);
1053         }
1054
1055         rc = -1;
1056         if (lustre_mech == LGSS_MECH_KRB5)
1057                 rc = handle_krb(&snd);
1058         else if (lustre_mech == LGSS_MECH_SK)
1059                 rc = handle_sk(&snd);
1060         else if (lustre_mech == LGSS_MECH_NULL)
1061                 rc = handle_null(&snd);
1062         else
1063                 printerr(LL_ERR,
1064                          "ERROR: Received or request for subflavor that is not enabled: %d\n",
1065                          lustre_mech);
1066
1067 out_err:
1068         /* Failures send a null token */
1069         rc = send_response(rc, hash, &snd.in_handle, &snd.in_tok,
1070                            snd.maj_stat, snd.min_stat,
1071                            &snd.out_handle, &snd.out_tok);
1072
1073         /* cleanup buffers */
1074         if (snd.in_tok.value)
1075                 free(snd.in_tok.value);
1076         if (snd.out_tok.value != NULL)
1077                 gss_release_buffer(&ignore_min_stat, &snd.out_tok);
1078
1079         /* For junk wire data just ignore */
1080 ignore:
1081         return rc;
1082 }