Source code for wsipipe.preprocess.sample.sampler

"""
Samplers apply different sampling policies to patchsets.
"""

from typing import Callable

import numpy as np
import pandas as pd

from wsipipe.preprocess.patching import PatchSet


[docs]def simple_random(class_df: pd.DataFrame, sum_totals: int) -> pd.DataFrame: """Takes a random sample without replacement from a dataframe of a single class """ class_sample = class_df.sample(n=sum_totals, axis=0, replace=False) return class_sample
[docs]def simple_random_replacement(class_df: pd.DataFrame, sum_totals: int) -> pd.DataFrame: """Takes a random sample with replacement from a dataframe of a single class """ class_sample = class_df.sample(n=sum_totals, axis=0, replace=True) return class_sample
[docs]def slide_weighted_random(class_df: pd.DataFrame, sum_totals: int) -> pd.DataFrame: """Takes a sample weighted per slide Weights inverse to the number of samples per slide Should return approximately the same number of patches per slide, even if some slides have many more patches than others. Samples with replacement """ class_df = class_df.assign(freq=class_df.groupby('slide_idx')['slide_idx'].transform('count').tolist()) class_df = class_df.assign(weights= np.divide(1, class_df.freq)) class_sample = class_df.sample(n=sum_totals, axis=0, replace=True, weights=class_df.weights) return class_sample
[docs]def balanced_sample(patches: PatchSet, num_samples: int, floor_samples: int = 1000, sampling_policy: Callable[[pd.DataFrame, int], pd.DataFrame] = simple_random) -> PatchSet: """ Creates a balanced sample with the same number of patches of different classes Gets the total number of patches per class. Set the number of patches per class to the total number of patches in the smallest class. If the number of patches in the smallest class is greater than the requested number of patches per class it returns the requested number of patches per class, otherwise it returns the number of patches in the smallest class. If one class is much smaller than all the others the floor sample number gives the minimum number of patches that will be returned for all classes that have more patches than that. For example if one class had only 50 patches and the others all had more than the floor samples of 1000, all classes would return 1000 patches apart from the small class which would return 50, without this all classes would be limited to 50 patches. Different sampling policies can then be applied to select that number of patches from the overall patchset, for example random, random with replacement or weighted random. Args: patches (PatchSet): A PatchSet num_samples (int): The requested number of patches per class floor_samples (int, optional): The minimum number of samples for large classes. Defaults to 1000 sampling_policy (Callable, optional): Defaults to simple_random Returns: (Patchset): A patchset containing a balanced sample of patches """ # note - if are working with data sets make sure that the indexes of the labels are # different for things that are different accross the datasets (eg. specific pathologies) # and the same for things that are the same (eg. tumor, normal) # work out how many of each type of patches you have in the index labels, sum_totals = patches.description() # find the count for the class that has the lowest count, so we have balanced classes n_patches = min(sum_totals) # limit the count for each class to the number of samples we want n_patches = min(n_patches, num_samples) # make sure that have a minimun number of samples for each class if available # classes with smaller that floor with remain the same n_patches = max(n_patches, floor_samples) sum_totals = np.minimum(sum_totals, n_patches) # cap the number sample from each class to n_patches # sample n patches sampled_patches = pd.DataFrame(columns=patches.df.columns) for idx, label in enumerate(labels): class_df = patches.df[patches.df.label == label] class_sample = sampling_policy(class_df, sum_totals[idx]) sampled_patches = pd.concat((sampled_patches, class_sample), axis=0) return PatchSet(sampled_patches, patches.settings)