mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 04:18:07 +00:00
use the primitive types in more places (#514)
This commit is contained in:
@ -16,7 +16,7 @@ from azure.mgmt.subscription import SubscriptionClient
|
||||
from memoization import cached
|
||||
from msrestazure.azure_active_directory import MSIAuthentication
|
||||
from msrestazure.tools import parse_resource_id
|
||||
from onefuzztypes.primitives import Container
|
||||
from onefuzztypes.primitives import Container, Region
|
||||
|
||||
from .monkeypatch import allow_more_workers, reduce_logging
|
||||
|
||||
@ -41,12 +41,12 @@ def get_base_resource_group() -> Any: # should be str
|
||||
|
||||
|
||||
@cached
|
||||
def get_base_region() -> Any: # should be str
|
||||
def get_base_region() -> Region:
|
||||
client = ResourceManagementClient(
|
||||
credential=get_identity(), subscription_id=get_subscription()
|
||||
)
|
||||
group = client.resource_groups.get(get_base_resource_group())
|
||||
return group.location
|
||||
return Region(group.location)
|
||||
|
||||
|
||||
@cached
|
||||
@ -89,11 +89,11 @@ DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
|
||||
@cached(ttl=DAY_IN_SECONDS)
|
||||
def get_regions() -> List[str]:
|
||||
def get_regions() -> List[Region]:
|
||||
subscription = get_subscription()
|
||||
client = SubscriptionClient(credential=get_identity())
|
||||
locations = client.subscriptions.list_locations(subscription)
|
||||
return sorted([x.name for x in locations])
|
||||
return sorted([Region(x.name) for x in locations])
|
||||
|
||||
|
||||
@cached
|
||||
|
@ -49,7 +49,7 @@ def check_container(
|
||||
compare: Compare,
|
||||
expected: int,
|
||||
container_type: ContainerType,
|
||||
containers: Dict[ContainerType, List[str]],
|
||||
containers: Dict[ContainerType, List[Container]],
|
||||
) -> None:
|
||||
actual = len(containers.get(container_type, []))
|
||||
if not check_val(compare, expected, actual):
|
||||
@ -62,7 +62,7 @@ def check_container(
|
||||
def check_containers(definition: TaskDefinition, config: TaskConfig) -> None:
|
||||
checked = set()
|
||||
|
||||
containers: Dict[ContainerType, List[str]] = {}
|
||||
containers: Dict[ContainerType, List[Container]] = {}
|
||||
for container in config.containers:
|
||||
if container.name not in checked:
|
||||
if not container_exists(container.name, StorageType.corpus):
|
||||
|
@ -18,6 +18,7 @@ from onefuzztypes.events import (
|
||||
from onefuzztypes.models import Error
|
||||
from onefuzztypes.models import Task as BASE_TASK
|
||||
from onefuzztypes.models import TaskConfig, TaskVm, UserInfo
|
||||
from onefuzztypes.primitives import PoolName
|
||||
|
||||
from ..azure.image import get_os
|
||||
from ..azure.queue import create_queue, delete_queue
|
||||
@ -165,7 +166,7 @@ class Task(BASE_TASK, ORMMixin):
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get_tasks_by_pool_name(cls, pool_name: str) -> List["Task"]:
|
||||
def get_tasks_by_pool_name(cls, pool_name: PoolName) -> List["Task"]:
|
||||
tasks = cls.search_states(states=TaskState.available())
|
||||
if not tasks:
|
||||
return []
|
||||
|
@ -72,7 +72,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
states: Optional[List[NodeState]] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[PoolName] = None,
|
||||
) -> List["Node"]:
|
||||
query: QueryFilter = {}
|
||||
if scaleset_id:
|
||||
@ -89,7 +89,7 @@ class Node(BASE_NODE, ORMMixin):
|
||||
*,
|
||||
scaleset_id: Optional[UUID] = None,
|
||||
states: Optional[List[NodeState]] = None,
|
||||
pool_name: Optional[str] = None,
|
||||
pool_name: Optional[PoolName] = None,
|
||||
exclude_update_scheduled: bool = False,
|
||||
num_results: Optional[int] = None,
|
||||
) -> List["Node"]:
|
||||
|
Reference in New Issue
Block a user