diff options
Diffstat (limited to 'drivers/vhost/net.c')
-rw-r--r-- | drivers/vhost/net.c | 64 |
1 files changed, 41 insertions, 23 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index 87c216c1e54e..176aa030dc5f 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -64,9 +64,13 @@ enum { VHOST_NET_VQ_MAX = 2, }; +struct vhost_net_virtqueue { + struct vhost_virtqueue vq; +}; + struct vhost_net { struct vhost_dev dev; - struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX]; + struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX]; struct vhost_poll poll[VHOST_NET_VQ_MAX]; /* Number of TX recently submitted. * Protected by tx vq lock. */ @@ -198,7 +202,7 @@ static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success) * read-size critical section for our kind of RCU. */ static void handle_tx(struct vhost_net *net) { - struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX]; + struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_TX].vq; unsigned out, in, s; int head; struct msghdr msg = { @@ -417,7 +421,7 @@ err: * read-size critical section for our kind of RCU. */ static void handle_rx(struct vhost_net *net) { - struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; + struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_RX].vq; unsigned uninitialized_var(in), log; struct vhost_log *vq_log; struct msghdr msg = { @@ -559,17 +563,26 @@ static int vhost_net_open(struct inode *inode, struct file *f) { struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL); struct vhost_dev *dev; + struct vhost_virtqueue **vqs; int r; if (!n) return -ENOMEM; + vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL); + if (!vqs) { + kfree(n); + return -ENOMEM; + } dev = &n->dev; - n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick; - n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick; - r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX); + vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq; + vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq; + n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick; + n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick; + r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX); if (r < 0) { kfree(n); + kfree(vqs); return r; } @@ -584,7 +597,9 @@ static int vhost_net_open(struct inode *inode, struct file *f) static void vhost_net_disable_vq(struct vhost_net *n, struct vhost_virtqueue *vq) { - struct vhost_poll *poll = n->poll + (vq - n->vqs); + struct vhost_net_virtqueue *nvq = + container_of(vq, struct vhost_net_virtqueue, vq); + struct vhost_poll *poll = n->poll + (nvq - n->vqs); if (!vq->private_data) return; vhost_poll_stop(poll); @@ -593,7 +608,9 @@ static void vhost_net_disable_vq(struct vhost_net *n, static int vhost_net_enable_vq(struct vhost_net *n, struct vhost_virtqueue *vq) { - struct vhost_poll *poll = n->poll + (vq - n->vqs); + struct vhost_net_virtqueue *nvq = + container_of(vq, struct vhost_net_virtqueue, vq); + struct vhost_poll *poll = n->poll + (nvq - n->vqs); struct socket *sock; sock = rcu_dereference_protected(vq->private_data, @@ -621,30 +638,30 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n, static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock, struct socket **rx_sock) { - *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX); - *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX); + *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq); + *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq); } static void vhost_net_flush_vq(struct vhost_net *n, int index) { vhost_poll_flush(n->poll + index); - vhost_poll_flush(&n->dev.vqs[index].poll); + vhost_poll_flush(&n->vqs[index].vq.poll); } static void vhost_net_flush(struct vhost_net *n) { vhost_net_flush_vq(n, VHOST_NET_VQ_TX); vhost_net_flush_vq(n, VHOST_NET_VQ_RX); - if (n->dev.vqs[VHOST_NET_VQ_TX].ubufs) { - mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); + if (n->vqs[VHOST_NET_VQ_TX].vq.ubufs) { + mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); n->tx_flush = true; - mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); + mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); /* Wait for all lower device DMAs done. */ - vhost_ubuf_put_and_wait(n->dev.vqs[VHOST_NET_VQ_TX].ubufs); - mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); + vhost_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].vq.ubufs); + mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); n->tx_flush = false; - kref_init(&n->dev.vqs[VHOST_NET_VQ_TX].ubufs->kref); - mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex); + kref_init(&n->vqs[VHOST_NET_VQ_TX].vq.ubufs->kref); + mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex); } } @@ -665,6 +682,7 @@ static int vhost_net_release(struct inode *inode, struct file *f) /* We do an extra flush before freeing memory, * since jobs can re-queue themselves. */ vhost_net_flush(n); + kfree(n->dev.vqs); kfree(n); return 0; } @@ -750,7 +768,7 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) r = -ENOBUFS; goto err; } - vq = n->vqs + index; + vq = &n->vqs[index].vq; mutex_lock(&vq->mutex); /* Verify that ring has been setup correctly. */ @@ -870,10 +888,10 @@ static int vhost_net_set_features(struct vhost_net *n, u64 features) n->dev.acked_features = features; smp_wmb(); for (i = 0; i < VHOST_NET_VQ_MAX; ++i) { - mutex_lock(&n->vqs[i].mutex); - n->vqs[i].vhost_hlen = vhost_hlen; - n->vqs[i].sock_hlen = sock_hlen; - mutex_unlock(&n->vqs[i].mutex); + mutex_lock(&n->vqs[i].vq.mutex); + n->vqs[i].vq.vhost_hlen = vhost_hlen; + n->vqs[i].vq.sock_hlen = sock_hlen; + mutex_unlock(&n->vqs[i].vq.mutex); } vhost_net_flush(n); mutex_unlock(&n->dev.mutex); |