diff --git a/compute/src/main/java/org/zstack/compute/vm/VmInstanceBase.java b/compute/src/main/java/org/zstack/compute/vm/VmInstanceBase.java index ed98aef34e0..0d49674f152 100755 --- a/compute/src/main/java/org/zstack/compute/vm/VmInstanceBase.java +++ b/compute/src/main/java/org/zstack/compute/vm/VmInstanceBase.java @@ -2396,6 +2396,14 @@ public void run(FlowTrigger trigger, Map data) { int deviceId = deviceIdBitmap.nextClearBit(0); deviceIdBitmap.set(deviceId); String internalName = VmNicVO.generateNicInternalName(spec.getVmInventory().getInternalId(), deviceId); + String driverType = vmNicVO.getDriverType(); + if (StringUtils.isEmpty(driverType)) { + VmNicInventory nicInv = VmNicInventory.valueOf(vmNicVO); + nicManager.setNicDriverType(nicInv, VmSystemTags.VIRTIO.hasTag(self.getUuid()), + ImagePlatform.valueOf(spec.getVmInventory().getPlatform()).isParaVirtualization(), + spec.getVmInventory()); + driverType = nicInv.getDriverType(); + } UpdateQuery.New(VmNicVO.class) .eq(VmNicVO_.uuid, vmNicUuid) @@ -2403,14 +2411,14 @@ public void run(FlowTrigger trigger, Map data) { .set(VmNicVO_.deviceId, deviceId) .set(VmNicVO_.internalName, internalName) .set(VmNicVO_.hypervisorType, spec.getVmInventory().getHypervisorType()) + .set(VmNicVO_.driverType, driverType) .update(); vmNicVO.setVmInstanceUuid(self.getUuid()); vmNicVO.setDeviceId(deviceId); vmNicVO.setInternalName(internalName); vmNicVO.setHypervisorType(spec.getVmInventory().getHypervisorType()); - vmNicVO.setDriverType(VmSystemTags.VIRTIO.hasTag(self.getUuid()) ? - nicManager.getDefaultPVNicDriver() : nicManager.getDefaultNicDriver()); + vmNicVO.setDriverType(driverType); spec.getDestNics().add(0, VmNicInventory.valueOf(vmNicVO)); trigger.next(); @@ -8984,4 +8992,3 @@ public void run(MessageReply reply) { }); } } - diff --git a/compute/src/main/java/org/zstack/compute/vm/VmNicManagerImpl.java b/compute/src/main/java/org/zstack/compute/vm/VmNicManagerImpl.java index df2bc6bcaf7..4fa8352e42b 100644 --- a/compute/src/main/java/org/zstack/compute/vm/VmNicManagerImpl.java +++ b/compute/src/main/java/org/zstack/compute/vm/VmNicManagerImpl.java @@ -1,6 +1,5 @@ package org.zstack.compute.vm; -import com.google.common.collect.Maps; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Autowire; import org.springframework.beans.factory.annotation.Autowired; @@ -119,39 +118,40 @@ public void afterDelIpAddress(String vmNicUUid, String usedIpUuid) { @Override public void prepareDbInitialValue() { - List nics = Q.New(VmNicVO.class).notNull(VmNicVO_.vmInstanceUuid).list(); - List ns = nics.stream() - .filter(v -> v.getDriverType() == null - && v.getType().equals(VmInstanceConstant.VIRTUAL_NIC_TYPE) - && v.getVmInstanceUuid() != null) - .collect(Collectors.toList()); + List nics = Q.New(VmNicVO.class) + .isNull(VmNicVO_.driverType) + .eq(VmNicVO_.type, VmInstanceConstant.VIRTUAL_NIC_TYPE) + .notNull(VmNicVO_.vmInstanceUuid) + .list(); - if (CollectionUtils.isEmpty(ns)) { + if (CollectionUtils.isEmpty(nics)) { return; } - List vmUuids = ns.stream() + List vmUuids = nics.stream() .map(VmNicVO::getVmInstanceUuid) + .distinct() .collect(Collectors.toList()); + Set virtioVmUuids = new HashSet<>(VmSystemTags.VIRTIO.filterResourceHasTag(vmUuids)); List tupleList = Q.New(VmInstanceVO.class) .select(VmInstanceVO_.uuid, VmInstanceVO_.platform) .in(VmInstanceVO_.uuid, vmUuids) .listTuple(); - Map vmPlatforms = Maps.newHashMap(); + Map vmDrivers = new HashMap<>(); for (Tuple vmTuple : tupleList) { String vmUuid = vmTuple.get(0, String.class); String vmPlatform = vmTuple.get(1, String.class); - vmPlatforms.put(vmUuid, ImagePlatform.valueOf(vmPlatform).isParaVirtualization() ? + vmDrivers.put(vmUuid, virtioVmUuids.contains(vmUuid) || ImagePlatform.valueOf(vmPlatform).isParaVirtualization() ? defaultPVNicDriver : defaultNicDriver); } Map> nicGroups = nics.stream() - .filter(v -> vmPlatforms.containsKey(v.getVmInstanceUuid())) + .filter(v -> vmDrivers.containsKey(v.getVmInstanceUuid())) .collect( Collectors.groupingBy( - v -> vmPlatforms.get(v.getVmInstanceUuid()).equals(defaultPVNicDriver), + v -> vmDrivers.get(v.getVmInstanceUuid()).equals(defaultPVNicDriver), Collectors.mapping(VmNicVO::getUuid, Collectors.toList())) ); diff --git a/test/src/test/groovy/org/zstack/test/integration/kvm/nic/ChangeWindowsVmNicDriverCase.groovy b/test/src/test/groovy/org/zstack/test/integration/kvm/nic/ChangeWindowsVmNicDriverCase.groovy index 86b6bab37c1..f2d5806aaf4 100644 --- a/test/src/test/groovy/org/zstack/test/integration/kvm/nic/ChangeWindowsVmNicDriverCase.groovy +++ b/test/src/test/groovy/org/zstack/test/integration/kvm/nic/ChangeWindowsVmNicDriverCase.groovy @@ -2,10 +2,13 @@ package org.zstack.test.integration.kvm.nic import org.springframework.beans.factory.annotation.Autowired import org.zstack.compute.vm.VmSystemTags +import org.zstack.compute.vm.VmNicManagerImpl import org.zstack.header.image.ImagePlatform import org.zstack.header.tag.SystemTagVO import org.zstack.header.tag.SystemTagVO_ import org.zstack.sdk.VmInstanceInventory +import org.zstack.sdk.L3NetworkInventory +import org.zstack.sdk.VmNicInventory import org.zstack.header.vm.VmNicVO import org.zstack.header.vm.VmNicVO_ import org.zstack.tag.SystemTag @@ -16,6 +19,7 @@ import org.zstack.test.integration.kvm.KvmTest import org.zstack.testlib.EnvSpec import org.zstack.testlib.SubCase import org.zstack.core.db.Q +import org.zstack.core.db.SQL class ChangeWindowsVmNicDriverCase extends SubCase { EnvSpec env @@ -39,6 +43,7 @@ class ChangeWindowsVmNicDriverCase extends SubCase { void test() { env.create { testChangeWindowsVmNicDriver() + testPrepareDbInitialValueForVirtioTaggedWindowsVm() } } @@ -88,4 +93,38 @@ class ChangeWindowsVmNicDriverCase extends SubCase { assert VmSystemTags.VIRTIO.hasTag(vm.uuid) assert Q.New(VmNicVO.class).eq(VmNicVO_.vmInstanceUuid, vm.uuid).select(VmNicVO_.driverType).findValue().equals("virtio") } -} \ No newline at end of file + + void testPrepareDbInitialValueForVirtioTaggedWindowsVm() { + VmInstanceInventory vm = env.inventoryByName("vm") as VmInstanceInventory + L3NetworkInventory pubL3 = env.inventoryByName("pubL3") as L3NetworkInventory + VmNicVO originalNic = Q.New(VmNicVO.class) + .eq(VmNicVO_.vmInstanceUuid, vm.uuid) + .find() + + assert VmSystemTags.VIRTIO.hasTag(vm.uuid) + assert originalNic.driverType == "virtio" + + VmNicInventory newNic = createVmNic { + l3NetworkUuid = pubL3.uuid + } + + attachVmNicToVm { + vmInstanceUuid = vm.uuid + vmNicUuid = newNic.uuid + } + + SQL.New(VmNicVO.class) + .eq(VmNicVO_.uuid, newNic.uuid) + .set(VmNicVO_.driverType, null) + .update() + assert Q.New(VmNicVO.class) + .eq(VmNicVO_.uuid, newNic.uuid) + .select(VmNicVO_.driverType) + .findValue() == null + + bean(VmNicManagerImpl.class).prepareDbInitialValue() + + assert Q.New(VmNicVO.class).eq(VmNicVO_.uuid, originalNic.uuid).select(VmNicVO_.driverType).findValue() == "virtio" + assert Q.New(VmNicVO.class).eq(VmNicVO_.uuid, newNic.uuid).select(VmNicVO_.driverType).findValue() == "virtio" + } +} diff --git a/test/src/test/groovy/org/zstack/test/integration/kvm/nic/VmNicBasicCase.groovy b/test/src/test/groovy/org/zstack/test/integration/kvm/nic/VmNicBasicCase.groovy index 43e6d80c840..8ee2a8bd3ff 100644 --- a/test/src/test/groovy/org/zstack/test/integration/kvm/nic/VmNicBasicCase.groovy +++ b/test/src/test/groovy/org/zstack/test/integration/kvm/nic/VmNicBasicCase.groovy @@ -82,6 +82,7 @@ class VmNicBasicCase extends SubCase { assert nic.ip != null assert nic.mac != null assert nic.usedIps.size() != 0 + assert dbFindByUuid(nic.uuid, VmNicVO.class).driverType == null IpRangeInventory ipRangeInventory = pubL3.ipRanges.get(0) assert nic.gateway == ipRangeInventory.gateway @@ -101,6 +102,7 @@ class VmNicBasicCase extends SubCase { void testAttachVmNicToVm () { L3NetworkInventory l3 = env.inventoryByName("l3") VmInstanceInventory vm = env.inventoryByName("vm") + String expectedDriverType = vm.vmNics[0].driverType changeL3NetworkState { uuid = nic.l3NetworkUuid @@ -149,6 +151,8 @@ class VmNicBasicCase extends SubCase { assert vmNicVO.deviceId == 1 assert vmNicVO.internalName == VmNicVO.generateNicInternalName(vmInstanceVO.getInternalId(), 1) assert vmNicVO.vmInstanceUuid == vm.getUuid() + assert vmNicVO.driverType == expectedDriverType + assert vm.vmNics.find { it.uuid == nic.uuid }.driverType == expectedDriverType usedIpUuid = vmNicVO.usedIpUuid