Skip to content

Reference

disfor

get(name)

Function to get data paths

This function fetches the paths to the raw data provided in DISFOR. If the data is not available locally, it uses pooch to fetch the data from huggingface.

Parameters:

Name Type Description Default
name Literal['samples.parquet', 'labels.parquet', 'pixel_data.parquet', 'tiffs', 'train_ids.json', 'val_ids.json', 'classes.json']

Name of the file to fetch.

required

Returns:

Type Description
Path

Path to the file on local storage

Source code in src/disfor/io.py
def get(
    name: Literal[
        "samples.parquet",
        "labels.parquet",
        "pixel_data.parquet",
        "tiffs",
        "train_ids.json",
        "val_ids.json",
        "classes.json",
    ],
) -> Path:
    """
    Function to get data paths

    This function fetches the paths to the raw data provided in DISFOR.
    If the data is not available locally, it uses `pooch` to fetch the data from huggingface.

    Args:
        name: Name of the file to fetch.

    Returns:
        Path to the file on local storage
    """
    if name == "tiffs":
        return fetch_s2_chips()
    return Path(_DATA_GETTER.fetch(name))

datasets

GenericDataset

A generic class which serves to load, filter and pre-process the raw data.

There are two classes which then bring the filtered and pre-processed data into formats which can be used with pytorch (disfor.datasets.MonoTemporalClassification) and sklearn style classifiers (disfor.datasets.TabularDataset) respectively.

Parameters:

Name Type Description Default
data_folder str | None

Path to root data folder containing pixel_data.parquet, labels.parquet and samples.parquet, if not specified, data will be fetched using disfor.get

None
target_classes List[Literal[100, 110, 120, 121, 122, 123, 200, 210, 211, 212, 213, 220, 221, 222, 230, 231, 232, 240, 241, 242, 243, 244, 245]] | None

Which classes should be included

None
class_mapping_overrides Dict[int, int] | None

Map classes to other classes for example {221: 211, 222: 211} would map both of the salvage classes to clear cut. This remapping happens before filtering of target_classes. This means that the items of the dict need to be specified in target_classes, otherwise they will be filtered out.

None
confidence List[Literal['high', 'medium']] | None

Filters dataset to only include logged confidence of label interpretation.

None
valid_scl_values List[Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]] | None

List of valid SCL values. Used to filter out cloudy or otherwise unusable observations

None
chip_size Literal[32, 16, 8, 4]

Size of the image chip. Maximum of 32x32. Used in min_clear_percentage_chip.

32
min_clear_percentage_chip int | None

Minimum percent (0-100) of pixels in the chip that has to be clear (SCL in 4,5,6) to be included.

None
months List[Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]] | None

List of months to include acquisitions from. January is 1, December is 12.

None
max_days_since_event int | dict | None

Either an integer specifying the maximum duration in days to the start label. This can also be set separately for each target_class. For example if target_classes is [110, 211] (Mature Forest, Clear Cut) we can specify a maximum number of 90 days after a Clear Cut by passing a dictionary with {211: 90}

None
sample_datasets List[Literal[1, 2, 3]] | None

Data from which sampling campaign should be included. Includes data from all by default (None)

None
max_samples_per_event int | None

Maximum number of acquisitions to include per event. Can be used to reduce number of samples drawn from segments with long durations. For example to reduce the number of healthy acquistions

None
random_seed int | None

Random seed used for reproducible subsampling operations

None
apply_downsampling bool

Flag if downsampling sampling of the majority class should be used.

False
target_majority_samples int | None

How many samples the majority class should have after balancing. If None, the majority class will be reduced to 2 times the samples of the second largest class, or 500, whichever is less.

None
omit_border bool

Omit samples which have "border" in the comment. These are usually samples where the sample is a mixed pixel

True
omit_low_tcd bool

Omit samples which have "TCD" in the comment. These are usually samples where the forest has a low tree cover density (for example olive plantations)

True
bands List[Literal['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12', 'SCL']] | None

Spectral bands to include

None
remove_outliers bool

Flag if outliers should be removed. This is used to remove clouds or other data artifacts which were not masked through the SCL values.

False
outlier_method Literal['iqr', 'zscore', 'modified_zscore']

Statistical method used to determine outliers. This statistical measure is calculated for each unique (sample_id, label) group.

'iqr'
outlier_threshold float

