+static int get_enlarged_msgsize(struct lustre_msg *msg,
+ int segment, int newsize)
+{
+ int save, newmsg_size;
+
+ LASSERT(newsize >= msg->lm_buflens[segment]);
+
+ save = msg->lm_buflens[segment];
+ msg->lm_buflens[segment] = newsize;
+ newmsg_size = lustre_msg_size_v2(msg->lm_bufcount, msg->lm_buflens);
+ msg->lm_buflens[segment] = save;
+
+ return newmsg_size;
+}
+
+static int get_enlarged_msgsize2(struct lustre_msg *msg,
+ int segment1, int newsize1,
+ int segment2, int newsize2)
+{
+ int save1, save2, newmsg_size;
+
+ LASSERT(newsize1 >= msg->lm_buflens[segment1]);
+ LASSERT(newsize2 >= msg->lm_buflens[segment2]);
+
+ save1 = msg->lm_buflens[segment1];
+ save2 = msg->lm_buflens[segment2];
+ msg->lm_buflens[segment1] = newsize1;
+ msg->lm_buflens[segment2] = newsize2;
+ newmsg_size = lustre_msg_size_v2(msg->lm_bufcount, msg->lm_buflens);
+ msg->lm_buflens[segment1] = save1;
+ msg->lm_buflens[segment2] = save2;
+
+ return newmsg_size;
+}
+
+static inline int msg_last_seglen(struct lustre_msg *msg)
+{
+ return msg->lm_buflens[msg->lm_bufcount - 1];
+}
+
+static
+int gss_enlarge_reqbuf_auth(struct ptlrpc_sec *sec,
+ struct ptlrpc_request *req,
+ int segment, int newsize)
+{
+ struct lustre_msg *newbuf;
+ int txtsize, sigsize, i;
+ int newmsg_size, newbuf_size;
+
+ /*
+ * embedded msg is at seg 1; signature is at the last seg
+ */
+ LASSERT(req->rq_reqbuf);
+ LASSERT(req->rq_reqbuf_len > req->rq_reqlen);
+ LASSERT(req->rq_reqbuf->lm_bufcount >= 2);
+ LASSERT(lustre_msg_buf(req->rq_reqbuf, 1, 0) == req->rq_reqmsg);
+
+ /* compute new embedded msg size */
+ newmsg_size = get_enlarged_msgsize(req->rq_reqmsg, segment, newsize);
+ LASSERT(newmsg_size >= req->rq_reqbuf->lm_buflens[1]);
+
+ /* compute new wrapper msg size */
+ for (txtsize = 0, i = 0; i < req->rq_reqbuf->lm_bufcount; i++)
+ txtsize += req->rq_reqbuf->lm_buflens[i];
+ txtsize += newmsg_size - req->rq_reqbuf->lm_buflens[1];
+
+ sigsize = gss_cli_payload(req->rq_cli_ctx, txtsize, 0);
+ LASSERT(sigsize >= msg_last_seglen(req->rq_reqbuf));
+ newbuf_size = get_enlarged_msgsize2(req->rq_reqbuf, 1, newmsg_size,
+ req->rq_reqbuf->lm_bufcount - 1,
+ sigsize);
+
+ /* request from pool should always have enough buffer */
+ LASSERT(!req->rq_pool || req->rq_reqbuf_len >= newbuf_size);
+
+ if (req->rq_reqbuf_len < newbuf_size) {
+ newbuf_size = size_roundup_power2(newbuf_size);
+
+ OBD_ALLOC(newbuf, newbuf_size);
+ if (newbuf == NULL)
+ RETURN(-ENOMEM);
+
+ memcpy(newbuf, req->rq_reqbuf, req->rq_reqbuf_len);
+
+ OBD_FREE(req->rq_reqbuf, req->rq_reqbuf_len);
+ req->rq_reqbuf = newbuf;
+ req->rq_reqbuf_len = newbuf_size;
+ req->rq_reqmsg = lustre_msg_buf(req->rq_reqbuf, 1, 0);
+ }
+
+ _sptlrpc_enlarge_msg_inplace(req->rq_reqbuf,
+ req->rq_reqbuf->lm_bufcount - 1, sigsize);
+ _sptlrpc_enlarge_msg_inplace(req->rq_reqbuf, 1, newmsg_size);
+ _sptlrpc_enlarge_msg_inplace(req->rq_reqmsg, segment, newsize);
+
+ req->rq_reqlen = newmsg_size;
+ RETURN(0);
+}
+
+static
+int gss_enlarge_reqbuf_priv(struct ptlrpc_sec *sec,
+ struct ptlrpc_request *req,
+ int segment, int newsize)
+{
+ struct lustre_msg *newclrbuf;
+ int newmsg_size, newclrbuf_size, newcipbuf_size;
+ int buflens[3];
+
+ /*
+ * embedded msg is at seg 0 of clear buffer;
+ * cipher text is at seg 2 of cipher buffer;
+ */
+ LASSERT(req->rq_pool ||
+ (req->rq_reqbuf == NULL && req->rq_reqbuf_len == 0));
+ LASSERT(req->rq_reqbuf == NULL ||
+ (req->rq_pool && req->rq_reqbuf->lm_bufcount == 3));
+ LASSERT(req->rq_clrbuf);
+ LASSERT(req->rq_clrbuf_len > req->rq_reqlen);
+ LASSERT(lustre_msg_buf(req->rq_clrbuf, 0, 0) == req->rq_reqmsg);
+
+ /* compute new embedded msg size */
+ newmsg_size = get_enlarged_msgsize(req->rq_reqmsg, segment, newsize);
+
+ /* compute new clear buffer size */
+ newclrbuf_size = get_enlarged_msgsize(req->rq_clrbuf, 0, newmsg_size);
+ newclrbuf_size += GSS_MAX_CIPHER_BLOCK;
+
+ /* compute new cipher buffer size */
+ buflens[0] = PTLRPC_GSS_HEADER_SIZE;
+ buflens[1] = gss_cli_payload(req->rq_cli_ctx, buflens[0], 0);
+ buflens[2] = gss_cli_payload(req->rq_cli_ctx, newclrbuf_size, 1);
+ newcipbuf_size = lustre_msg_size_v2(3, buflens);
+
+ /*
+ * handle the case that we put both clear buf and cipher buf into
+ * pre-allocated single buffer.
+ */
+ if (unlikely(req->rq_pool) &&
+ req->rq_clrbuf >= req->rq_reqbuf &&
+ (char *) req->rq_clrbuf <
+ (char *) req->rq_reqbuf + req->rq_reqbuf_len) {
+ /*
+ * it couldn't be better we still fit into the
+ * pre-allocated buffer.
+ */
+ if (newclrbuf_size + newcipbuf_size <= req->rq_reqbuf_len) {
+ void *src, *dst;
+
+ /* move clear text backward. */
+ src = req->rq_clrbuf;
+ dst = (char *) req->rq_reqbuf + newcipbuf_size;
+
+ memmove(dst, src, req->rq_clrbuf_len);
+
+ req->rq_clrbuf = (struct lustre_msg *) dst;
+ req->rq_clrbuf_len = newclrbuf_size;
+ req->rq_reqmsg = lustre_msg_buf(req->rq_clrbuf, 0, 0);
+ } else {
+ /*
+ * sadly we have to split out the clear buffer
+ */
+ LASSERT(req->rq_reqbuf_len >= newcipbuf_size);
+ LASSERT(req->rq_clrbuf_len < newclrbuf_size);
+ }
+ }
+
+ if (req->rq_clrbuf_len < newclrbuf_size) {
+ newclrbuf_size = size_roundup_power2(newclrbuf_size);
+
+ OBD_ALLOC(newclrbuf, newclrbuf_size);
+ if (newclrbuf == NULL)
+ RETURN(-ENOMEM);
+
+ memcpy(newclrbuf, req->rq_clrbuf, req->rq_clrbuf_len);
+
+ if (req->rq_reqbuf == NULL ||
+ req->rq_clrbuf < req->rq_reqbuf ||
+ (char *) req->rq_clrbuf >=
+ (char *) req->rq_reqbuf + req->rq_reqbuf_len) {
+ OBD_FREE(req->rq_clrbuf, req->rq_clrbuf_len);
+ }
+
+ req->rq_clrbuf = newclrbuf;
+ req->rq_clrbuf_len = newclrbuf_size;
+ req->rq_reqmsg = lustre_msg_buf(req->rq_clrbuf, 0, 0);
+ }
+
+ _sptlrpc_enlarge_msg_inplace(req->rq_clrbuf, 0, newmsg_size);
+ _sptlrpc_enlarge_msg_inplace(req->rq_reqmsg, segment, newsize);
+ req->rq_reqlen = newmsg_size;
+
+ RETURN(0);
+}
+
+static
+int gss_enlarge_reqbuf(struct ptlrpc_sec *sec,
+ struct ptlrpc_request *req,
+ int segment, int newsize)
+{
+ LASSERT(!req->rq_ctx_init && !req->rq_ctx_fini);
+
+ switch (SEC_FLAVOR_SVC(req->rq_sec_flavor)) {
+ case SPTLRPC_SVC_AUTH:
+ return gss_enlarge_reqbuf_auth(sec, req, segment, newsize);
+ case SPTLRPC_SVC_PRIV:
+ return gss_enlarge_reqbuf_priv(sec, req, segment, newsize);
+ default:
+ LASSERTF(0, "bad flavor %x\n", req->rq_sec_flavor);
+ return 0;
+ }
+}
+