Source code for transformer_tools.transformers

# Copyright (c) 2021 Philip May, Deutsche Telekom AG
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT


"""Transformers tools."""
import sklearn
import torch


[docs]class LabeledDataset(torch.utils.data.Dataset): """Dataset with labes.""" def __init__(self, encodings, labels): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): return len(self.labels)
[docs]class KFoldLabeledDataset: """Utility to do k-fold cross-validation on ``LabeledDataset``.""" def __init__(self, n_splits=7, n_repeats=1, random_state=None): self.n_splits = n_splits self.n_repeats = n_repeats self.random_state = random_state
[docs] def split(self, labeled_dataset, stratification_labels=None): """Generates data splits of training and test set.""" idxs = list(range(len(labeled_dataset))) if stratification_labels is None: # no stratification wanted k_fold = sklearn.model_selection.RepeatedKFold( n_splits=self.n_splits, n_repeats=self.n_repeats, random_state=self.random_state, ) k_fold_split = k_fold.split(idxs) else: # stratification wanted k_fold = sklearn.model_selection.RepeatedStratifiedKFold( n_splits=self.n_splits, n_repeats=self.n_repeats, random_state=self.random_state, ) k_fold_split = k_fold.split(idxs, stratification_labels) for train_idxs, test_idxs in k_fold_split: train = torch.utils.data.Subset(labeled_dataset, train_idxs) test = torch.utils.data.Subset(labeled_dataset, test_idxs) yield train, test