salad.datasets.digits package

Digits datasets used in domain adaptation

Submodules

salad.datasets.digits.base module

salad.datasets.digits.mnist module

class salad.datasets.digits.mnist.MNIST(root, split='train', transform=None, label_transform=None, download=True)

Bases: torchvision.datasets.mnist.MNIST

MNIST Dataset

images
labels

salad.datasets.digits.openset module

class salad.datasets.digits.openset.OpenSetDataset(dataset, known, unknown, labels=None)

Bases: object

Dataset wrapper for openset classification

Works with any classification datasets that outputs a tuple (x, y) when calling the getitem method. Given two sets of label for known and unknown classes, maps unknown class labels to zero.

salad.datasets.digits.openset.get_data(train=True, batch_size=128)

salad.datasets.digits.synth module

class salad.datasets.digits.synth.Synth(root, split='train', transform=None, label_transform=None, download=True)

Bases: salad.datasets.digits.base._BaseDataset

Synthetic images dataset

extract_images_labels(filename)
image_shape = [16, 16, 1]
num_labels = 10
test_file = 'synth_test_32x32.mat?raw=true'
training_file = 'synth_train_32x32.mat?raw=true'
urls = {'https://github.com/domainadaptation/datasets/blob/master/synth/synth_test_32x32.mat?raw=true', 'https://github.com/domainadaptation/datasets/blob/master/synth/synth_train_32x32.mat?raw=true'}
class salad.datasets.digits.synth.SynthSmall(root, split='train', transform=None, label_transform=None, download=True)

Bases: salad.datasets.digits.base._BaseDataset

Synthetic images dataset

extract_images_labels(filename)
image_shape = [16, 16, 1]
num_labels = 10
test_file = 'synth_test_32x32.mat_small?raw=true'
training_file = 'synth_train_32x32_small.mat?raw=true'
urls = {'https://github.com/domainadaptation/datasets/blob/master/synth/synth_test_32x32_small.mat?raw=true', 'https://github.com/domainadaptation/datasets/blob/master/synth/synth_train_32x32_small.mat?raw=true'}

salad.datasets.digits.usps module

class salad.datasets.digits.usps.USPS(root, split='train', transform=None, label_transform=None, download=True)

Bases: salad.datasets.digits.base._BaseDataset

[USPS](http://statweb.stanford.edu/~tibs/ElemStatLearn/data.html) Dataset.

Parameters:
  • root (string) – Root directory of dataset where processed/training.pt and processed/test.pt exist.
  • train (bool, optional) – If True, creates dataset from training.pt, otherwise from test.pt.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

Download USPS dataset from [1] or use the expliclict links [2] for training and [3] for testing. Code for loading the dataset partly adapted from [4] (Apache License 2.0).

References

[1]http://statweb.stanford.edu/~tibs/ElemStatLearn/data.html
[2]Training Dataset http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.train.gz
[3]Test Dataset http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.test.gz
[4]https://github.com/haeusser/learning_by_association/blob/master/semisup/tools/usps.py
extract_images_labels(filename)
image_shape = [16, 16, 1]
num_labels = 10
test_file = 'zip.train.gz'
training_file = 'zip.train.gz'
urls = ['http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.train.gz', 'http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.test.gz']