Which threshold to apply, acquisitions greater than (outlier_threshold*outlier_method) will be removed.

1.5
outlier_columns List[Literal['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12', 'SCL']] | None

Which columns (bands) to search for outliers. If an outlier is detected in any of the bands it will be removed. Default is all bands which are defined in the parameter bands

None
label_strategy Literal['LabelEncoder', 'LabelBinarizer', 'Hierarchical']

How the values in target_classes should be encoded. LabelEncoder and LabelBinarizer correspond to the sklearn encoders. Hierarchical is a custom encoding implemented for the hierarchical classes. For more details see disfor.utils.HierarchicalLabelEncoder

'LabelEncoder'

Attributes:

Name Type Description
pixel_data

Polars dataframe containing filtered and pre-processed data.

Source code in src/disfor/datasets/generic.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
class GenericDataset:
    """
    A generic class which serves to load, filter and pre-process the raw data.

    There are two classes which then bring the filtered and pre-processed data into formats which
    can be used with pytorch ([`disfor.datasets.MonoTemporalClassification`][]) and sklearn style classifiers ([`disfor.datasets.TabularDataset`][]) respectively.

    Args:
        data_folder: Path to root data folder containing pixel_data.parquet, labels.parquet and samples.parquet,
            if not specified, data will be fetched using [`disfor.get`][]
        target_classes: Which classes should be included
        class_mapping_overrides: Map classes to other classes for example {221: 211, 222: 211} would map both of the salvage classes to clear cut.
            This remapping happens before filtering of `target_classes`. This means that the items of the dict need to be specified in target_classes,
            otherwise they will be filtered out.
        confidence: Filters dataset to only include logged confidence of label interpretation.
        valid_scl_values: List of valid SCL values. Used to filter out cloudy or otherwise unusable observations
        chip_size: Size of the image chip. Maximum of 32x32. Used in `min_clear_percentage_chip`.
        min_clear_percentage_chip: Minimum percent (0-100) of pixels in the chip that has to be clear (SCL in 4,5,6) to be included.
        months: List of months to include acquisitions from. January is 1, December is 12.
        max_days_since_event: Either an integer specifying the maximum duration in days to the start label. This can also be set separately for each target_class.
            For example if target_classes is [110, 211] (Mature Forest, Clear Cut) we can specify a maximum number of 90 days after a Clear Cut by passing a dictionary
            with {211: 90}
        sample_datasets: Data from which sampling campaign should be included. Includes data from all by default (None)
        max_samples_per_event: Maximum number of acquisitions to include per event. Can be used to reduce number of samples
            drawn from segments with long durations. For example to reduce the number of healthy acquistions
        random_seed: Random seed used for reproducible subsampling operations
        apply_downsampling: Flag if downsampling sampling of the majority class should be used.
        target_majority_samples: How many samples the majority class should have after balancing. If None, the majority class will be reduced
            to 2 times the samples of the second largest class, or 500, whichever is less.
        omit_border: Omit samples which have "border" in the comment. These are usually samples where the sample is a mixed pixel
        omit_low_tcd: Omit samples which have "TCD" in the comment. These are usually samples where the forest has a low tree cover density (for example olive plantations)
        bands: Spectral bands to include
        remove_outliers: Flag if outliers should be removed. This is used to remove clouds or other data artifacts
            which were not masked through the SCL values.
        outlier_method: Statistical method used to determine outliers. This statistical measure is calculated for each unique
            `(sample_id, label)` group.
        outlier_threshold: Which threshold to apply, acquisitions greater than `(outlier_threshold*outlier_method)` will be removed.
        outlier_columns: Which columns (bands) to search for outliers. If an outlier is detected in any of the bands
            it will be removed. Default is all bands which are defined in the parameter `bands`
        label_strategy: How the values in `target_classes` should be encoded. LabelEncoder and LabelBinarizer correspond to the sklearn encoders.
            Hierarchical is a custom encoding implemented for the hierarchical classes. For more details see [`disfor.utils.HierarchicalLabelEncoder`][]

    Attributes:
        pixel_data: Polars dataframe containing filtered and pre-processed data.
    """

    def __init__(
        self,
        #
        data_folder: str | None = None,
        # Class selection
        target_classes: List[
            Literal[
                100,
                110,
                120,
                121,
                122,
                123,
                200,
                210,
                211,
                212,
                213,
                220,
                221,
                222,
                230,
                231,
                232,
                240,
                241,
                242,
                243,
                244,
                245,
            ]
        ]
        | None = None,
        class_mapping_overrides: Dict[int, int] | None = None,
        label_strategy: Literal[
            "LabelEncoder", "LabelBinarizer", "Hierarchical"
        ] = "LabelEncoder",
        # Filtering parameters
        confidence: List[Literal["high", "medium"]] | None = None,
        # Cloud masking parameters
        valid_scl_values: List[Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]
        | None = None,
        chip_size: Literal[32, 16, 8, 4] = 32,
        min_clear_percentage_chip: int | None = None,
        months: List[Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]] | None = None,
        max_days_since_event: int | dict | None = None,
        sample_datasets: List[Literal[1, 2, 3]] | None = None,
        # Sampling parameters
        max_samples_per_event: int | None = None,
        random_seed: int | None = None,
        # Balanced sampling parameters
        apply_downsampling: bool = False,
        target_majority_samples: int | None = None,
        # Quality filters
        omit_low_tcd: bool = True,
        omit_border: bool = True,
        # Feature selection
        bands: List[
            Literal[
                "B02",
                "B03",
                "B04",
                "B05",
                "B06",
                "B07",
                "B08",
                "B8A",
                "B11",
                "B12",
                "SCL",
            ]
        ]
        | None = None,
        # Outlier removal parameters
        remove_outliers: bool = False,
        outlier_method: Literal["iqr", "zscore", "modified_zscore"] = "iqr",
        outlier_threshold: float = 1.5,
        outlier_columns: List[
            Literal[
                "B02",
                "B03",
                "B04",
                "B05",
                "B06",
                "B07",
                "B08",
                "B8A",
                "B11",
                "B12",
                "SCL",
            ]
        ]
        | None = None,
    ):
        self.random_seed = random_seed
        self.data_folder = data_folder
        self._load_base_data()
        all_bands = [
            "B02",
            "B03",
            "B04",
            "B05",
            "B06",
            "B07",
            "B08",
            "B8A",
            "B11",
            "B12",
            "SCL",
        ]
        self.bands = bands or all_bands[:-1]
        self.band_idxs = [all_bands.index(band) for band in self.bands]
        self.target_classes = target_classes or list(CLASSES.keys())
        self.valid_scl_values = valid_scl_values or [2, 4, 5, 6]
        self.outlier_method = outlier_method
        self.outlier_threshold = outlier_threshold
        self.outlier_columns = outlier_columns
        self.class_mapping_overrides = class_mapping_overrides or {}
        self.chip_size = chip_size
        self.label_strategy = label_strategy

        # Filters for samples.parquet
        samples_filters = [pl.lit(True)]
        # TODO: sample_ids should be handled in the implementing classes
        if confidence is not None:
            samples_filters.append(pl.col.confidence.is_in(confidence))
        if sample_datasets is not None:
            samples_filters.append(pl.col.dataset.is_in(sample_datasets))
        if omit_low_tcd:
            samples_filters.append(~pl.col.comment.str.contains("TCD"))
        if omit_border:
            samples_filters.append(~pl.col.comment.str.contains("border"))

        # Filters for labels.parquet
        labels_filters = [pl.lit(True)]
        if target_classes is not None:
            labels_filters.append(pl.col("label").is_in(self.target_classes))
        else:
            # This is done because there are some labels with value 999 (artifact) in the labels dataset
            # TODO: Fix this directly in the dataset? i.e. exclude 999
            labels_filters.append(pl.col("label").is_in(CLASSES.keys()))

        # Filters for pixel_data.parquet
        pixel_data_filters = [pl.lit(True)]
        if months is not None:
            pixel_data_filters.append(pl.col("timestamps").dt.month().is_in(months))
        if self.valid_scl_values is not None:
            pixel_data_filters.append(pl.col.SCL.is_in(self.valid_scl_values))
        if min_clear_percentage_chip is not None:
            pixel_data_filters.append(
                pl.col(f"percent_clear_{chip_size}x{chip_size}")
                >= min_clear_percentage_chip
            )
        match max_days_since_event:
            case dict():
                max_duration_filters = []
                for label, days in max_days_since_event.items():
                    if days is None:
                        continue
                    max_duration_filters.append(
                        (
                            (pl.col("timestamps") - pl.col("start"))
                            > pl.duration(days=days)
                        )
                        & (pl.col.label == label)
                    )
                pixel_data_filters.append(~pl.any_horizontal(max_duration_filters))
            case int():
                pixel_data_filters.append(
                    (
                        (pl.col("timestamps") - pl.col("start"))
                        > pl.duration(days=max_days_since_event)
                    )
                )

        # Load and filter samples data
        samples = pl.read_parquet(
            self.base_data_paths["samples.parquet"],
            columns=["sample_id", "cluster_id", "comment", "dataset", "confidence"],
            use_pyarrow=True,
        ).filter(samples_filters)

        labels = (
            pl.read_parquet(
                self.base_data_paths["labels.parquet"],
                columns=["sample_id", "label", "start"],
            )
            .join(samples, on="sample_id", how="inner")
            .with_columns(
                pl.col.label.replace_strict(
                    self.class_mapping_overrides,
                    return_dtype=pl.UInt16,
                    default=pl.col.label,
                ),
            )
            .filter(labels_filters)
        )

        # Load and filter pixel data
        pixel_data = (
            pl.read_parquet(
                self.base_data_paths["pixel_data.parquet"],
                columns=list(
                    set(
                        [
                            "sample_id",
                            "SCL",
                            "timestamps",
                            "label",
                            f"percent_clear_{chip_size}x{chip_size}",
                        ]
                        + self.bands
                    )
                ),
            )
            .join(labels, on=["sample_id", "label"], how="inner")
            .filter(pixel_data_filters)
        )

        # Outlier removal using statistical measures
        if remove_outliers and len(pixel_data) > 0:
            pixel_data = self._remove_outliers(pixel_data)

        # Apply sampling sub-sampling per event
        if max_samples_per_event is not None and len(pixel_data) > 0:
            if random_seed is not None:
                pl.set_random_seed(random_seed)
            if max_samples_per_event > 0:
                pixel_data = pixel_data.filter(
                    pl.int_range(pl.len()).over(["sample_id", "label"])
                    < max_samples_per_event
                )

        if apply_downsampling and len(pixel_data) > 0:
            pixel_data = self._apply_balanced_sampling(
                pixel_data, target_majority_samples
            )

        match label_strategy:
            case "LabelEncoder":
                from sklearn.preprocessing import LabelEncoder

                self.encoder = LabelEncoder()
            case "LabelBinarizer":
                from sklearn.preprocessing import LabelBinarizer

                self.encoder = LabelBinarizer()
            case "Hierarchical":
                from disfor.utils import HierarchicalLabelEncoder

                self.encoder = HierarchicalLabelEncoder()

        self.encoder.fit(self.target_classes)  # ty:ignore[invalid-argument-type]

        if len(pixel_data) > 0:
            pixel_data = (
                pixel_data.with_columns(
                    label_encoded=pl.Series(
                        self.encoder.transform(pixel_data["label"].to_list())
                    )
                )
                .sort("dataset", "cluster_id")
                .with_columns(
                    cluster_id_encoded=pl.struct("dataset", "cluster_id").rank("dense"),
                )
            )

        self.pixel_data = pixel_data

    def _load_base_data(self):
        """Load base data files"""
        required_data = [
            "classes.json",
            "train_ids.json",
            "val_ids.json",
            "labels.parquet",
            "pixel_data.parquet",
            "samples.parquet",
        ]
        if self.data_folder is None:
            self.base_data_paths = {
                filename: disfor.get(filename) for filename in required_data
            }
        else:
            self.base_data_paths = {
                filename: Path(self.data_folder) / filename
                for filename in required_data
            }

        with open(self.base_data_paths["classes.json"], "r") as f:
            self._class_mapping = {int(k): v for k, v in json.load(f).items()}

        with open(self.base_data_paths["train_ids.json"], "r") as f:
            self._train_ids = json.load(f)

        with open(self.base_data_paths["val_ids.json"], "r") as f:
            self._val_ids = json.load(f)

    def _apply_balanced_sampling(
        self, df: pl.DataFrame, target_majority_samples: int | None = None
    ) -> pl.DataFrame:
        """Apply balanced sampling by downsampling the majority class"""
        counts = df["label"].value_counts(sort=True)[0]

        # return, if there's only one (or no) classes
        if len(counts) < 2:
            return df

        # Find majority class
        max_count = counts["count"][0]
        majority_class = counts["label"][0]

        # Determine target size for majority class,
        # If no target is set, the second largest class*2 is the maximum, if lower than 500
        if target_majority_samples is None:
            second_largest = counts["count"][1]
            target_majority_samples = min(max_count, max(second_largest * 2, 500))

        # Set random seed if specified
        if self.random_seed is not None:
            pl.set_random_seed(self.random_seed)

        # Split and downsample
        majority_mask = pl.col("label") == majority_class
        majority_samples = df.filter(majority_mask).sample(
            n=min(target_majority_samples, max_count)
        )
        minority_samples = df.filter(~majority_mask)

        return pl.concat([minority_samples, majority_samples])

    def _remove_outliers(self, df: pl.DataFrame) -> pl.DataFrame:
        """Calculate outlier mask using the configured method"""
        outlier_cols = (
            self.outlier_columns if self.outlier_columns is not None else self.bands
        )

        mask = [pl.lit(False)]

        if self.outlier_method == "iqr":
            for col in outlier_cols:
                q1 = pl.col(col).quantile(0.25).over("sample_id", "label")
                q3 = pl.col(col).quantile(0.75).over("sample_id", "label")
                iqr = q3 - q1
                lower_bound = q1 - self.outlier_threshold * iqr
                upper_bound = q3 + self.outlier_threshold * iqr
                mask.append((pl.col(col) < lower_bound) | (pl.col(col) > upper_bound))

        elif self.outlier_method == "zscore":
            for col in outlier_cols:
                mean = pl.col(col).mean().over("sample_id", "label")
                std = pl.col(col).std().over("sample_id", "label")
                z_score = ((pl.col(col) - mean) / std).abs()
                mask.append(z_score > self.outlier_threshold)

        elif self.outlier_method == "modified_zscore":
            for col in outlier_cols:
                median_val = pl.col(col).median().over("sample_id", "label")
                mad = (
                    ((pl.col(col) - pl.col(col).median()).abs())
                    .median()
                    .over("sample_id", "label")
                )
                modified_z = 0.6745 * (pl.col(col) - median_val).abs() / mad
                mask.append(modified_z > self.outlier_threshold)

        else:
            raise ValueError(f"Unknown outlier method: {self.outlier_method}")

        return df.filter(~pl.any_horizontal(mask))

