ensure VM IDs are unique before calling Azure reimage/delete APIs (#1023)

This commit is contained in:
bmc-msft
2021-06-25 11:54:52 -04:00
committed by GitHub
parent 880039a617
commit 883c93aaf4
2 changed files with 15 additions and 13 deletions

View File

@ -5,7 +5,7 @@
import logging import logging
import os import os
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Set, Union, cast
from uuid import UUID from uuid import UUID
from azure.core.exceptions import ( from azure.core.exceptions import (
@ -151,18 +151,18 @@ def check_can_update(name: UUID) -> Any:
return vmss return vmss
def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]: def reimage_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name) check_can_update(name)
resource_group = get_base_resource_group() resource_group = get_base_resource_group()
logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids) logging.info("reimaging scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = get_compute_client() compute_client = get_compute_client()
instance_ids = [] instance_ids = set()
machine_to_id = list_instance_ids(name) machine_to_id = list_instance_ids(name)
for vm_id in vm_ids: for vm_id in vm_ids:
if vm_id in machine_to_id: if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id]) instance_ids.add(machine_to_id[vm_id])
else: else:
logging.info("unable to find vm_id for %s:%s", name, vm_id) logging.info("unable to find vm_id for %s:%s", name, vm_id)
@ -170,23 +170,23 @@ def reimage_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
compute_client.virtual_machine_scale_sets.begin_reimage_all( compute_client.virtual_machine_scale_sets.begin_reimage_all(
resource_group, resource_group,
str(name), str(name),
VirtualMachineScaleSetVMInstanceIDs(instance_ids=instance_ids), VirtualMachineScaleSetVMInstanceIDs(instance_ids=list(instance_ids)),
) )
return None return None
def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]: def delete_vmss_nodes(name: UUID, vm_ids: Set[UUID]) -> Optional[Error]:
check_can_update(name) check_can_update(name)
resource_group = get_base_resource_group() resource_group = get_base_resource_group()
logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids) logging.info("deleting scaleset VM - name: %s vm_ids:%s", name, vm_ids)
compute_client = get_compute_client() compute_client = get_compute_client()
instance_ids = [] instance_ids = set()
machine_to_id = list_instance_ids(name) machine_to_id = list_instance_ids(name)
for vm_id in vm_ids: for vm_id in vm_ids:
if vm_id in machine_to_id: if vm_id in machine_to_id:
instance_ids.append(machine_to_id[vm_id]) instance_ids.add(machine_to_id[vm_id])
else: else:
logging.info("unable to find vm_id for %s:%s", name, vm_id) logging.info("unable to find vm_id for %s:%s", name, vm_id)
@ -194,7 +194,9 @@ def delete_vmss_nodes(name: UUID, vm_ids: List[UUID]) -> Optional[Error]:
compute_client.virtual_machine_scale_sets.begin_delete_instances( compute_client.virtual_machine_scale_sets.begin_delete_instances(
resource_group, resource_group,
str(name), str(name),
VirtualMachineScaleSetVMInstanceRequiredIDs(instance_ids=instance_ids), VirtualMachineScaleSetVMInstanceRequiredIDs(
instance_ids=list(instance_ids)
),
) )
return None return None

View File

@ -537,7 +537,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
) )
return return
machine_ids = [] machine_ids = set()
for node in nodes: for node in nodes:
if node.debug_keep_node: if node.debug_keep_node:
logging.warning( logging.warning(
@ -547,7 +547,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
node.machine_id, node.machine_id,
) )
else: else:
machine_ids.append(node.machine_id) machine_ids.add(node.machine_id)
logging.info( logging.info(
SCALESET_LOG_PREFIX + "deleting nodes scaleset_id:%s machine_id:%s", SCALESET_LOG_PREFIX + "deleting nodes scaleset_id:%s machine_id:%s",
@ -585,7 +585,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
) )
return return
machine_ids = [] machine_ids = set()
for node in nodes: for node in nodes:
if node.debug_keep_node: if node.debug_keep_node:
logging.warning( logging.warning(
@ -595,7 +595,7 @@ class Scaleset(BASE_SCALESET, ORMMixin):
node.machine_id, node.machine_id,
) )
else: else:
machine_ids.append(node.machine_id) machine_ids.add(node.machine_id)
if not machine_ids: if not machine_ids:
logging.info( logging.info(