salad.datasets.da packageΒΆ

SubmodulesΒΆ

salad.datasets.da.base moduleΒΆ

class salad.datasets.da.base.AugmentationDataset(dataset, transforms, n_samples=2)ΒΆ

Bases: torch.utils.data.dataset.Dataset

class salad.datasets.da.base.JointDataset(*datasets)ΒΆ

Bases: torch.utils.data.dataset.Dataset

class salad.datasets.da.base.JointLoader(*datasets, collate_fn=None)ΒΆ

Bases: object

class salad.datasets.da.base.MultiDomainLoader(*args, collate='stack')ΒΆ

Bases: salad.datasets.da.base.JointLoader

Wrapper around Joint Loader for multi domain training

salad.datasets.da.base.concat_collate(batch)ΒΆ
salad.datasets.da.base.load_dataset(path, train=True, img_size=32, expand=True)ΒΆ
salad.datasets.da.base.load_dataset2(path, train=True, img_size=32, expand=True)ΒΆ

salad.datasets.da.digits moduleΒΆ

Dataset loader for digit experiments

Digit datasets (MNIST, USPS, SVHN, Synth Digits) are standard benchmarks for unsupervised domain adaptation. In addition to access to these datasets, this module provides a collection of other datasets useful for DA based on digit datasets.

Datasets are collections of single datasets and are subclasses of the MultiDomainLoader.

class salad.datasets.da.digits.AugmentationLoader(root, dataset_name, transforms, split='train', augment={}, download=True, collate='cat', **kwargs)ΒΆ

Bases: salad.datasets.da.base.MultiDomainLoader

class salad.datasets.da.digits.DigitsLoader(root, keys, split='train', download=True, collate='stack', normalize=False, augment={}, augment_func=<class 'salad.datasets.transforms.ensembling.Augmentation'>, batch_size=1, **kwargs)ΒΆ

Bases: salad.datasets.da.base.MultiDomainLoader

Digits dataset

Four domains available: SVHN, MNIST, SYNTH, USPS

Parameters:
  • root (str) – Root directory where dataset is available or should be downloaded to
  • keys (list of str) – pass

See also

torch.utils.data.DataLoader

class salad.datasets.da.digits.HighToLowGaussianΒΆ

Bases: object

noisemodels = [0.3, 0.25, 0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.001]ΒΆ
class salad.datasets.da.digits.HighToLowSaltPepper(*args, **kwargs)ΒΆ

Bases: object

noisemodels = [<salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.Gaussian object>]ΒΆ
class salad.datasets.da.digits.LowToHighGaussianΒΆ

Bases: object

noisemodels = [0.001, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.25, 0.3]ΒΆ
class salad.datasets.da.digits.LowToHighSaltPepperΒΆ

Bases: object

noisemodels = [<salad.datasets.transforms.noise.Gaussian object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>, <salad.datasets.transforms.noise.SaltAndPepper object>]ΒΆ
class salad.datasets.da.digits.NoiseLoader(root, key, noisemodels=[], normalize=True, **kwargs)ΒΆ

Bases: salad.datasets.da.digits.AugmentationLoader

eps = 1.0ΒΆ
class salad.datasets.da.digits.RotationLoader(root, dataset_name, angles=[0, 15, 30, 45, 60, 75], normalize=False, **kwargs)ΒΆ

Bases: salad.datasets.da.digits.AugmentationLoader

eps = 1.0ΒΆ

salad.datasets.da.toy moduleΒΆ

Toy Datasets for domain adaptation experiments

class salad.datasets.da.toy.ToyDatasetLoader(seed=None, augment=False, n_domains=2, download=True, noisemodels=None, collate='stack', **kwargs)ΒΆ

Bases: salad.datasets.da.base.MultiDomainLoader

Digits dataset

Four domains available: SVHN, MNIST, SYNTH, USPS

salad.datasets.da.toy.make_data(n_samples=50000, n_domains=2, plot=False, noisemodels=None, seed=None)ΒΆ
salad.datasets.da.toy.noise_augment(x)ΒΆ