TabularDataset

Bases: GenericDataset

Class providing data for sklearn style models

For usage see the dataloaders usage page.

Parameters:

Name Type Description Default
**kwargs Unpack[DatasetParams]

keyword arguments being passed to disfor.datasets.GenericDataset

{}
Source code in src/disfor/datasets/tabular.py
class TabularDataset(GenericDataset):
    """Class providing data for sklearn style models

    For usage see the [dataloaders usage page](../usage/dataloaders).

    Args:
        **kwargs: keyword arguments being passed to [disfor.datasets.GenericDataset][]
    """

    def __init__(self, **kwargs: Unpack[DatasetParams]):
        super().__init__(**kwargs)
        train_df = self.pixel_data.filter(pl.col.sample_id.is_in(self._train_ids))
        test_df = self.pixel_data.filter(pl.col.sample_id.is_in(self._val_ids))

        # Train
        self.X_train = train_df[self.bands].to_numpy(writable=True)
        self.y_train = train_df["label_encoded"].to_numpy(writable=True)
        self.group_train = train_df["cluster_id_encoded"].to_numpy(writable=True)
        # Test
        self.X_test = test_df[self.bands].to_numpy(writable=True)
        self.y_test = test_df["label_encoded"].to_numpy(writable=True)
        self.group_test = test_df["cluster_id_encoded"].to_numpy(writable=True)

