Source code for pylimma.batch

# SPDX-License-Identifier: GPL-3.0-or-later
#
# This module is a Python port of code from R limma. Original R copyrights:
#   removeBatchEffect.R                  Copyright (C) 2008-2025 Gordon Smyth,
#                                                                Carolyn de Graaf
#   wsva.R                               Copyright (C) 2015-2017 Yifang Hu,
#                                                                Gordon Smyth
# Python port: Copyright (C) 2026 John Mulvey
"""
Batch-effect removal and surrogate-variable analysis for pylimma.

Faithful ports:

- ``remove_batch_effect`` (``limma/R/removeBatchEffect.R``).
- ``wsva`` (``limma/R/wsva.R``). Weighted surrogate variable
  analysis; has an optional screeplot branch, so it lands in this
  module alongside ``remove_batch_effect`` rather than in
  ``plotting.py``.

Accepts matrix / dict / EList / AnnData via the ``get_eawp`` /
``put_eawp`` dispatchers. RGList / MAList / EListRaw are out of scope
(see ``memory/policy_data_class_wrappers.md``).
"""

from __future__ import annotations

import warnings

import numpy as np
import pandas as pd

from .classes import get_eawp, put_eawp
from .lmfit import lm_fit


def _factor_levels(factor) -> tuple[np.ndarray, np.ndarray]:
    """Determine levels and string-coded factor values following R's
    ``as.factor()`` dispatch.

    Rules (mirrors R's `as.factor`):
    - ``pd.Categorical`` / ``pd.Series`` with categorical dtype: preserve
      ``categories`` order (matches R's factor-input behaviour).
    - Numeric (int/float) array: levels = numeric sort of unique values,
      rendered back to strings (matches R's `as.factor(numeric_vector)`).
    - Boolean: levels = ``["FALSE", "TRUE"]`` (matches R).
    - Character / object: levels = alphabetical sort of unique string values
      (matches R's `as.factor(character_vector)`).

    Returns ``(levels, str_values)`` both as 1-D numpy arrays of strings.
    """
    # Categorical (including pd.Series of categorical dtype)
    if isinstance(factor, pd.Categorical):
        cat = factor
    elif isinstance(factor, pd.Series) and hasattr(factor, "cat"):
        cat = factor.values
    else:
        cat = None
    if cat is not None and isinstance(cat, pd.Categorical):
        levels = np.array([str(c) for c in cat.categories])
        str_values = np.array([str(v) for v in cat])
        return levels, str_values

    arr = np.asarray(factor)

    # Boolean
    if arr.dtype == bool:
        levels = np.array(["FALSE", "TRUE"])
        str_values = np.where(arr, "TRUE", "FALSE")
        return levels, str_values

    # Numeric (integer or float)
    if np.issubdtype(arr.dtype, np.number):
        unique_sorted = np.sort(np.unique(arr))

        # R's `as.character(numeric)` formatting: integers render without
        # decimals, floats use R's default (effectively Python's default too)
        def _render(x):
            if float(x).is_integer():
                return str(int(x))
            return str(x)

        levels = np.array([_render(v) for v in unique_sorted])
        str_values = np.array([_render(v) for v in arr])
        return levels, str_values

    # Character / object: alphabetical
    str_values = np.array([str(v) for v in arr])
    levels = np.array(sorted(set(str_values)))
    return levels, str_values


def _sum_to_zero_design(factor) -> np.ndarray:
    """Build R-equivalent ``model.matrix(~factor)[,-1]`` with
    ``contrasts = contr.sum`` applied to the factor.

    For a factor with K levels (in R's `as.factor()` order - numeric for
    int/float input, alphabetical for character, preserved for
    ``pd.Categorical``), returns an ``n x (K-1)`` matrix: the first K-1
    levels get a 1 in their column and the last level gets -1 across all
    columns. Intercept is dropped.
    """
    levels, str_values = _factor_levels(factor)
    n = len(str_values)
    K = len(levels)
    if K < 2:
        return np.zeros((n, 0), dtype=np.float64)
    X = np.zeros((n, K - 1), dtype=np.float64)
    for j, lev in enumerate(levels[:-1]):
        X[str_values == lev, j] = 1.0
    X[str_values == levels[-1], :] = -1.0
    return X


