使用python实现自定义Transformer以对pyspark的pipeline进行增强

一 示例

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable

class SelectColsTransformer(
    Transformer, DefaultParamsReadable, DefaultParamsWritable
):
    cols = Param(
        Params._dummy(),
        "cols",
        "cols to select"
    )

    @keyword_only
    def __init__(self, cols:list[float]=['*']):
        super(SelectColsTransformer, self).__init__()
        self._setDefault(cols=['*'])
        kwargs = self._input_kwargs
        self.setParams(**kwargs)


    @keyword_only
    def setParams(self, cols:list[float]=['*']):
        kwargs = self._input_kwargs
        return self._set(**kwargs)


    def _transform(self, dataset):
        return dataset.select(self.getOrDefault(self.cols));

Q.E.D.