MonoTemporalClassification

Source code in src/disfor/datasets/__init__.py
class MonoTemporalClassification:
    def __init__(self, *args, **kwargs):
        raise ImportError("Install 'disfor[torch]' to use pytorch datasets.")

utils

HierarchicalLabelEncoder

Sklearn-style encoder for hierarchical multi-class labels with multi-hot encoding.

Assumes a 3-level hierarchy where: - Level 1: First digit (e.g., 1xx, 2xx) - Level 2: First two digits (e.g., 11x, 12x, 21x) - Level 3: All three digits (e.g., 110, 111, 211)

Source code in src/disfor/utils.py
class HierarchicalLabelEncoder:
    """
    Sklearn-style encoder for hierarchical multi-class labels with multi-hot encoding.

    Assumes a 3-level hierarchy where:
    - Level 1: First digit (e.g., 1xx, 2xx)
    - Level 2: First two digits (e.g., 11x, 12x, 21x)
    - Level 3: All three digits (e.g., 110, 111, 211)
    """

    def __init__(self):
        self.level1_classes_ = []
        self.level2_classes_ = []
        self.level3_classes_ = []
        self.is_fitted_ = False

    def _extract_hierarchy(self, label: int) -> Tuple[int, int, int]:
        """Extract the three hierarchy levels from a label."""
        level1 = label // 100
        level2 = label // 10
        level3 = label
        return level1, level2, level3

    def fit(self, y: List[int]) -> "HierarchicalLabelEncoder":
        """
        Fit the encoder by discovering all unique classes at each hierarchy level.

        Parameters:
        -----------
        y : List[int]
            List of integer class labels

        Returns:
        --------
        self : HierarchicalLabelEncoder
        """
        level1_set = set()
        level2_set = set()
        level3_set = set()

        for label in y:
            l1, l2, l3 = self._extract_hierarchy(label)
            level1_set.add(l1)
            level2_set.add(l2)
            level3_set.add(l3)

        # Sort to ensure consistent ordering
        self.level1_classes_ = sorted(level1_set)
        self.level2_classes_ = sorted(level2_set)
        self.level3_classes_ = sorted(level3_set)

        self.is_fitted_ = True
        return self

    def transform(self, y: List[int]) -> np.ndarray:
        """
        Transform labels to hierarchical multi-hot encoding.

        Parameters:
        -----------
        y : List[int]
            List of integer class labels

        Returns:
        --------
        encoded : np.ndarray
            Multi-hot encoded array of shape (n_samples, n_features)
            where n_features = len(level1) + len(level2) + len(level3)
        """
        if not self.is_fitted_:
            raise ValueError(
                "Encoder must be fitted before transform. Call fit() first."
            )

        n_samples = len(y)
        n_level1 = len(self.level1_classes_)
        n_level2 = len(self.level2_classes_)
        n_level3 = len(self.level3_classes_)
        n_features = n_level1 + n_level2 + n_level3

        # Create mapping dictionaries for faster lookup
        level1_map = {cls: idx for idx, cls in enumerate(self.level1_classes_)}
        level2_map = {cls: idx for idx, cls in enumerate(self.level2_classes_)}
        level3_map = {cls: idx for idx, cls in enumerate(self.level3_classes_)}

        # Initialize output array
        encoded = np.zeros((n_samples, n_features), dtype=int)

        for i, label in enumerate(y):
            l1, l2, l3 = self._extract_hierarchy(label)

            # Set corresponding bits to 1
            if l1 in level1_map:
                encoded[i, level1_map[l1]] = 1
            if l2 in level2_map:
                encoded[i, n_level1 + level2_map[l2]] = 1
            if l3 in level3_map:
                encoded[i, n_level1 + n_level2 + level3_map[l3]] = 1

        return encoded

    def fit_transform(self, y: List[int]) -> np.ndarray:
        """
        Fit the encoder and transform labels in one step.

        Parameters:
        -----------
        y : List[int]
            List of integer class labels

        Returns:
        --------
        encoded : np.ndarray
            Multi-hot encoded array
        """
        return self.fit(y).transform(y)

    def inverse_transform(self, encoded: np.ndarray) -> List[int]:
        """
        Transform multi-hot encoded labels back to original labels.

        Parameters:
        -----------
        encoded : np.ndarray
            Multi-hot encoded array of shape (n_samples, n_features)

        Returns:
        --------
        labels : List[int]
            List of integer class labels
        """
        if not self.is_fitted_:
            raise ValueError("Encoder must be fitted before inverse_transform.")

        n_level1 = len(self.level1_classes_)
        n_level2 = len(self.level2_classes_)

        labels = []
        for row in encoded:
            # Extract active indices for each level
            level3_idx = np.where(row[n_level1 + n_level2 :] == 1)[0]

            if len(level3_idx) > 0:
                # Use the most specific level (level 3)
                label = self.level3_classes_[level3_idx[0]]
            else:
                # Fallback to level 2 or level 1 if level 3 is not set
                level2_idx = np.where(row[n_level1 : n_level1 + n_level2] == 1)[0]
                if len(level2_idx) > 0:
                    label = self.level2_classes_[level2_idx[0]] * 10
                else:
                    level1_idx = np.where(row[:n_level1] == 1)[0]
                    if len(level1_idx) > 0:
                        label = self.level1_classes_[level1_idx[0]] * 100
                    else:
                        label = 0  # Default if no class is set

            labels.append(label)

        return labels

    def get_feature_names(self) -> List[str]:
        """
        Get feature names for the encoded output.

        Returns:
        --------
        names : List[str]
            List of feature names in format "level1_X", "level2_XX", "level3_XXX"
        """
        if not self.is_fitted_:
            raise ValueError("Encoder must be fitted before getting feature names.")

        names = []
        names.extend([f"level1_{cls}" for cls in self.level1_classes_])
        names.extend([f"level2_{cls}" for cls in self.level2_classes_])
        names.extend([f"level3_{cls}" for cls in self.level3_classes_])

        return names