[docs] def remove_batch_effect( x, batch: np.ndarray | list | None = None, batch2: np.ndarray | list | None = None, covariates: np.ndarray | None = None, design: np.ndarray | None = None, group: np.ndarray | list | None = None, *, out_layer: str = "batch_removed", uns_key: str = "remove_batch_effect", layer: str | None = None, **lmfit_kwargs, ): """ Remove batch effects from a matrix of expression values. Faithful port of R limma's ``removeBatchEffect`` (``limma/R/removeBatchEffect.R``). Fits a linear model against a combined design of experimental conditions and batch covariates, then subtracts the estimated batch-coefficient contribution from the expression matrix. Parameters ---------- x : ndarray, EList, AnnData, or dict Expression matrix (genes x samples) or wrapper. batch : array-like, optional Factor of batch labels. Coded with sum-to-zero contrasts before entering the design. batch2 : array-like, optional Second batch factor, treated the same way as ``batch``. covariates : ndarray, optional Quantitative covariates (samples x p). Mean-centred before entry. design : ndarray, optional Design matrix for the experimental conditions to be preserved. If omitted and ``group`` is also omitted, a one-group design is assumed and a warning is emitted. group : array-like, optional If given and ``design`` is omitted, sets ``design = one_hot(group)``. **lmfit_kwargs Forwarded to ``lm_fit`` (e.g. ``weights``, ``method``). Returns ------- Same class as input (matrix -> ndarray, EList -> EList, AnnData -> None). """ original_input = x eawp = get_eawp(x, layer=layer) X = np.asarray(eawp["exprs"], dtype=np.float64) if batch is None and batch2 is None and covariates is None: return put_eawp( {"E": X}, original_input, out_layer=out_layer, weights_layer=None, uns_key=uns_key, single_matrix=True, ) parts = [] if batch is not None: parts.append(_sum_to_zero_design(batch)) if batch2 is not None: parts.append(_sum_to_zero_design(batch2)) if covariates is not None: cov = np.asarray(covariates, dtype=np.float64) if cov.ndim == 1: cov = cov.reshape(-1, 1) cov = cov - cov.mean(axis=0, keepdims=True) parts.append(cov) X_batch = np.concatenate(parts, axis=1) if parts else np.zeros((X.shape[1], 0)) if group is not None and design is None: grp = np.asarray([str(v) for v in group]) levels = np.array(sorted(set(grp))) design = np.zeros((len(grp), len(levels)), dtype=np.float64) for j, lev in enumerate(levels): design[grp == lev, j] = 1.0 if design is None: warnings.warn( "design matrix of interest not specified. Assuming a one-group experiment.", UserWarning, ) design = np.ones((X.shape[1], 1), dtype=np.float64) design = np.asarray(design, dtype=np.float64) full_design = np.concatenate([design, X_batch], axis=1) fit = lm_fit(X, design=full_design, **lmfit_kwargs) coef = np.asarray(fit["coefficients"], dtype=np.float64).copy() # R's lmFit QR pivots out the trailing redundant column, so any NA lands # in X_batch's slice and `beta[is.na(beta)] <- 0` correctly zeros the # wrong-subtraction. pylimma.lm_fit's QR may pivot a preserved-design # column instead, leaving a finite but algebraically-arbitrary value in # X_batch. Mirror R's intent by zeroing the ENTIRE row across both slices # whenever a NaN appears anywhere in that row's coefficients: if part of # the linear combination is unidentified, none of the decomposition's # batch-attribution can be trusted for that probe. nan_row = np.any(np.isnan(coef), axis=1) coef[nan_row, :] = 0.0 beta = coef[:, design.shape[1] :] corrected = X - beta @ X_batch.T return put_eawp( {"E": corrected}, original_input, out_layer=out_layer, weights_layer=None, uns_key=uns_key, single_matrix=True, )
def _lm_effects_residual( y: np.ndarray, design: np.ndarray, array_weights: np.ndarray | None = None, weights: np.ndarray | None = None, block: np.ndarray | None = None, correlation: float | None = None, ) -> np.ndarray: """Residual-space effects matrix for wsva. Equivalent of R limma's ``.lmEffects(y, design, ...)[, -1]`` (lmEffects.R): the residual block of the effects matrix with the contrast column dropped. Shape ``(ngenes, n - p)``. Supports the same side-channel arguments R's .lmEffects forwards from wsva's ``...``: ``array_weights``, ``weights`` (per-observation matrix), ``block`` + ``correlation``. R's ``gene.weights`` and the ``weights``-as-array-weights alias are not supported here because wsva's documented use case is SV estimation on residual space only. Parameters ---------- y : ndarray Expression matrix (n_genes, n_samples). design : ndarray Design matrix (n_samples, p). array_weights : ndarray, optional Per-sample weights (length n_samples). Pre-scales y and design by sqrt(array_weights), mirroring lmEffects.R:94-99. weights : ndarray, optional Per-observation weights (n_genes, n_samples). Triggers the per-gene QR loop in lmEffects.R:131-150. block : ndarray, optional Block factor (length n_samples) for GLS. correlation : float, optional Within-block correlation; required when ``block`` is given. """ from scipy.linalg import qr, solve_triangular p = design.shape[1] n = design.shape[0] if n <= p: raise ValueError("No residual degrees of freedom") y = np.asarray(y, dtype=np.float64) X = np.asarray(design, dtype=np.float64) # Array weights: divide out via sqrt transform (lmEffects.R:94-99). if array_weights is not None: aw = np.asarray(array_weights, dtype=np.float64) if aw.size != n: raise ValueError("Length of array_weights doesn't match number of arrays") if np.any(aw <= 0) or np.any(np.isnan(aw)): raise ValueError("array_weights must be positive") ws = np.sqrt(aw) X = X * ws[:, None] y = y * ws[None, :] # Block / correlation: GLS via Cholesky (lmEffects.R:102-118). R_chol = None if block is not None: if correlation is None: raise ValueError("correlation must be set when block is given") block_arr = np.asarray(block).ravel() if block_arr.size != n: raise ValueError("Length of block does not match number of arrays") ub, inv = np.unique(block_arr, return_inverse=True) Z = (inv[:, None] == np.arange(len(ub))).astype(np.float64) cormatrix = Z @ (float(correlation) * Z.T) np.fill_diagonal(cormatrix, 1.0) R_chol = np.linalg.cholesky(cormatrix).T # upper triangular R if weights is None: # Apply transform y <- solve(R^T, y^T)^T and X <- solve(R^T, X). y = solve_triangular(R_chol.T, y.T, lower=True).T X = solve_triangular(R_chol.T, X, lower=True) # Per-observation weights: per-gene QR loop (lmEffects.R:131-150). if weights is not None: w_mat = np.asarray(weights, dtype=np.float64) if w_mat.shape != (y.shape[0], n): raise ValueError("weights must have same dimensions as y") if np.any(w_mat <= 0) or np.any(np.isnan(w_mat)): raise ValueError("weights must be positive") effects = np.zeros((y.shape[0], n)) for g in range(y.shape[0]): ws_g = np.sqrt(w_mat[g]) wX = X * ws_g[:, None] wy = y[g] * ws_g if R_chol is not None: wy = solve_triangular(R_chol.T, wy, lower=True) wX = solve_triangular(R_chol.T, wX, lower=True) Q_g, _ = qr(wX, mode="full") effects[g] = Q_g.T @ wy return effects[:, p:] # drop contrast columns, keep residual block # Common path: single QR of (possibly transformed) X. Q, _ = qr(X, mode="full") residual = Q[:, p:].T @ y.T return residual.T # (n_genes, n - p)
[docs] def wsva( y, design: np.ndarray, n_sv: int = 1, weight_by_sd: bool = False, plot: bool = False, *, array_weights: np.ndarray | None = None, weights: np.ndarray | None = None, block: np.ndarray | None = None, correlation: float | None = None, **kwargs, ) -> np.ndarray: """Weighted surrogate variable analysis. Port of R limma's ``wsva`` (Yifang Hu and Gordon Smyth, 2015-2017). Returns an ``n_arrays x n_sv`` matrix of surrogate variables. When ``weight_by_sd=True``, the algorithm is iterative and each iteration weights rows by their residual SD. When ``plot=True``, a screeplot of the singular-value spectrum is produced via matplotlib (lazy import). ``array_weights``, ``weights``, ``block``, and ``correlation`` are threaded through to ``.lmEffects`` as R's wsva does via ``...`` (wsva.R:1, lmEffects.R:1). ``weights`` aliased as array-weights (length ``n_arrays``) is promoted to ``array_weights`` to match R's lmEffects.R:52-56. """ # R's .lmEffects (lmEffects.R:52-56) promotes a length-n vector # passed as `weights` to `array.weights` when array_weights is None. # Mirror that aliasing before dispatch. eawp = get_eawp(y) y_mat = np.asarray(eawp["exprs"], dtype=np.float64) design = np.asarray(design, dtype=np.float64) if design.ndim == 1: design = design.reshape(-1, 1) narrays = y_mat.shape[1] p = design.shape[1] d = narrays - p if array_weights is None and weights is not None: w_arr = np.asarray(weights) if w_arr.ndim == 1 and w_arr.size == narrays: array_weights = w_arr weights = None if kwargs: warnings.warn( f"Extra arguments disregarded: {sorted(kwargs.keys())}", UserWarning, ) n_sv = max(int(n_sv), 1) n_sv = min(n_sv, d) if n_sv <= 0: raise ValueError("No residual df") # Shared kwargs for every call to _lm_effects_residual in this function. eff_kwargs = dict( array_weights=array_weights, weights=weights, block=block, correlation=correlation, ) if weight_by_sd: if plot: warnings.warn("Plot not available with weight_by_sd=True", UserWarning) current_design = design for _ in range(n_sv): Effects = _lm_effects_residual(y_mat, current_design, **eff_kwargs) s = np.sqrt(np.mean(Effects**2, axis=1)) Effects_w = s[:, None] * Effects U, _, _ = np.linalg.svd(Effects_w, full_matrices=False) u = U[:, 0] * s sv = (u[:, None] * y_mat).sum(axis=0) current_design = np.concatenate([current_design, sv.reshape(-1, 1)], axis=1) SV = current_design[:, p:].T # (n_sv, narrays) else: Effects = _lm_effects_residual(y_mat, design, **eff_kwargs) U, s, _ = np.linalg.svd(Effects, full_matrices=False) U = U[:, :n_sv] SV = U.T @ y_mat # (n_sv, narrays) if plot: from .plotting import _require_matplotlib plt = _require_matplotlib() lam = s**2 lam = lam / lam.sum() _, ax = plt.subplots() ax.plot(np.arange(1, len(lam) + 1), lam, "o") ax.set_xlabel("Surrogate variable number") ax.set_ylabel("Proportion variance explained") A = (SV**2).mean(axis=1) SV = (SV / np.sqrt(A)[:, None]).T # (narrays, n_sv) return SV