- Notifications
You must be signed in to change notification settings - Fork 637
/
Copy pathtest_backend_utils.py
124 lines (111 loc) · 4.68 KB
/
test_backend_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
importos
importpathlib
fromunittestimportmock
fromskyimportclouds
fromskyimportskypilot_config
fromsky.backendsimportbackend_utils
fromsky.resourcesimportResources
# Set env var to test config file.
@mock.patch.object(skypilot_config, '_dict', None)
@mock.patch.object(skypilot_config, '_loaded_config_path', None)
@mock.patch('sky.clouds.service_catalog.instance_type_exists',
return_value=True)
@mock.patch('sky.clouds.service_catalog.get_accelerators_from_instance_type',
return_value={'fake-acc': 2})
@mock.patch('sky.clouds.service_catalog.get_image_id_from_tag',
return_value='fake-image')
@mock.patch.object(clouds.aws, 'DEFAULT_SECURITY_GROUP_NAME', 'fake-default-sg')
@mock.patch('sky.check.get_cloud_credential_file_mounts',
return_value='~/.aws/credentials')
@mock.patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name',
return_value='/tmp/fake/path')
@mock.patch('sky.backends.backend_utils._deterministic_cluster_yaml_hash',
return_value='fake-hash')
@mock.patch('sky.utils.common_utils.fill_template')
deftest_write_cluster_config_w_remote_identity(mock_fill_template,
*mocks) ->None:
os.environ[
skypilot_config.
ENV_VAR_SKYPILOT_CONFIG] ='./tests/test_yamls/test_aws_config.yaml'
skypilot_config._reload_config()
cloud=clouds.AWS()
region=clouds.Region(name='fake-region')
zones= [clouds.Zone(name='fake-zone')]
resource=Resources(cloud=cloud, instance_type='fake-type: 3')
cluster_config_template='aws-ray.yml.j2'
# test default
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="display",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True)
expected_subset= {
'instance_type': 'fake-type: 3',
'custom_resources': '{"fake-acc":2}',
'region': 'fake-region',
'zones': 'fake-zone',
'image_id': 'fake-image',
'security_group': 'fake-default-sg',
'security_group_managed_by_skypilot': 'true',
'vpc_name': 'fake-vpc',
'remote_identity': 'LOCAL_CREDENTIALS', # remote identity
'sky_local_path': '/tmp/fake',
'sky_wheel_hash': 'b1bd84059bc0342f7843fcbe04ab563e',
}
mock_fill_template.assert_called_once()
assertmock_fill_template.call_args[0][
0] ==cluster_config_template, "config template incorrect"
assertmock_fill_template.call_args[0][1].items() >=expected_subset.items(
), "config fill values incorrect"
# test using cluster matches regex, top
mock_fill_template.reset_mock()
expected_subset.update({
'security_group': 'fake-1-sg',
'security_group_managed_by_skypilot': 'false',
'remote_identity': 'fake1-skypilot-role'
})
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="sky-serve-fake1-1234",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True)
mock_fill_template.assert_called_once()
assert (mock_fill_template.call_args[0][0] ==cluster_config_template,
"config template incorrect")
assert (mock_fill_template.call_args[0][1].items() >=
expected_subset.items(), "config fill values incorrect")
# test using cluster matches regex, middle
mock_fill_template.reset_mock()
expected_subset.update({
'security_group': 'fake-2-sg',
'security_group_managed_by_skypilot': 'false',
'remote_identity': 'fake2-skypilot-role'
})
backend_utils.write_cluster_config(
to_provision=resource,
num_nodes=2,
cluster_config_template=cluster_config_template,
cluster_name="sky-serve-fake2-1234",
local_wheel_path=pathlib.Path('/tmp/fake'),
wheel_hash='b1bd84059bc0342f7843fcbe04ab563e',
region=region,
zones=zones,
dryrun=True,
keep_launch_fields_in_existing_config=True)
mock_fill_template.assert_called_once()
assert (mock_fill_template.call_args[0][0] ==cluster_config_template,
"config template incorrect")
assert (mock_fill_template.call_args[0][1].items() >=
expected_subset.items(), "config fill values incorrect")