fit(y)

Fit the encoder by discovering all unique classes at each hierarchy level.

Parameters:

y : List[int] List of integer class labels

Returns:

self : HierarchicalLabelEncoder

Source code in src/disfor/utils.py
def fit(self, y: List[int]) -> "HierarchicalLabelEncoder":
    """
    Fit the encoder by discovering all unique classes at each hierarchy level.

    Parameters:
    -----------
    y : List[int]
        List of integer class labels

    Returns:
    --------
    self : HierarchicalLabelEncoder
    """
    level1_set = set()
    level2_set = set()
    level3_set = set()

    for label in y:
        l1, l2, l3 = self._extract_hierarchy(label)
        level1_set.add(l1)
        level2_set.add(l2)
        level3_set.add(l3)

    # Sort to ensure consistent ordering
    self.level1_classes_ = sorted(level1_set)
    self.level2_classes_ = sorted(level2_set)
    self.level3_classes_ = sorted(level3_set)

    self.is_fitted_ = True
    return self

fit_transform(y)

Fit the encoder and transform labels in one step.

Parameters:

y : List[int] List of integer class labels

Returns:

encoded : np.ndarray Multi-hot encoded array

Source code in src/disfor/utils.py
def fit_transform(self, y: List[int]) -> np.ndarray:
    """
    Fit the encoder and transform labels in one step.

    Parameters:
    -----------
    y : List[int]
        List of integer class labels

    Returns:
    --------
    encoded : np.ndarray
        Multi-hot encoded array
    """
    return self.fit(y).transform(y)

