salad.datasets.da package

Submodules

salad.datasets.da.base module

class salad.datasets.da.base.AugmentationDataset(dataset, transforms, n_samples=2)

Bases: torch.utils.data.dataset.Dataset

class salad.datasets.da.base.JointDataset(*datasets)

Bases: torch.utils.data.dataset.Dataset

class salad.datasets.da.base.JointLoader(*datasets, collate_fn=None)

Bases: object

class salad.datasets.da.base.MultiDomainLoader(*args, collate='stack')

Bases: salad.datasets.da.base.JointLoader

Wrapper around Joint Loader for multi domain training

salad.datasets.da.base.concat_collate(batch)
salad.datasets.da.base.load_dataset(path, train=True, img_size=32, expand=True)
salad.datasets.da.base.load_dataset2(path, train=True, img_size=32, expand=True)

salad.datasets.da.digits module

Dataset loader for digit experiments

Digit datasets (MNIST, USPS, SVHN, Synth Digits) are standard benchmarks for unsupervised domain adaptation. In addition to access to these datasets, this module provides a collection of other datasets useful for DA based on digit datasets.

Datasets are collections of single datasets and are subclasses of the MultiDomainLoader.

class salad.datasets.da.digits.AugmentationLoader(root, dataset_name, transforms, split='train', augment={}, download=True, collate='cat', **kwargs)

Bases: salad.datasets.da.base.MultiDomainLoader

class salad.datasets.da.digits.DigitsLoader(root, keys, split='train', download=True, collate='stack', normalize=False, augment={}, augment_func=<class 'salad.datasets.transforms.ensembling.Augmentation'>, batch_size=1, **kwargs)

Bases: salad.datasets.da.base.MultiDomainLoader

Digits dataset

Four domains available: SVHN, MNIST, SYNTH, USPS

Parameters:
  • root (str) – Root directory where dataset is available or should be downloaded to
  • keys (list of str) – pass

See also

torch.utils.data.DataLoader

class salad.datasets.da.digits.HighToLowGaussian

Bases: object

noisemodels = [0.3, 0.25, 0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.001]
class salad.datasets.da.digits.HighToLowSaltPepper(*args, **kwargs)

Bases: object

noisemodels = [<salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.Gaussian object>]
class salad.datasets.da.digits.LowToHighGaussian

Bases: object

noisemodels = [0.001, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.25, 0.3]
class salad.datasets.da.digits.LowToHighSaltPepper

Bases: object

noisemodels = [<salad.datasets.transforms.noise.Gaussian object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>]
class salad.datasets.da.digits.NoiseLoader(root, key, noisemodels=[], normalize=True, **kwargs)

Bases: salad.datasets.da.digits.AugmentationLoader

eps = 1.0
class salad.datasets.da.digits.RotationLoader(root, dataset_name, angles=[0, 15, 30, 45, 60, 75], normalize=False, **kwargs)

Bases: salad.datasets.da.digits.AugmentationLoader

eps = 1.0

salad.datasets.da.toy module

Toy Datasets for domain adaptation experiments

class salad.datasets.da.toy.ToyDatasetLoader(seed=None, augment=False, n_domains=2, download=True, noisemodels=None, collate='stack', **kwargs)

Bases: salad.datasets.da.base.MultiDomainLoader

Digits dataset

Four domains available: SVHN, MNIST, SYNTH, USPS

salad.datasets.da.toy.make_data(n_samples=50000, n_domains=2, plot=False, noisemodels=None, seed=None)
salad.datasets.da.toy.noise_augment(x)