forked from plotly/plotly.py
- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
382 lines (318 loc) · 11.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
"""
Utility Routines for Working with Matplotlib Objects
====================================================
"""
importitertools
importio
importbase64
importnumpyasnp
importwarnings
importmatplotlib
frommatplotlib.colorsimportcolorConverter
frommatplotlib.pathimportPath
frommatplotlib.markersimportMarkerStyle
frommatplotlib.transformsimportAffine2D
frommatplotlibimportticker
defexport_color(color):
"""Convert matplotlib color code to hex color or RGBA color"""
ifcolorisNoneorcolorConverter.to_rgba(color)[3] ==0:
return"none"
elifcolorConverter.to_rgba(color)[3] ==1:
rgb=colorConverter.to_rgb(color)
return"#{0:02X}{1:02X}{2:02X}".format(*(int(255*c) forcinrgb))
else:
c=colorConverter.to_rgba(color)
return (
"rgba("
+", ".join(str(int(np.round(val*255))) forvalinc[:3])
+", "
+str(c[3])
+")"
)
def_many_to_one(input_dict):
"""Convert a many-to-one mapping to a one-to-one mapping"""
returndict((key, val) forkeys, valininput_dict.items() forkeyinkeys)
LINESTYLES=_many_to_one(
{
("solid", "-", (None, None)): "none",
("dashed", "--"): "6,6",
("dotted", ":"): "2,2",
("dashdot", "-."): "4,4,2,4",
("", " ", "None", "none"): None,
}
)
defget_dasharray(obj):
"""Get an SVG dash array for the given matplotlib linestyle
Parameters
----------
obj : matplotlib object
The matplotlib line or path object, which must have a get_linestyle()
method which returns a valid matplotlib line code
Returns
-------
dasharray : string
The HTML/SVG dasharray code associated with the object.
"""
ifobj.__dict__.get("_dashSeq", None) isnotNone:
return",".join(map(str, obj._dashSeq))
else:
ls=obj.get_linestyle()
dasharray=LINESTYLES.get(ls, "not found")
ifdasharray=="not found":
warnings.warn(
"line style '{0}' not understood: "
"defaulting to solid line.".format(ls)
)
dasharray=LINESTYLES["solid"]
returndasharray
PATH_DICT= {
Path.LINETO: "L",
Path.MOVETO: "M",
Path.CURVE3: "S",
Path.CURVE4: "C",
Path.CLOSEPOLY: "Z",
}
defSVG_path(path, transform=None, simplify=False):
"""Construct the vertices and SVG codes for the path
Parameters
----------
path : matplotlib.Path object
transform : matplotlib transform (optional)
if specified, the path will be transformed before computing the output.
Returns
-------
vertices : array
The shape (M, 2) array of vertices of the Path. Note that some Path
codes require multiple vertices, so the length of these vertices may
be longer than the list of path codes.
path_codes : list
A length N list of single-character path codes, N <= M. Each code is
a single character, in ['L','M','S','C','Z']. See the standard SVG
path specification for a description of these.
"""
iftransformisnotNone:
path=path.transformed(transform)
vc_tuples= [
(verticesifpath_code!=Path.CLOSEPOLYelse [], PATH_DICT[path_code])
for (vertices, path_code) inpath.iter_segments(simplify=simplify)
]
ifnotvc_tuples:
# empty path is a special case
returnnp.zeros((0, 2)), []
else:
vertices, codes=zip(*vc_tuples)
vertices=np.array(list(itertools.chain(*vertices))).reshape(-1, 2)
returnvertices, list(codes)
defget_path_style(path, fill=True):
"""Get the style dictionary for matplotlib path objects"""
style= {}
style["alpha"] =path.get_alpha()
ifstyle["alpha"] isNone:
style["alpha"] =1
style["edgecolor"] =export_color(path.get_edgecolor())
iffill:
style["facecolor"] =export_color(path.get_facecolor())
else:
style["facecolor"] ="none"
style["edgewidth"] =path.get_linewidth()
style["dasharray"] =get_dasharray(path)
style["zorder"] =path.get_zorder()
returnstyle
defget_line_style(line):
"""Get the style dictionary for matplotlib line objects"""
style= {}
style["alpha"] =line.get_alpha()
ifstyle["alpha"] isNone:
style["alpha"] =1
style["color"] =export_color(line.get_color())
style["linewidth"] =line.get_linewidth()
style["dasharray"] =get_dasharray(line)
style["zorder"] =line.get_zorder()
style["drawstyle"] =line.get_drawstyle()
returnstyle
defget_marker_style(line):
"""Get the style dictionary for matplotlib marker objects"""
style= {}
style["alpha"] =line.get_alpha()
ifstyle["alpha"] isNone:
style["alpha"] =1
style["facecolor"] =export_color(line.get_markerfacecolor())
style["edgecolor"] =export_color(line.get_markeredgecolor())
style["edgewidth"] =line.get_markeredgewidth()
style["marker"] =line.get_marker()
markerstyle=MarkerStyle(line.get_marker())
markersize=line.get_markersize()
markertransform=markerstyle.get_transform() +Affine2D().scale(
markersize, -markersize
)
style["markerpath"] =SVG_path(markerstyle.get_path(), markertransform)
style["markersize"] =markersize
style["zorder"] =line.get_zorder()
returnstyle
defget_text_style(text):
"""Return the text style dict for a text instance"""
style= {}
style["alpha"] =text.get_alpha()
ifstyle["alpha"] isNone:
style["alpha"] =1
style["fontsize"] =text.get_size()
style["color"] =export_color(text.get_color())
style["halign"] =text.get_horizontalalignment() # left, center, right
style["valign"] =text.get_verticalalignment() # baseline, center, top
style["malign"] =text._multialignment# text alignment when '\n' in text
style["rotation"] =text.get_rotation()
style["zorder"] =text.get_zorder()
returnstyle
defget_axis_properties(axis):
"""Return the property dictionary for a matplotlib.Axis instance"""
props= {}
label1On=axis._major_tick_kw.get("label1On", True)
ifisinstance(axis, matplotlib.axis.XAxis):
iflabel1On:
props["position"] ="bottom"
else:
props["position"] ="top"
elifisinstance(axis, matplotlib.axis.YAxis):
iflabel1On:
props["position"] ="left"
else:
props["position"] ="right"
else:
raiseValueError("{0} should be an Axis instance".format(axis))
# Use tick values if appropriate
locator=axis.get_major_locator()
props["nticks"] =len(locator())
ifisinstance(locator, ticker.FixedLocator):
props["tickvalues"] =list(locator())
else:
props["tickvalues"] =None
# Find tick formats
formatter=axis.get_major_formatter()
ifisinstance(formatter, ticker.NullFormatter):
props["tickformat"] =""
elifisinstance(formatter, ticker.FixedFormatter):
props["tickformat"] =list(formatter.seq)
elifisinstance(formatter, ticker.FuncFormatter):
props["tickformat"] =list(formatter.func.args[0].values())
elifnotany(label.get_visible() forlabelinaxis.get_ticklabels()):
props["tickformat"] =""
else:
props["tickformat"] =None
# Get axis scale
props["scale"] =axis.get_scale()
# Get major tick label size (assumes that's all we really care about!)
labels=axis.get_ticklabels()
iflabels:
props["fontsize"] =labels[0].get_fontsize()
else:
props["fontsize"] =None
# Get associated grid
props["grid"] =get_grid_style(axis)
# get axis visibility
props["visible"] =axis.get_visible()
returnprops
defget_grid_style(axis):
gridlines=axis.get_gridlines()
ifaxis._major_tick_kw["gridOn"] andlen(gridlines) >0:
color=export_color(gridlines[0].get_color())
alpha=gridlines[0].get_alpha()
dasharray=get_dasharray(gridlines[0])
returndict(gridOn=True, color=color, dasharray=dasharray, alpha=alpha)
else:
return {"gridOn": False}
defget_figure_properties(fig):
return {
"figwidth": fig.get_figwidth(),
"figheight": fig.get_figheight(),
"dpi": fig.dpi,
}
defget_axes_properties(ax):
props= {
"axesbg": export_color(ax.patch.get_facecolor()),
"axesbgalpha": ax.patch.get_alpha(),
"bounds": ax.get_position().bounds,
"dynamic": ax.get_navigate(),
"axison": ax.axison,
"frame_on": ax.get_frame_on(),
"patch_visible": ax.patch.get_visible(),
"axes": [get_axis_properties(ax.xaxis), get_axis_properties(ax.yaxis)],
}
foraxnamein ["x", "y"]:
axis=getattr(ax, axname+"axis")
domain=getattr(ax, "get_{0}lim".format(axname))()
lim=domain
ifisinstance(axis.converter, matplotlib.dates.DateConverter):
scale="date"
try:
importpandasaspd
frompandas.tseries.converterimportPeriodConverter
exceptImportError:
pd=None
ifpdisnotNoneandisinstance(axis.converter, PeriodConverter):
_dates= [pd.Period(ordinal=int(d), freq=axis.freq) fordindomain]
domain= [
(d.year, d.month-1, d.day, d.hour, d.minute, d.second, 0)
fordin_dates
]
else:
domain= [
(
d.year,
d.month-1,
d.day,
d.hour,
d.minute,
d.second,
d.microsecond*1e-3,
)
fordinmatplotlib.dates.num2date(domain)
]
else:
scale=axis.get_scale()
ifscalenotin ["date", "linear", "log"]:
raiseValueError("Unknown axis scale: ""{0}".format(axis.get_scale()))
props[axname+"scale"] =scale
props[axname+"lim"] =lim
props[axname+"domain"] =domain
returnprops
defiter_all_children(obj, skipContainers=False):
"""
Returns an iterator over all childen and nested children using
obj's get_children() method
if skipContainers is true, only childless objects are returned.
"""
ifhasattr(obj, "get_children") andlen(obj.get_children()) >0:
forchildinobj.get_children():
ifnotskipContainers:
yieldchild
# could use `yield from` in python 3...
forgrandchildiniter_all_children(child, skipContainers):
yieldgrandchild
else:
yieldobj
defget_legend_properties(ax, legend):
handles, labels=ax.get_legend_handles_labels()
visible=legend.get_visible()
return {"handles": handles, "labels": labels, "visible": visible}
defimage_to_base64(image):
"""
Convert a matplotlib image to a base64 png representation
Parameters
----------
image : matplotlib image object
The image to be converted.
Returns
-------
image_base64 : string
The UTF8-encoded base64 string representation of the png image.
"""
ax=image.axes
binary_buffer=io.BytesIO()
# image is saved in axes coordinates: we need to temporarily
# set the correct limits to get the correct image
lim=ax.axis()
ax.axis(image.get_extent())
image.write_png(binary_buffer)
ax.axis(lim)
binary_buffer.seek(0)
returnbase64.b64encode(binary_buffer.read()).decode("utf-8")