"""
Axislines includes modified implementation of the Axes class. The
biggest difference is that the artists responsible to draw axis line,
ticks, ticklabel and axis labels are separated out from the mpl's Axis
class, which are much more than artists in the original
mpl. Originally, this change was motivated to support curvlinear
grid. Here are a few reasons that I came up with new axes class.
* "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
different ticks (tick locations and labels). This is not possible
with the current mpl, although some twin axes trick can help.
* Curvelinear grid.
* angled ticks.
In the new axes class, xaxis and yaxis is set to not visible by
default, and new set of artist (AxisArtist) are defined to draw axis
line, ticks, ticklabels and axis label. Axes.axis attribute serves as
a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
instance responsible to draw left y-axis. The default Axes.axis contains
"bottom", "left", "top" and "right".
AxisArtist can be considered as a container artist and
has following children artists which will draw ticks, labels, etc.
* line
* major_ticks, major_ticklabels
* minor_ticks, minor_ticklabels
* offsetText
* label
Note that these are separate artists from Axis class of the
original mpl, thus most of tick-related command in the original mpl
won't work, although some effort has made to work with. For example,
color and markerwidth of the ax.axis["bottom"].major_ticks will follow
those of Axes.xaxis unless explicitly specified.
In addition to AxisArtist, the Axes will have *gridlines* attribute,
which obviously draws grid lines. The gridlines needs to be separated
from the axis as some gridlines can never pass any axis.
"""
import matplotlib.axes as maxes
import matplotlib.artist as martist
import matplotlib.text as mtext
import matplotlib.font_manager as font_manager
from matplotlib.path import Path
from matplotlib.transforms import Affine2D,ScaledTranslation,\
IdentityTransform, TransformedPath, Bbox
from matplotlib.collections import LineCollection
from matplotlib import rcParams
from matplotlib.artist import allow_rasterization
import warnings
import numpy as np
import matplotlib.lines as mlines
class BezierPath(mlines.Line2D):
def __init__(self, path, *kl, **kw):
mlines.Line2D.__init__(self, [], [], *kl, **kw)
self._path = path
self._invalid = False
def recache(self):
self._transformed_path = TransformedPath(self._path, self.get_transform())
self._invalid = False
def set_path(self, path):
self._path = path
self._invalid = True
def draw(self, renderer):
if self._invalid:
self.recache()
renderer.open_group('line2d')
if not self._visible: return
gc = renderer.new_gc()
self._set_gc_clip(gc)
gc.set_foreground(self._color)
gc.set_antialiased(self._antialiased)
gc.set_linewidth(self._linewidth)
gc.set_alpha(self._alpha)
if self.is_dashed():
cap = self._dashcapstyle
join = self._dashjoinstyle
else:
cap = self._solidcapstyle
join = self._solidjoinstyle
gc.set_joinstyle(join)
gc.set_capstyle(cap)
funcname = self._lineStyles.get(self._linestyle, '_draw_nothing')
if funcname != '_draw_nothing':
tpath, affine = self._transformed_path.get_transformed_path_and_affine()
lineFunc = getattr(self, funcname)
lineFunc(renderer, gc, tpath, affine.frozen())
gc.restore()
renderer.close_group('line2d')
class UnimplementedException(Exception):
pass
class AxisArtistHelper(object):
"""
AxisArtistHelper should define
following method with given APIs. Note that the first axes argument
will be axes attribute of the caller artist.
# LINE (spinal line?)
def get_line(self, axes):
# path : Path
return path
def get_line_transform(self, axes):
# ...
# trans : transform
return trans
# LABEL
def get_label_pos(self, axes):
# x, y : position
return (x, y), trans
def get_label_offset_transform(self, \
axes,
pad_points, fontprops, renderer,
bboxes,
):
# va : vertical alignment
# ha : horizontal alignment
# a : angle
return trans, va, ha, a
# TICK
def get_tick_transform(self, axes):
return trans
def get_tick_iterators(self, axes):
# iter : iteratoable object that yields (c, angle, l) where
# c, angle, l is position, tick angle, and label
return iter_major, iter_minot
"""
class _Base(object):
"""
Base class for axis helper.
"""
def __init__(self, label_direction):
"""
label direction must be one of
["left", "right", "bottom", "top",
"curved"]
"""
self.label_direction = label_direction
def update_lim(self, axes):
pass
_label_angles = dict(left=90, right=90, bottom=0, top=0)
_ticklabel_angles = dict(left=0, right=0, bottom=0, top=0)
def _get_label_offset_transform(self, pad_points, fontprops,
renderer, bboxes=None):
"""
Returns (offset-transform, vertical-alignment, horiz-alignment)
of (tick or axis) labels appropriate for the label
direction.
The offset-transform represents a required pixel offset
from the reference point. For example, x-axis center will
be the referece point for xlabel.
pad_points : padding from axis line or tick labels (see bboxes)
fontprops : font properties for label
renderer : renderer
bboxes=None : list of bboxes (window extents) of the tick labels.
This only make sense for axis label.
all the above parameters are used to estimate the offset.
"""
if renderer:
pad_pixels = renderer.points_to_pixels(pad_points)
font_size_points = fontprops.get_size_in_points()
font_size_pixels = renderer.points_to_pixels(font_size_points)
else:
pad_pixels = pad_points
font_size_points = fontprops.get_size_in_points()
font_size_pixels = font_size_points
if bboxes:
bbox = Bbox.union(bboxes)
w, h = bbox.width, bbox.height
else:
w, h = 0, 0
tr = Affine2D()
if self.label_direction == "left":
tr.translate(-(pad_pixels+w), 0.)
return tr, "center", "right"
elif self.label_direction == "right":
tr.translate(+(pad_pixels+w), 0.)
return tr, "center", "left"
elif self.label_direction == "bottom":
tr.translate(0, -(pad_pixels+font_size_pixels+h))
return tr, "baseline", "center"
elif self.label_direction == "top":
tr.translate(0, +(pad_pixels+h))
return tr, "baseline", "center"
elif self.label_direction == "curved":
#tr.translate(0, +(pad_pixels+h))
return tr, "baseline", "center"
else:
raise ValueError("Unknown label direction : %s" \
% (self.label_direction,))
def get_label_offset_transform(self, axes,
pad_points, fontprops, renderer,
bboxes,
):
"""
offset transform for axis label.
"""
tr, va, ha = self._get_label_offset_transform( \
pad_points, fontprops, renderer, bboxes)
a = self._label_angles[self.label_direction]
return tr, va, ha, a
def get_ticklabel_offset_transform(self, axes,
pad_points, fontprops,
renderer,
):
"""
offset transform for ticklabels.
"""
tr, va, ha = self._get_label_offset_transform( \
pad_points, fontprops, renderer, None)
a = self._ticklabel_angles[self.label_direction]
return tr, va, ha, a
class Fixed(_Base):
"""
Helper class for a fixed (in the axes coordinate) axis.
"""
_default_passthru_pt = dict(left=(0, 0),
right=(1, 0),
bottom=(0, 0),
top=(0, 1))
def __init__(self,
loc,
label_direction=None):
"""
nth_coord = along which coordinate value varies
in 2d, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
if loc not in ["left", "right", "bottom", "top"]:
raise ValueError("%s" % loc)
#if nth_coord is None:
if loc in ["left", "right"]:
nth_coord = 1
elif loc in ["bottom", "top"]:
nth_coord = 0
self.nth_coord = nth_coord
super(AxisArtistHelper.Fixed, self).__init__(loc)
self.passthru_pt = self._default_passthru_pt[loc]
if label_direction is None:
label_direction = loc
_verts = np.array([[0., 0.],
[1., 1.]])
fixed_coord = 1-nth_coord
_verts[:,fixed_coord] = self.passthru_pt[fixed_coord]
# axis line in transAxes
self._path = Path(_verts)
def get_nth_coord(self):
return self.nth_coord
# LINE
def get_line(self, axes):
return self._path
def get_line_transform(self, axes):
return axes.transAxes
# LABEL
def get_label_pos(self, axes):
"""
label reference position in transAxes.
get_label_transform() returns a transform of (transAxes+offset)
"""
_verts = [0.5, 0.5]
nth_coord = self.nth_coord
fixed_coord = 1-nth_coord
_verts[fixed_coord] = self.passthru_pt[fixed_coord]
return _verts, axes.transAxes
def get_label_offset_transform(self, axes,
pad_points, fontprops, renderer,
bboxes,
):
tr, va, ha = self._get_label_offset_transform( \
pad_points, fontprops, renderer, bboxes,
)
a = self._label_angles[self.label_direction]
return tr, va, ha, a
# TICK
def get_tick_transform(self, axes):
trans_tick = [axes.get_xaxis_transform(),
axes.get_yaxis_transform()][self.nth_coord]
return trans_tick
class Floating(_Base):
def __init__(self, nth_coord,
value, label_direction):
self.nth_coord = nth_coord
self._value = value
super(AxisArtistHelper.Floating,
self).__init__(label_direction)
def get_nth_coord(self):
return self.nth_coord
def get_line(self, axes):
raise RuntimeError("get_line method should be defined by the derived class")
class AxisArtistHelperRectlinear:
class Fixed(AxisArtistHelper.Fixed):
def __init__(self,
axes, loc, #nth_coord=None,
label_direction=None):
"""
nth_coord = along which coordinate value varies
in 2d, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
"""
super(AxisArtistHelperRectlinear.Fixed, self).__init__( \
loc, label_direction)
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
# TICK
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label"""
angle = 0 - 90 * self.nth_coord
if self.passthru_pt[1 - self.nth_coord] > 0.5:
angle = 180+angle
# take care the tick direction
if self.nth_coord == 0 and rcParams["xtick.direction"] == "out":
angle += 180
elif self.nth_coord == 1 and rcParams["ytick.direction"] == "out":
angle += 180
major = self.axis.major
majorLocs = major.locator()
major.formatter.set_locs(majorLocs)
majorLabels = [major.formatter(val, i) for i, val in enumerate(majorLocs)]
minor = self.axis.minor
minorLocs = minor.locator()
minor.formatter.set_locs(minorLocs)
minorLabels = [minor.formatter(val, i) for i, val in enumerate(minorLocs)]
trans_tick = self.get_tick_transform(axes)
tr2ax = trans_tick + axes.transAxes.inverted()
def _f(locs, labels):
for x, l in zip(locs, labels):
c = list(self.passthru_pt) # copy
c[self.nth_coord] = x
# check if the tick point is inside axes
c2 = tr2ax.transform_point(c)
delta=0.00001
if 0. -delta<= c2[self.nth_coord] <= 1.+delta:
yield c, angle, l
return _f(majorLocs, majorLabels), _f(minorLocs, minorLabels)
class Floating(AxisArtistHelper.Floating):
def __init__(self, axes, nth_coord,
passingthrough_point, label_direction):
super(AxisArtistHelperRectlinear.Floating, self).__init__( \
nth_coord, passingthrough_point, label_direction)
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
def get_line(self, axes):
_verts = np.array([[0., 0.],
[1., 1.]])
fixed_coord = 1-self.nth_coord
trans_passingthrough_point = axes.transData + axes.transAxes.inverted()
p = trans_passingthrough_point.transform_point([self._value,
self._value])
_verts[:,fixed_coord] = p[fixed_coord]
return Path(_verts)
def get_line_transform(self, axes):
return axes.transAxes
def get_label_pos(self, axes):
_verts = [0.5, 0.5]
fixed_coord = 1-self.nth_coord
trans_passingthrough_point = axes.transData + axes.transAxes.inverted()
p = trans_passingthrough_point.transform_point([self._value,
self._value])
_verts[fixed_coord] = p[fixed_coord]
if not (0. <= _verts[fixed_coord] <= 1.):
return None, None
else:
return _verts, axes.transAxes
def get_label_transform(self, axes,
pad_points, fontprops, renderer,
bboxes,
):
tr, va, ha = self._get_label_offset_transform(pad_points, fontprops,
renderer, bboxes)
a = self._label_angles[self.label_direction]
tr = axes.transAxes + tr
return tr, va, ha, a
def get_tick_transform(self, axes):
return axes.transData
def get_tick_iterators(self, axes):
"""tick_loc, tick_angle, tick_label"""
angle = 0 - 90 * self.nth_coord
major = self.axis.major
majorLocs = major.locator()
major.formatter.set_locs(majorLocs)
majorLabels = [major.formatter(val, i) for i, val in enumerate(majorLocs)]
minor = self.axis.minor
minorLocs = minor.locator()
minor.formatter.set_locs(minorLocs)
minorLabels = [minor.formatter(val, i) for i, val in enumerate(minorLocs)]
tr2ax = axes.transData + axes.transAxes.inverted()
def _f(locs, labels):
for x, l in zip(locs, labels):
c = [self._value, self._value]
c[self.nth_coord] = x
c1, c2 = tr2ax.transform_point(c)
if 0. <= c1 <= 1. and 0. <= c2 <= 1.:
yield c, angle, l
return _f(majorLocs, majorLabels), _f(minorLocs, minorLabels)
class GridHelperBase(object):
def __init__(self):
self._force_update = True
self._old_limits = None
super(GridHelperBase, self).__init__()
def update_lim(self, axes):
x1, x2 = axes.get_xlim()
y1, y2 = axes.get_ylim()
if self._force_update or self._old_limits != (x1, x2, y1, y2):
self._update(x1, x2, y1, y2)
self._force_update = False
self._old_limits = (x1, x2, y1, y2)
def _update(self, x1, x2, y1, y2):
pass
def invalidate(self):
self._force_update = True
def valid(self):
return not self._force_update
def get_gridlines(self):
return []
class GridHelperRectlinear(GridHelperBase):
def __init__(self, axes):
super(GridHelperRectlinear, self).__init__()
self.axes = axes
def new_fixed_axis(self, loc,
nth_coord=None,
tick_direction="in",
label_direction=None,
offset=None,
axes=None,
):
if axes is None:
warnings.warn("'new_fixed_axis' explicitly requires the axes keyword.")
axes = self.axes
_helper = AxisArtistHelperRectlinear.Fixed(axes, loc, nth_coord)
axisline = AxisArtist(axes, _helper, offset=offset)
return axisline
def new_floating_axis(self, nth_coord, value,
tick_direction="in",
label_direction=None,
axes=None,
):
if axes is None:
warnings.warn("'new_floating_axis' explicitly requires the axes keyword.")
axes = self.axes
passthrough_point = (value, value)
transform = axes.transData
_helper = AxisArtistHelperRectlinear.Floating( \
axes, nth_coord, value, label_direction)
axisline = AxisArtist(axes, _helper)
axisline.line.set_clip_on(True)
axisline.line.set_clip_box(axisline.axes.bbox)
return axisline
from matplotlib.lines import Line2D
class Ticks(Line2D):
def __init__(self, ticksize, tick_out=False, **kwargs):
"""
ticksize : ticksize
tick_out : tick is directed outside (rotated by 180 degree) if True. default is False.
"""
self.ticksize = ticksize
self.locs_angles = []
self.set_tick_out(tick_out)
self._axis = kwargs.pop("axis", None)
if self._axis is not None:
if "color" not in kwargs:
kwargs["color"] = "auto"
if ("mew" not in kwargs) and ("markeredgewidth" not in kwargs):
kwargs["markeredgewidth"] = "auto"
super(Ticks, self).__init__([0.], [0.], **kwargs)
self.set_snap(True)
def set_tick_out(self, b):
"""
set True if tick need to be rotated by 180 degree.
"""
self._tick_out = b
def get_tick_out(self):
"""
Return True if the tick will be rotated by 180 degree.
"""
return self._tick_out
def get_color(self):
if self._color == 'auto':
if self._axis is not None:
ticklines = self._axis.get_ticklines()
if ticklines:
color_from_axis = ticklines[0].get_color()
return color_from_axis
return "k"
return super(Ticks, self).get_color()
def get_markeredgecolor(self):
if self._markeredgecolor == 'auto':
return self.get_color()
else:
return self._markeredgecolor
def get_markeredgewidth(self):
if self._markeredgewidth == 'auto':
if self._axis is not None:
ticklines = self._axis.get_ticklines()
if ticklines:
width_from_axis = ticklines[0].get_markeredgewidth()
return width_from_axis
return .5
else:
return self._markeredgewidth
def update_ticks(self, locs_angles_labels, renderer):
self.locs_angles_labels = locs_angles_labels
_tickvert_path = Path([[0., 0.], [0., 1.]])
def draw(self, renderer):
if not self.get_visible():
return
size = self.ticksize
path_trans = self.get_transform()
# set gc : copied from lines.py
# gc = renderer.new_gc()
# self._set_gc_clip(gc)
# gc.set_foreground(self.get_color())
# gc.set_antialiased(self._antialiased)
# gc.set_linewidth(self._linewidth)
# gc.set_alpha(self._alpha)
# if self.is_dashed():
# cap = self._dashcapstyle
# join = self._dashjoinstyle
# else:
# cap = self._solidcapstyle
# join = self._solidjoinstyle
# gc.set_joinstyle(join)
# gc.set_capstyle(cap)
# gc.set_snap(self.get_snap())
gc = renderer.new_gc()
self._set_gc_clip(gc)
gc.set_foreground(self.get_markeredgecolor())
gc.set_linewidth(self.get_markeredgewidth())
gc.set_alpha(self._alpha)
offset = renderer.points_to_pixels(size)
marker_scale = Affine2D().scale(offset, offset)
tick_out = self.get_tick_out()
for loc, angle, _ in self.locs_angles_labels:
if tick_out:
angle += 180
marker_rotation = Affine2D().rotate_deg(angle)
#marker_rotation.clear().rotate_deg(angle)
marker_transform = marker_scale + marker_rotation
locs = path_trans.transform_non_affine(np.array([loc, loc]))
renderer.draw_markers(gc, self._tickvert_path, marker_transform,
Path(locs), path_trans.get_affine())
gc.restore()
class TickLabels(mtext.Text):
def __init__(self, size, **kwargs):
self.locs_angles_labels = []
self._axis = kwargs.pop("axis", None)
if self._axis is not None:
if "color" not in kwargs:
kwargs["color"] = "auto"
super(TickLabels, self).__init__(x=0., y=0., text="",
**kwargs
)
self._rotate_ticklabel = None
def set_rotate_along_line(self, b):
self._rotate_ticklabel = b
def get_rotate_along_line(self):
return self._rotate_ticklabel
def update_ticks(self, locs_angles_labels, renderer):
self.locs_angles_labels = locs_angles_labels
def get_color(self):
if self._color == 'auto':
if self._axis is not None:
ticklabels = self._axis.get_ticklabels()
if ticklabels:
color_from_axis = ticklabels[0].get_color()
return color_from_axis
return "k"
return super(TickLabels, self).get_color()
def draw(self, renderer):
if not self.get_visible(): return
if self.get_rotate_along_line():
# curved axis
# save original and adjust some properties
tr = self.get_transform()
rm = self.get_rotation_mode()
self.set_rotation_mode("anchor")
offset_tr = Affine2D()
self.set_transform(tr+offset_tr)
# estimate pad
dd = 5 + renderer.points_to_pixels(self.get_size())
for (x, y), a, l in self.locs_angles_labels:
theta = (a+90.)/180.*np.pi
dx, dy = dd * np.cos(theta), dd * np.sin(theta)
offset_tr.translate(dx, dy)
self.set_rotation(a-180)
self.set_x(x)
self.set_y(y)
self.set_text(l)
super(TickLabels, self).draw(renderer)
offset_tr.clear()
# restore original properties
self.set_transform(tr)
self.set_rotation_mode(rm)
else:
for (x, y), a, l in self.locs_angles_labels:
self.set_x(x)
self.set_y(y)
self.set_text(l)
super(TickLabels, self).draw(renderer)
def get_window_extents(self, renderer):
bboxes = []
for (x, y), a, l in self.locs_angles_labels:
self.set_x(x)
self.set_y(y)
self.set_text(l)
bboxes.append(self.get_window_extent())
return [b for b in bboxes if b.width!=0 or b.height!=0]
class AxisLabel(mtext.Text):
def __init__(self, *kl, **kwargs):
self._axis = kwargs.pop("axis", None)
if self._axis is not None:
if "color" not in kwargs:
kwargs["color"] = "auto"
super(AxisLabel, self).__init__(*kl, **kwargs)
def get_color(self):
if self._color == 'auto':
if self._axis is not None:
label = self._axis.get_label()
if label:
color_from_axis = label.get_color()
return color_from_axis
return "k"
return super(AxisLabel, self).get_color()
def get_text(self):
t = super(AxisLabel, self).get_text()
if t == "__from_axes__":
return self._axis.get_label().get_text()
return self._text
class GridlinesCollection(LineCollection):
def __init__(self, *kl, **kwargs):
super(GridlinesCollection, self).__init__(*kl, **kwargs)
self.set_grid_helper(None)
def set_grid_helper(self, grid_helper):
self._grid_helper = grid_helper
def draw(self, renderer):
if self._grid_helper is not None:
self._grid_helper.update_lim(self.axes)
gl = self._grid_helper.get_gridlines()
if gl:
self.set_segments([np.transpose(l) for l in gl])
else:
self.set_segments([])
super(GridlinesCollection, self).draw(renderer)
class AxisArtist(martist.Artist):
"""
an artist which draws axis (a line along which the n-th axes coord
is constant) line, ticks, ticklabels, and axis label.
It requires an AxisArtistHelper instance.
"""
LABELPAD = 5
ZORDER=2.5
def __init__(self, axes,
helper,
offset=None,
major_tick_size=None,
major_tick_pad=None,
minor_tick_size=None,
minor_tick_pad=None,
**kw):
"""
axes is also used to follow the axis attribute (tick color, etc).
"""
super(AxisArtist, self).__init__(**kw)
self.axes = axes
self._axis_artist_helper = helper
if offset is None:
offset = (0, 0)
self.dpi_transform = Affine2D()
self.offset_transform = ScaledTranslation(offset[0], offset[1],
self.dpi_transform)
self._label_visible = True
self._majortick_visible = True
self._majorticklabel_visible = True
self._minortick_visible = True
self._minorticklabel_visible = True
if self._axis_artist_helper.label_direction in ["left", "right"]:
axis_name = "ytick"
self.axis = axes.yaxis
else:
axis_name = "xtick"
self.axis = axes.xaxis
if major_tick_size is None:
self.major_tick_size = rcParams['%s.major.size'%axis_name]
if major_tick_pad is None:
self.major_tick_pad = rcParams['%s.major.pad'%axis_name]
if minor_tick_size is None:
self.minor_tick_size = rcParams['%s.minor.size'%axis_name]
if minor_tick_pad is None:
self.minor_tick_pad = rcParams['%s.minor.pad'%axis_name]
self._init_line()
self._init_ticks()
self._init_offsetText(self._axis_artist_helper.label_direction)
self._init_label()
self.set_zorder(self.ZORDER)
self._rotate_label_along_line = False
def set_rotate_label_along_line(self, b):
self._rotate_label_along_line = b
def get_rotate_label_along_line(self):
return self._rotate_label_along_line
def get_transform(self):
return self.axes.transAxes + self.offset_transform
def get_helper(self):
return self._axis_artist_helper
def _init_line(self):
tran = self._axis_artist_helper.get_line_transform(self.axes) \
+ self.offset_transform
self.line = BezierPath(self._axis_artist_helper.get_line(self.axes),
color=rcParams['axes.edgecolor'],
linewidth=rcParams['axes.linewidth'],
transform=tran)
def _draw_line(self, renderer):
self.line.set_path(self._axis_artist_helper.get_line(self.axes))
self.line.draw(renderer)
def _init_ticks(self):
transform=self._axis_artist_helper.get_tick_transform(self.axes) \
+ self.offset_transform
self.major_ticks = Ticks(self.major_tick_size,
axis=self.axis,
transform=transform)
self.minor_ticks = Ticks(self.minor_tick_size,
axis=self.axis,
transform=transform)
size = rcParams['xtick.labelsize']
fontprops = font_manager.FontProperties(size=size)
tvhl = self._axis_artist_helper.get_ticklabel_offset_transform( \
self.axes, self.major_tick_pad,
fontprops=fontprops, renderer=None)
trans, vert, horiz, label_a = tvhl
trans = transform + trans
# ignore ticklabel angle during the drawing time (but respect
# during init). Instead, use angle set by the TickLabel
# artist.
self.major_ticklabels = TickLabels(size, axis=self.axis)
self.minor_ticklabels = TickLabels(size, axis=self.axis)
self.major_ticklabels.set(figure = self.axes.figure,
transform=trans,
va=vert,
ha=horiz,
fontproperties=fontprops)
self.minor_ticklabels.set(figure = self.axes.figure,
transform=trans,
va=vert,
ha=horiz,
fontproperties=fontprops)
_offsetText_pos = dict(left=(0, 1, "bottom", "right"),
right=(1, 1, "bottom", "left"),
bottom=(1, 0, "top", "right"),
top=(1, 1, "bottom", "right"))
def _init_offsetText(self, direction):
x,y,va,ha = self._offsetText_pos[direction]
self.offsetText = mtext.Annotation("",
xy=(x,y), xycoords="axes fraction",
xytext=(0,0), textcoords="offset points",
#fontproperties = fp,
color = rcParams['xtick.color'],
verticalalignment=va,
horizontalalignment=ha,
)
self.offsetText.set_transform(IdentityTransform())
self.axes._set_artist_props(self.offsetText)
def _update_offsetText(self):
self.offsetText.set_text( self.axis.major.formatter.get_offset() )
self.offsetText.set_size(self.major_ticklabels.get_size())
offset = self.major_tick_pad + self.major_ticklabels.get_size() + 2.
self.offsetText.xytext= (0, offset)
def _draw_offsetText(self, renderer):
self._update_offsetText()
self.offsetText.draw(renderer)
def _draw_ticks(self, renderer):
majortick_iter, minortick_iter = \
self._axis_artist_helper.get_tick_iterators(self.axes)
tick_loc_angle_label = list(majortick_iter)
transform=self._axis_artist_helper.get_tick_transform(self.axes) \
+ self.offset_transform
fontprops = font_manager.FontProperties(size=12)
tvhl = self._axis_artist_helper.get_ticklabel_offset_transform( \
self.axes,
self.major_tick_pad,
fontprops=fontprops,
renderer=renderer,
)
trans, va, ha, a = tvhl
trans = transform + trans
# ignore va, ha, angle during the drawing time
self.major_ticklabels.set_transform(trans)
self.major_ticks.update_ticks(tick_loc_angle_label, renderer)
self.major_ticklabels.update_ticks(tick_loc_angle_label, renderer)
self.major_ticks.draw(renderer)
self.major_ticklabels.draw(renderer)
tick_loc_angle_label = list(minortick_iter)
self.minor_ticks.update_ticks(tick_loc_angle_label, renderer)
self.minor_ticklabels.update_ticks(tick_loc_angle_label, renderer)
self.minor_ticks.draw(renderer)
self.minor_ticklabels.draw(renderer)
if (self.major_ticklabels.get_visible() or self.minor_ticklabels.get_visible()):
self._draw_offsetText(renderer)
return self.major_ticklabels.get_window_extents(renderer)
def _init_label(self):
# x in axes coords, y in display coords (to be updated at draw
# time by _update_label_positions)
fontprops = font_manager.FontProperties(size=rcParams['axes.labelsize'])
textprops = dict(fontproperties = fontprops,
color = rcParams['axes.labelcolor'],
)
self.label = AxisLabel(0, 0, "__from_axes__",
color = "auto", #rcParams['axes.labelcolor'],
fontproperties=fontprops,
axis=self.axis,
)
self.label.set_figure(self.axes.figure)
def _draw_label(self, renderer, bboxes):
if not self.label.get_visible():
return
fontprops = font_manager.FontProperties(size=rcParams['axes.labelsize'])
pad_points = self.major_tick_pad
if self.get_rotate_label_along_line():
xy, tr, label_a = self._axis_artist_helper.get_label_pos( \
self.axes, with_angle=True)
if xy is None: return
x, y = xy
offset_tr = Affine2D()
if self.major_ticklabels.get_visible():
dd = renderer.points_to_pixels(self.major_ticklabels.get_size() \
+ pad_points + 2*self.LABELPAD )
else:
dd = renderer.points_to_pixels(pad_points + 2*self.LABELPAD)
theta = label_a - 0.5 * np.pi #(label_a)/180.*np.pi
dx, dy = dd * np.cos(theta), dd * np.sin(theta)
offset_tr.translate(dx, dy)
tr2 = (tr+offset_tr) #+ tr2
self.label.set(x=x, y=y,
rotation_mode="anchor",
transform=tr2,
va="center", ha="center",
rotation=label_a/np.pi*180.)
else:
xy, tr = self._axis_artist_helper.get_label_pos(self.axes)
if xy is None: return
x, y = xy
tr2, va, ha, a = self._axis_artist_helper.get_label_offset_transform(\
self.axes,
pad_points+2*self.LABELPAD, fontprops,
renderer,
bboxes=bboxes,
)
tr2 = (tr+self.offset_transform) + tr2
self.label.set(x=x, y=y,
transform=tr2,
va=va, ha=ha, rotation=a)
self.label.draw(renderer)
def set_label(self, s):
self.label.set_text(s)
@allow_rasterization
def draw(self, renderer):
'Draw the axis lines, tick lines and labels'
if not self.get_visible(): return
renderer.open_group(__name__)
self._axis_artist_helper.update_lim(self.axes)
dpi_cor = renderer.points_to_pixels(1.)
self.dpi_transform.clear().scale(dpi_cor, dpi_cor)
self._draw_line(renderer)
bboxes = self._draw_ticks(renderer)
#self._draw_offsetText(renderer)
self._draw_label(renderer, bboxes)
renderer.close_group(__name__)
def get_ticklabel_extents(self, renderer):
pass
def toggle(self, all=None, ticks=None, ticklabels=None, label=None):
if all:
_ticks, _ticklabels, _label = True, True, True
elif all is not None:
_ticks, _ticklabels, _label = False, False, False
else:
_ticks, _ticklabels, _label = None, None, None
if ticks is not None:
_ticks = ticks
if ticklabels is not None:
_ticklabels = ticklabels
if label is not None:
_label = label
if _ticks is not None:
self.major_ticks.set_visible(_ticks)
self.minor_ticks.set_visible(_ticks)
if _ticklabels is not None:
self.major_ticklabels.set_visible(_ticklabels)
self.minor_ticklabels.set_visible(_ticklabels)
if _label is not None:
self.label.set_visible(_label)
class Axes(maxes.Axes):
class AxisDict(dict):
def __init__(self, axes):
self.axes = axes
super(Axes.AxisDict, self).__init__()
def __call__(self, *v, **kwargs):
return maxes.Axes.axis(self.axes, *v, **kwargs)
def __init__(self, *kl, **kw):
helper = kw.pop("grid_helper", None)
if helper:
self._grid_helper = helper
else:
self._grid_helper = GridHelperRectlinear(self)
self._axisline_on = True
super(Axes, self).__init__(*kl, **kw)
self.toggle_axisline(True)
def toggle_axisline(self, b=None):
if b is None:
b = not self._axisline_on
if b:
self._axisline_on = True
for s in self.spines.values():
s.set_visible(False)
self.xaxis.set_visible(False)
self.yaxis.set_visible(False)
else:
self._axisline_on = False
for s in self.spines.values():
s.set_visible(True)
self.xaxis.set_visible(True)
self.yaxis.set_visible(True)
def _init_axis(self):
super(Axes, self)._init_axis()
def _init_axis_artists(self):
self._axislines = self.AxisDict(self)
new_fixed_axis = self.get_grid_helper().new_fixed_axis
for loc in ["bottom", "top", "left", "right"]:
self._axislines[loc] = new_fixed_axis(loc=loc, axes=self)
for axisline in [self._axislines["top"], self._axislines["right"]]:
axisline.label.set_visible(False)
axisline.major_ticklabels.set_visible(False)
axisline.minor_ticklabels.set_visible(False)
def _get_axislines(self):
return self._axislines
axis = property(_get_axislines)
def _init_gridlines(self, grid_helper=None):
gridlines = GridlinesCollection(None, transform=self.transData,
colors=rcParams['grid.color'],
linestyles=rcParams['grid.linestyle'],
linewidths=rcParams['grid.linewidth'])
self._set_artist_props(gridlines)
if grid_helper is None:
grid_helper = self.get_grid_helper()
gridlines.set_grid_helper(grid_helper)
gridlines.set_clip_on(True)
self.gridlines = gridlines
def cla(self):
# gridlines need to b created before cla() since cla calls grid()
self._init_gridlines()
super(Axes, self).cla()
self._init_axis_artists()
def get_grid_helper(self):
return self._grid_helper
def grid(self, b=None, **kwargs):
if not self._axisline_on:
super(Axes, self).grid(b, **kwargs)
return
if b is None:
b = not self.gridlines.get_visible()
self.gridlines.set_visible(b)
if len(kwargs):
martist.setp(self.gridlines, **kwargs)
def get_children(self):
if self._axisline_on:
children = self._axislines.values()+[self.gridlines]
else:
children = []
children.extend(super(Axes, self).get_children())
return children
def invalidate_grid_helper(self):
self._grid_helper.invalidate()
def new_floating_axis(self, nth_coord, value,
tick_direction="in",
label_direction=None,
):
gh = self.get_grid_helper()
axis = gh.new_floating_axis(nth_coord, value,
tick_direction=tick_direction,
label_direction=label_direction,
axes=self)
return axis
def draw(self, renderer, inframe=False):
if not self._axisline_on:
super(Axes, self).draw(renderer, inframe)
return
orig_artists = self.artists
self.artists = self.artists + list(self._axislines.values()) + [self.gridlines]
super(Axes, self).draw(renderer, inframe)
self.artists = orig_artists
def get_tightbbox(self, renderer):
bb0 = super(Axes, self).get_tightbbox(renderer)
if not self._axisline_on:
return bb0
bb = [bb0]
for axisline in self._axislines.values():
if not axisline.get_visible():
continue
if axisline.label.get_visible():
bb.append(axisline.label.get_window_extent(renderer))
if axisline.major_ticklabels.get_visible():
bb.extend(axisline.major_ticklabels.get_window_extents(renderer))
if axisline.minor_ticklabels.get_visible():
bb.extend(axisline.minor_ticklabels.get_window_extents(renderer))
if axisline.major_ticklabels.get_visible() or \
axisline.minor_ticklabels.get_visible():
bb.append(axisline.offsetText.get_window_extent(renderer))
#bb.extend([c.get_window_extent(renderer) for c in artists \
# if c.get_visible()])
_bbox = Bbox.union([b for b in bb if b.width!=0 or b.height!=0])
return _bbox
Subplot = maxes.subplot_class_factory(Axes)
class AxesZero(Axes):
def __init__(self, *kl, **kw):
super(AxesZero, self).__init__(*kl, **kw)
def _init_axis_artists(self):
super(AxesZero, self)._init_axis_artists()
new_floating_axis = self._grid_helper.new_floating_axis
xaxis_zero = new_floating_axis(nth_coord=0,
value=0.,
tick_direction="in",
label_direction="bottom",
axes=self)
xaxis_zero.line.set_clip_path(self.patch)
xaxis_zero.set_visible(False)
self._axislines["xzero"] = xaxis_zero
yaxis_zero = new_floating_axis(nth_coord=1,
value=0.,
tick_direction="in",
label_direction="left",
axes=self)
yaxis_zero.line.set_clip_path(self.patch)
yaxis_zero.set_visible(False)
self._axislines["yzero"] = yaxis_zero
SubplotZero = maxes.subplot_class_factory(AxesZero)
if __name__ == "__main__":
import matplotlib.pyplot as plt
fig = plt.figure(1, (4,3))
ax = SubplotZero(fig, 1, 1, 1)
fig.add_subplot(ax)
ax.axis["xzero"].set_visible(True)
ax.axis["xzero"].label.set_text("Axis Zero")
for n in ["bottom", "top", "right"]:
ax.axis[n].set_visible(False)
xx = np.arange(0, 2*np.pi, 0.01)
ax.plot(xx, np.sin(xx))
ax.set_ylabel("Test")
plt.draw()
plt.show()
|