get_feature_names()

Get feature names for the encoded output.

Returns:

names : List[str] List of feature names in format "level1_X", "level2_XX", "level3_XXX"

Source code in src/disfor/utils.py
def get_feature_names(self) -> List[str]:
    """
    Get feature names for the encoded output.

    Returns:
    --------
    names : List[str]
        List of feature names in format "level1_X", "level2_XX", "level3_XXX"
    """
    if not self.is_fitted_:
        raise ValueError("Encoder must be fitted before getting feature names.")

    names = []
    names.extend([f"level1_{cls}" for cls in self.level1_classes_])
    names.extend([f"level2_{cls}" for cls in self.level2_classes_])
    names.extend([f"level3_{cls}" for cls in self.level3_classes_])

    return names

inverse_transform(encoded)

Transform multi-hot encoded labels back to original labels.

Parameters:

encoded : np.ndarray Multi-hot encoded array of shape (n_samples, n_features)

Returns:

labels : List[int] List of integer class labels

Source code in src/disfor/utils.py
def inverse_transform(self, encoded: np.ndarray) -> List[int]:
    """
    Transform multi-hot encoded labels back to original labels.

    Parameters:
    -----------
    encoded : np.ndarray
        Multi-hot encoded array of shape (n_samples, n_features)

    Returns:
    --------
    labels : List[int]
        List of integer class labels
    """
    if not self.is_fitted_:
        raise ValueError("Encoder must be fitted before inverse_transform.")

    n_level1 = len(self.level1_classes_)
    n_level2 = len(self.level2_classes_)

    labels = []
    for row in encoded:
        # Extract active indices for each level
        level3_idx = np.where(row[n_level1 + n_level2 :] == 1)[0]

        if len(level3_idx) > 0:
            # Use the most specific level (level 3)
            label = self.level3_classes_[level3_idx[0]]
        else:
            # Fallback to level 2 or level 1 if level 3 is not set
            level2_idx = np.where(row[n_level1 : n_level1 + n_level2] == 1)[0]
            if len(level2_idx) > 0:
                label = self.level2_classes_[level2_idx[0]] * 10
            else:
                level1_idx = np.where(row[:n_level1] == 1)[0]
                if len(level1_idx) > 0:
                    label = self.level1_classes_[level1_idx[0]] * 100
                else:
                    label = 0  # Default if no class is set

        labels.append(label)

    return labels

