From a983812217220f4537bad27742567661fb0cadd1 Mon Sep 17 00:00:00 2001 From: pbrod Date: Fri, 30 Dec 2016 04:03:07 +0100 Subject: [PATCH] Simplified further --- wafo/containers.py | 49 +++++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/wafo/containers.py b/wafo/containers.py index f86f050..c203eba 100644 --- a/wafo/containers.py +++ b/wafo/containers.py @@ -208,19 +208,17 @@ class PlotData(object): return res def _plot_children(self, axis, plotflag, kwds): - tmp = None - if not plotflag and self.children is not None: - axis.hold('on') - tmp = [] - child_args = kwds.pop('plot_args_children', - tuple(self.plot_args_children)) - child_kwds = dict(self.plot_kwds_children).copy() - child_kwds.update(kwds.pop('plot_kwds_children', {})) - child_kwds['axis'] = axis - for child in self.children: - tmp1 = child.plot(*child_args, **child_kwds) - if tmp1 is not None: - tmp.append(tmp1) + axis.hold('on') + tmp = [] + child_args = kwds.pop('plot_args_children', + tuple(self.plot_args_children)) + child_kwds = dict(self.plot_kwds_children).copy() + child_kwds.update(kwds.pop('plot_kwds_children', {})) + child_kwds['axis'] = axis + for child in self.children: + tmp1 = child.plot(*child_args, **child_kwds) + if tmp1 is not None: + tmp.append(tmp1) if tmp: return tmp return None @@ -232,7 +230,9 @@ class PlotData(object): axis = plt.gca() default_plotflag = self.plot_kwds.get('plotflag') plotflag = kwds.get('plotflag', default_plotflag) - tmp = self._plot_children(axis, plotflag, kwds) + tmp = None + if not plotflag and self.children is not None: + tmp = self._plot_children(axis, plotflag, kwds) main_args = args if len(args) else tuple(self.plot_args) main_kwds = dict(self.plot_kwds).copy() main_kwds.update(kwds) @@ -310,6 +310,7 @@ class AxisLabels: def labelfig(self, axis=None): if axis is None: axis = plt.gca() + try: return self._labelfig(axis) except Exception as err: @@ -377,14 +378,15 @@ class Plotter_1d(object): def set_axis(axis, f_max, trans_flag, log_scale): - ax = list(axis.axis()) - if trans_flag == 8 and not log_scale: - ax[3] = 11 * np.log10(f_max) - ax[2] = ax[3] - 40 - else: - ax[3] = 1.15 * f_max - ax[2] = ax[3] * 1e-4 - axis.axis(ax) + if log_scale or (trans_flag == 5 and not log_scale): + ax = list(axis.axis()) + if trans_flag == 8 and not log_scale: + ax[3] = 11 * np.log10(f_max) + ax[2] = ax[3] - 40 + else: + ax[3] = 1.15 * f_max + ax[2] = ax[3] * 1e-4 + axis.axis(ax) def set_plot_scale(axis, f_max, plotflag): @@ -397,8 +399,7 @@ def set_plot_scale(axis, f_max, plotflag): axis.set(**opt) trans_flag = np.mod(plotflag // 10, 10) - if log_scale or (trans_flag == 5 and not log_scale): - set_axis(axis, f_max, trans_flag, log_scale) + set_axis(axis, f_max, trans_flag, log_scale) def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds):