- Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathrefresh_utils.py
155 lines (123 loc) · 5 KB
/
refresh_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ importannotations
importasyncio
importcopy
importdatetime
importlogging
importrandom
fromtypingimportAny, Callable
importaiohttp
fromgoogle.auth.credentialsimportCredentials
fromgoogle.auth.credentialsimportScoped
importgoogle.auth.transport.requests
logger=logging.getLogger(name=__name__)
# _refresh_buffer is the amount of time before a refresh's result expires
# that a new refresh operation begins.
_refresh_buffer: int=4*60# 4 minutes
def_seconds_until_refresh(
expiration: datetime.datetime,
) ->int:
"""
Calculates the duration to wait before starting the next refresh.
Usually the duration will be half of the time until certificate
expiration.
Args:
expiration (datetime.datetime): The expiration time of the certificate.
Returns:
int: Time in seconds to wait before performing next refresh.
"""
duration=int(
(expiration-datetime.datetime.now(datetime.timezone.utc)).total_seconds()
)
# if certificate duration is less than 1 hour
ifduration<3600:
# something is wrong with certificate, refresh now
ifduration<_refresh_buffer:
return0
# otherwise wait until 4 minutes before expiration for next refresh
returnduration-_refresh_buffer
returnduration//2
asyncdef_is_valid(task: asyncio.Task) ->bool:
try:
metadata=awaittask
# only valid if now is before the cert expires
ifdatetime.datetime.now(datetime.timezone.utc) <metadata.expiration:
returnTrue
exceptException:
# supress any errors from task
logger.debug("Current instance metadata is invalid.")
returnFalse
def_downscope_credentials(
credentials: Credentials,
scopes: list[str] = ["https://www.googleapis.com/auth/sqlservice.login"],
) ->Credentials:
"""Generate a down-scoped credential.
Args:
credentials (google.auth.credentials.Credentials):
Credentials object used to generate down-scoped credentials.
scopes (list[str]): List of Google scopes to
include in down-scoped credentials object.
Returns:
google.auth.credentials.Credentials: Down-scoped credentials object.
"""
# credentials sourced from a service account or metadata are children of
# Scoped class and are capable of being re-scoped
ifisinstance(credentials, Scoped):
scoped_creds=credentials.with_scopes(scopes=scopes)
# authenticated user credentials can not be re-scoped
else:
# create shallow copy to not overwrite scopes on default credentials
scoped_creds=copy.copy(credentials)
# overwrite '_scopes' to down-scope user credentials
# Cloud SDK reference: https://github.com/google-cloud-sdk-unofficial/google-cloud-sdk/blob/93920ccb6d2cce0fe6d1ce841e9e33410551d66b/lib/googlecloudsdk/command_lib/sql/generate_login_token_util.py#L116
scoped_creds._scopes=scopes
# down-scoped credentials require refresh, are invalid after being re-scoped
request=google.auth.transport.requests.Request()
scoped_creds.refresh(request)
returnscoped_creds
def_exponential_backoff(attempt: int) ->float:
"""Calculates a duration to backoff in milliseconds based on the attempt i.
The formula is:
base * multi^(attempt + 1 + random)
With base = 200ms and multi = 1.618, and random = [0.0, 1.0),
the backoff values would fall between the following low and high ends:
Attempt Low (ms) High (ms)
0 324 524
1 524 847
2 847 1371
3 1371 2218
4 2218 3588
The theoretical worst case scenario would have a client wait 8.5s in total
for an API request to complete (with the first four attempts failing, and
the fifth succeeding).
"""
base=200
multi=1.618
exp=attempt+1+random.random()
returnbase*pow(multi, exp)
asyncdefretry_50x(
request_coro: Callable, *args: Any, **kwargs: Any
) ->aiohttp.ClientResponse:
"""Retry any 50x HTTP response up to X number of times."""
max_retries=5
foriinrange(max_retries):
resp=awaitrequest_coro(*args, **kwargs)
# backoff for any 50X errors
ifresp.status>=500andi<max_retries:
# calculate backoff time
backoff=_exponential_backoff(i)
awaitasyncio.sleep(backoff/1000)
else:
break
returnresp