transform(y)

Transform labels to hierarchical multi-hot encoding.

Parameters:

y : List[int] List of integer class labels

Returns:

encoded : np.ndarray Multi-hot encoded array of shape (n_samples, n_features) where n_features = len(level1) + len(level2) + len(level3)

Source code in src/disfor/utils.py
def transform(self, y: List[int]) -> np.ndarray:
    """
    Transform labels to hierarchical multi-hot encoding.

    Parameters:
    -----------
    y : List[int]
        List of integer class labels

    Returns:
    --------
    encoded : np.ndarray
        Multi-hot encoded array of shape (n_samples, n_features)
        where n_features = len(level1) + len(level2) + len(level3)
    """
    if not self.is_fitted_:
        raise ValueError(
            "Encoder must be fitted before transform. Call fit() first."
        )

    n_samples = len(y)
    n_level1 = len(self.level1_classes_)
    n_level2 = len(self.level2_classes_)
    n_level3 = len(self.level3_classes_)
    n_features = n_level1 + n_level2 + n_level3

    # Create mapping dictionaries for faster lookup
    level1_map = {cls: idx for idx, cls in enumerate(self.level1_classes_)}
    level2_map = {cls: idx for idx, cls in enumerate(self.level2_classes_)}
    level3_map = {cls: idx for idx, cls in enumerate(self.level3_classes_)}

    # Initialize output array
    encoded = np.zeros((n_samples, n_features), dtype=int)

    for i, label in enumerate(y):
        l1, l2, l3 = self._extract_hierarchy(label)

        # Set corresponding bits to 1
        if l1 in level1_map:
            encoded[i, level1_map[l1]] = 1
        if l2 in level2_map:
            encoded[i, n_level1 + level2_map[l2]] = 1
        if l3 in level3_map:
            encoded[i, n_level1 + n_level2 + level3_map[l3]] = 1

    return encoded

