From 6a08b0648286f8ba7beed2dadb121f5bfe36b534 Mon Sep 17 00:00:00 2001 From: colganwi Date: Thu, 11 Dec 2025 20:24:23 -0500 Subject: [PATCH 1/2] made legends compatible with mpl layout functions --- CHANGELOG.md | 2 ++ src/pycea/pl/_legend.py | 57 ++++++++++++++------------------------- src/pycea/tl/_utils.py | 9 +++++++ src/pycea/tl/clades.py | 8 +++++- tests/test_plot_legend.py | 10 +------ 5 files changed, 39 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fec218a..d3b2f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,10 @@ and this project adheres to [Semantic Versioning][]. ### Added ### Changed +- `pycea.tl.clades` now resets `tdata.uns["clade_colors"]` when number of clades differs from number of colors (#45) ### Fixed +- Legend placement now works with tight and constrained layouts (#45) ## [0.2.0] - 2025-11-14 diff --git a/src/pycea/pl/_legend.py b/src/pycea/pl/_legend.py index 0ce64ea..c6ad276 100644 --- a/src/pycea/pl/_legend.py +++ b/src/pycea/pl/_legend.py @@ -117,8 +117,6 @@ def _place_legend( shared_kwargs: dict[str, Any], at_x: float, at_y: float, - box_width: float | None = None, - expand: bool = False, ) -> mlegend.Legend: """Place a legend on the axes at the specified position. @@ -131,13 +129,9 @@ def _place_legend( shared_kwargs A dictionary of shared keyword arguments for all legends. at_x - The x-coordinate (in axes fraction) to place the legend. + The x offset in pixels from the top-right corner of the axes. at_y - The y-coordinate (in axes fraction) to place the legend. - box_width - The width of the legend box (in axes fraction). - expand - Whether to expand the legend to the box_width. + The y offset in pixels from the top-right corner of the axes. """ handlelength = legend_kwargs.get("handlelength", 2.0) fontsize = shared_kwargs.get("fontsize", mpl.rcParams["legend.fontsize"]) @@ -145,16 +139,13 @@ def _place_legend( fontsize = FontProperties(size=fontsize).get_size_in_points() if handlelength == "dynamic": handlelength = 100 / fontsize - if box_width is not None: - handlelength = (box_width * 325) / fontsize + offset_trans = mtransforms.ScaledTranslation(at_x / ax.figure.dpi, at_y / ax.figure.dpi, ax.figure.dpi_scale_trans) opts: dict[str, Any] = { "handlelength": handlelength, "loc": legend_kwargs.get("loc", "upper left"), - "bbox_to_anchor": (at_x, at_y), + "bbox_to_anchor": (1, 1), + "bbox_transform": ax.transAxes + offset_trans, } - if expand and box_width is not None: - opts["bbox_to_anchor"] = (at_x, at_y, box_width + 0.03, 0) - opts["mode"] = "expand" opts.update({k: v for k, v in legend_kwargs.items() if k not in ("loc", "handlelength")}) opts.update(shared_kwargs) leg: mlegend.Legend = ax.legend(**opts) @@ -188,11 +179,14 @@ def _render_legends( shared_kwargs = {} fig = ax.figure fig.canvas.draw() # make sure transforms are current + ax_height = ax.bbox.height + ax_width = ax.bbox.width + spacing *= ax_height # convert to pixels if not hasattr(ax, "_attrs"): ax._attrs = {} # type: ignore - x_offset = ax._attrs.get("x_offset", anchor_x) # type: ignore - y_offset = ax._attrs.get("y_offset", 1.0) # type: ignore + x_offset = ax._attrs.get("x_offset", (anchor_x - 1) * ax_width) # type: ignore + y_offset = ax._attrs.get("y_offset", 0.0) # type: ignore column_max_width = ax._attrs.get("column_max_width", 0.0) # type: ignore for legend_kwargs in legends: @@ -201,39 +195,28 @@ def _render_legends( ax.add_artist(ax.get_legend()) # 2) place normally to measure its natural size legend = _place_legend(ax, legend_kwargs, shared_kwargs, x_offset, y_offset) - # 3) measure in axes fraction + # 3) measure in pixels renderer = fig.canvas.get_renderer() # type: ignore bbox_disp = legend.get_window_extent(renderer=renderer) - bbox_axes = mtransforms.Bbox(ax.transAxes.inverted().transform(bbox_disp)) - height = bbox_axes.height - width = bbox_axes.width + width = bbox_disp.width + height = bbox_disp.height # 4) if first in column, initialize max width if column_max_width == 0.0: column_max_width = width # 5) if it overflows vertically, start new column - if (height > y_offset) & (y_offset != 1.0): + if (height - y_offset > ax_height) & (y_offset != 0.0): legend.remove() x_offset += column_max_width + spacing - y_offset = 1.0 - column_max_width = 0.0 + y_offset = 0.0 # place again and re-measure legend = _place_legend(ax, legend_kwargs, shared_kwargs, x_offset, y_offset) bbox_disp = legend.get_window_extent(renderer=renderer) - bbox_axes = mtransforms.Bbox(ax.transAxes.inverted().transform(bbox_disp)) - height = bbox_axes.height - width = bbox_axes.width - column_max_width = width - # 6) if this legend is narrower than the column max, re-place with expand - elif width < column_max_width: - legend.remove() - legend = _place_legend( - ax, legend_kwargs, shared_kwargs, x_offset, y_offset, box_width=column_max_width, expand=True - ) - # 7) otherwise, update column max if this one is wider - else: - column_max_width = width - # 8) finalize: update y_offset and save to _attrs + column_max_width = bbox_disp.width + height = bbox_disp.height + # 6) update offsets for next legend y_offset -= height + spacing + if width > column_max_width: + column_max_width = width ax._attrs.update({"y_offset": y_offset}) # type: ignore ax._attrs.update({"x_offset": x_offset}) # type: ignore ax._attrs.update({"column_max_width": column_max_width}) # type: ignore diff --git a/src/pycea/tl/_utils.py b/src/pycea/tl/_utils.py index 8dbf7a0..cdf247f 100755 --- a/src/pycea/tl/_utils.py +++ b/src/pycea/tl/_utils.py @@ -107,3 +107,12 @@ def _remove_attribute(tree, key, nodes=True, edges=True): for u, v in tree.edges: if key in tree.edges[u, v]: del tree.edges[u, v][key] + + +def _check_colors_length(tdata, key: str): + """Remove colors from uns if they do not match the number of unique entries in obs.""" + if f"{key}_colors" not in tdata.uns.keys(): + return + if tdata.obs[key].nunique() != len(tdata.uns[f"{key}_colors"]): + del tdata.uns[f"{key}_colors"] + return diff --git a/src/pycea/tl/clades.py b/src/pycea/tl/clades.py index 9c9d993..7d4e16c 100755 --- a/src/pycea/tl/clades.py +++ b/src/pycea/tl/clades.py @@ -16,7 +16,7 @@ get_trees, ) -from ._utils import _remove_attribute +from ._utils import _check_colors_length, _remove_attribute def _nodes_at_depth(tree, parent, nodes, depth, depth_key): @@ -149,6 +149,11 @@ def clades( * `tdata.obst[tree].nodes[key_added]` : `Object` - Clade assignment for each node. + Modifies the following fields: + + * `tdata.uns[f"{key_added}_colors"]` : `List` + - Removed if its length does not match the number of unique clades. + Examples -------- Mark clades at specified depth @@ -183,5 +188,6 @@ def clades( node_to_clade = get_keyed_node_data(tdata, key_added, tree_keys, slot="obst") node_to_clade.index = node_to_clade.index.droplevel(0) tdata.obs[key_added] = tdata.obs.index.map(node_to_clade[key_added]) + _check_colors_length(tdata, key_added) if copy: return pd.concat(lcas) diff --git a/tests/test_plot_legend.py b/tests/test_plot_legend.py index 133892e..4ec858b 100644 --- a/tests/test_plot_legend.py +++ b/tests/test_plot_legend.py @@ -70,21 +70,13 @@ def test_size_legend(): assert len(legend2["labels"]) == 6 -def test_place_legend_default_and_expand(): +def test_place_legend_default(): fig, ax = plt.subplots() l1 = mlines.Line2D([], [], color="red", label="a") legend_kwargs = {"title": "t1", "handles": [l1], "labels": ["a"]} shared_kwargs = {"fontsize": 10} leg = _place_legend(ax, legend_kwargs, shared_kwargs, at_x=0.5, at_y=0.5) assert isinstance(leg, mlegend.Legend) - # Expand case - fig2, ax2 = plt.subplots() - p1 = mpatches.Patch(color="blue", label="b") - legend_kwargs2 = {"title": "t2", "handles": [p1], "labels": ["b"], "handlelength": 2} - shared_kwargs2 = {"fontsize": 12} - leg2 = _place_legend(ax2, legend_kwargs2, shared_kwargs2, at_x=0.2, at_y=0.8, box_width=0.3, expand=True) - assert isinstance(leg2, mlegend.Legend) - assert hasattr(leg2, "_bbox_to_anchor") def test_render_legends(): From f2b1dfd4e15fb790aa564705788bf623447ecc93 Mon Sep 17 00:00:00 2001 From: colganwi Date: Thu, 11 Dec 2025 20:41:50 -0500 Subject: [PATCH 2/2] green on pre-release failure --- .github/workflows/test.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ebfea68..7ae59d3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -43,7 +43,8 @@ jobs: | { name: .key, label: (if (.key | contains("pre")) then .key + " (PRE-RELEASE DEPENDENCIES)" else .key end), - python: .value.python + python: .value.python, + allow_failure: (.key | contains("pre")) } )') echo "envs=${ENVS_JSON}" | tee $GITHUB_OUTPUT @@ -51,7 +52,6 @@ jobs: # Run tests through hatch. Spawns a separate runner for each environment defined in the hatch matrix obtained above. test: needs: get-environments - strategy: fail-fast: false matrix: @@ -61,6 +61,8 @@ jobs: name: ${{ matrix.env.label }} runs-on: ${{ matrix.os }} + continue-on-error: ${{ matrix.env.allow_failure }} + steps: - uses: actions/checkout@v4 with: