--- /dev/null
+From f9b256237b2682ef81847165a9cdf8465e5ebb16 Mon Sep 17 00:00:00 2001
+From: Greg Edwards <gedwards@ddn.com>
+Date: Thu, 29 Oct 2020 15:10:58 -0600
+Subject: [PATCH 4/4] virtio_ring: add a vring_desc reserve mempool
+
+When submitting large IOs under heavy memory fragmentation, the
+allocation of the indirect vring_desc descriptor array may fail
+for higher order allocations.
+
+Create a small reserve mempool of max-sized vring_desc descriptor
+arrays per-virtqueue. If we fail to allocate a descriptor array
+via kmalloc(), fall back to grabbing one from the preallocated
+reserve pool.
+
+Signed-off-by: Greg Edwards <gedwards@ddn.com>
+---
+ drivers/virtio/virtio_ring.c | 90 ++++++++++++++++++++++++++++++++----
+ 1 file changed, 81 insertions(+), 9 deletions(-)
+
+diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
+index 3e968645388b..58c362186049 100644
+--- a/drivers/virtio/virtio_ring.c
++++ b/drivers/virtio/virtio_ring.c
+@@ -16,6 +16,11 @@
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
+ */
++
++#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
++
++#include <linux/mempool.h>
++#include <linux/scatterlist.h>
+ #include <linux/virtio.h>
+ #include <linux/virtio_ring.h>
+ #include <linux/virtio_config.h>
+@@ -26,6 +29,24 @@
+ #include <linux/hrtimer.h>
+ #include <linux/kmemleak.h>
+ #include <linux/dma-mapping.h>
++
++/*
++ * vring_desc reserve mempool
++ *
++ * If higher-order allocations fail in alloc_indirect(), try to grab a
++ * preallocated, max-sized descriptor array from the per-virtqueue mempool.
++ * Each pool element is sized at (req + rsp + max data + max integrity).
++ */
++#define VRING_DESC_POOL_DEFAULT 16
++#define VRING_DESC_POOL_NR_DESC (1 + 1 + SG_MAX_SEGMENTS + SG_MAX_SEGMENTS)
++#define VRING_DESC_POOL_ELEM_SZ (VRING_DESC_POOL_NR_DESC * \
++ sizeof(struct vring_desc))
++
++static unsigned short vring_desc_pool_sz = VRING_DESC_POOL_DEFAULT;
++module_param_named(vring_desc_pool_sz, vring_desc_pool_sz, ushort, 0444);
++MODULE_PARM_DESC(vring_desc_pool_sz,
++ "Number of elements in indirect descriptor mempool (default: "
++ __stringify(VRING_DESC_POOL_DEFAULT) ")");
+
+ #ifdef DEBUG
+ /* For development, we want to crash whenever the ring is screwed. */
+@@ -59,6 +82,7 @@
+ struct vring_desc_state {
+ void *data; /* Data for callback. */
+ struct vring_desc *indir_desc; /* Indirect descriptor, if any. */
++ bool indir_desc_mempool; /* Allocated from reserve mempool */
+ };
+
+ struct vring_virtqueue {
+@@ -104,6 +128,9 @@ struct vring_virtqueue {
+ ktime_t last_add_time;
+ #endif
+
++ /* Descriptor reserve mempool */
++ mempool_t *vring_desc_pool;
++
+ /* Per-descriptor state. */
+ struct vring_desc_state desc_state[];
+ };
+@@ -231,10 +258,13 @@ static int vring_mapping_error(const struct vring_virtqueue *vq,
+ }
+
+ static struct vring_desc *alloc_indirect(struct virtqueue *_vq,
+- unsigned int total_sg, gfp_t gfp)
++ unsigned int total_sg, gfp_t gfp,
++ int head)
+ {
++ struct vring_virtqueue *vq = to_vvq(_vq);
+ struct vring_desc *desc;
+ unsigned int i;
++ size_t size = total_sg * sizeof(struct vring_desc);
+
+ /*
+ * We require lowmem mappings for the descriptors because
+@@ -242,16 +272,43 @@ static struct vring_desc *alloc_indirect(struct virtqueue *_vq,
+ * virtqueue.
+ */
+ gfp &= ~__GFP_HIGHMEM;
+-
+- desc = kmalloc(total_sg * sizeof(struct vring_desc), gfp);
+- if (!desc)
+- return NULL;
++ gfp |= __GFP_NOWARN;
++
++ desc = kmalloc(size, gfp);
++ if (!desc) {
++ if (vq->vring_desc_pool) {
++ /* try to get a buffer from the reserve pool */
++ if (WARN_ON_ONCE(size > VRING_DESC_POOL_ELEM_SZ))
++ return NULL;
++ desc = mempool_alloc(vq->vring_desc_pool, gfp);
++ if (!desc) {
++ pr_warn_ratelimited(
++ "reserve indirect desc alloc failed\n");
++ return NULL;
++ }
++ vq->desc_state[head].indir_desc_mempool = true;
++ } else {
++ pr_warn_ratelimited("indirect desc alloc failed\n");
++ return NULL;
++ }
++ }
+
+ for (i = 0; i < total_sg; i++)
+ desc[i].next = cpu_to_virtio16(_vq->vdev, i + 1);
+ return desc;
+ }
+
++void free_indirect(struct vring_virtqueue *vq, struct vring_desc *desc,
++ int head)
++{
++ if (!vq->desc_state[head].indir_desc_mempool) {
++ kfree(desc);
++ } else {
++ mempool_free(desc, vq->vring_desc_pool);
++ vq->desc_state[head].indir_desc_mempool = 0;
++ }
++}
++
+ static inline int virtqueue_add(struct virtqueue *_vq,
+ struct scatterlist *sgs[],
+ unsigned int total_sg,
+@@ -296,7 +353,7 @@ static inline int virtqueue_add(struct virtqueue *_vq,
+ /* If the host supports indirect descriptor tables, and we have multiple
+ * buffers, then go indirect. FIXME: tune this threshold */
+ if (vq->indirect && total_sg > 1 && vq->vq.num_free)
+- desc = alloc_indirect(_vq, total_sg, gfp);
++ desc = alloc_indirect(_vq, total_sg, gfp, head);
+ else {
+ desc = NULL;
+ WARN_ON_ONCE(total_sg > vq->vring.num && !vq->indirect);
+@@ -324,7 +381,7 @@ static inline int virtqueue_add(struct virtqueue *_vq,
+ if (out_sgs)
+ vq->notify(&vq->vq);
+ if (indirect)
+- kfree(desc);
++ free_indirect(vq, desc, head);
+ END_USE(vq);
+ return -ENOSPC;
+ }
+@@ -407,7 +464,7 @@ unmap_release:
+ vq->vq.num_free += total_sg;
+
+ if (indirect)
+- kfree(desc);
++ free_indirect(vq, desc, head);
+
+ return -EIO;
+ }
+@@ -630,7 +687,7 @@ static void detach_buf(struct vring_virtqueue *vq, unsigned int head)
+ for (j = 0; j < len / sizeof(struct vring_desc); j++)
+ vring_unmap_one(vq, &indir_desc[j]);
+
+- kfree(vq->desc_state[head].indir_desc);
++ free_indirect(vq, vq->desc_state[head].indir_desc, head);
+ vq->desc_state[head].indir_desc = NULL;
+ }
+ }
+@@ -907,6 +964,15 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
+ if (!vq)
+ return NULL;
+
++ if (vring_desc_pool_sz) {
++ vq->vring_desc_pool = mempool_create_node(vring_desc_pool_sz,
++ mempool_kmalloc, mempool_kfree,
++ (void *)VRING_DESC_POOL_ELEM_SZ,
++ GFP_KERNEL, numa_node_id());
++ if (!vq->vring_desc_pool)
++ goto err;
++ }
++
+ vq->vring = vring;
+ vq->vq.callback = callback;
+ vq->vq.vdev = vdev;
+@@ -941,6 +1007,10 @@ struct virtqueue *__vring_new_virtqueue(unsigned int index,
+ memset(vq->desc_state, 0, vring.num * sizeof(struct vring_desc_state));
+
+ return &vq->vq;
++
++err:
++ kfree(vq);
++ return NULL;
+ }
+ EXPORT_SYMBOL_GPL(__vring_new_virtqueue);
+
+@@ -1076,6 +1146,8 @@ void vring_del_virtqueue(struct virtqueue *_vq)
+ vq->vring.desc, vq->queue_dma_addr);
+ }
+ list_del(&_vq->list);
++ if (vq->vring_desc_pool)
++ mempool_destroy(vq->vring_desc_pool);
+ kfree(vq);
+ }
+ EXPORT_SYMBOL_GPL(vring_del_virtqueue);
+--
+2.28.0
+