generate_folds(n_folds, data_folder='data')

Source code in src/disfor/utils.py
def generate_folds(n_folds: int, data_folder="data"):
    from sklearn.model_selection import StratifiedGroupKFold

    groups = pl.read_parquet(
        Path(data_folder) / "samples.parquet",
        columns=["sample_id", "cluster_id", "comment", "dataset", "confidence"],
        use_pyarrow=True,
    )
    # sample_ids in HRVPP not highly correlated -> use sample_id as group
    # Evoland: Group by cluster_id
    # Windthrow: Group by Wind Event
    clusters = groups.with_columns(
        cluster=pl.when(dataset=2)
        .then(pl.col.sample_id.cast(pl.String))
        .otherwise(pl.format("{}{}", pl.col.dataset, pl.col.cluster_id))
    )
    samples_w_clusters = (
        pl.read_parquet(Path(data_folder) / "labels.parquet")
        .join(clusters.select("sample_id", "cluster"), on="sample_id")
        .sort("cluster")
        .with_columns(
            pl.col.label.cast(pl.Int16),
            cluster_int=pl.col("cluster").rle_id(),
        )
    )
    sgkf = StratifiedGroupKFold(n_splits=n_folds)
    splits = sgkf.split(
        X=samples_w_clusters["label"],
        y=samples_w_clusters["label"],
        groups=samples_w_clusters["cluster_int"],
    )
    folds = {}
    sample_ids = samples_w_clusters["sample_id"].to_numpy()
    for i, (train_index, test_index) in enumerate(splits):
        folds[i] = {}
        folds[i]["train_ids"] = set(sample_ids[train_index].tolist())
        folds[i]["val_ids"] = set(sample_ids[test_index].tolist())
    return folds