salad.solver.da package

Domain Adaptation solvers

Submodules

salad.solver.da.advdrop module

class salad.solver.da.advdrop.AdversarialDropoutLoss(model, step=1)

Bases: object

Loss Derivation for Adversarial Dropout Regularization

See also

salad.solver.AdversarialDropoutSolver

step1(batch)
step2(batch)
step3(batch)
class salad.solver.da.advdrop.AdversarialDropoutSolver(model, dataset, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Implementation of “Adversarial Dropout Regularization”

Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.

Both network parts are jointly trained on the source domain using the standard cross entropy loss,

..math:

\min_{C, G} - \sum_k p^s_k \log y^s_k

The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.

..math:

\min_{C} - \sum_k p^s_k \log y^s_k + \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}

In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.

..math:

\min_{G} \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}

References

[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018

class salad.solver.da.advdrop.SymmetricKL

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.solver.da.advdrop.pack(*args)
salad.solver.da.advdrop.unpack(arg, n_tensors)

salad.solver.da.advdrop_refactor module

class salad.solver.da.advdrop_refactor.AdversarialDropoutLoss(model, step)

Bases: object

step1(batch)
step2(batch)
step3(batch)
class salad.solver.da.advdrop_refactor.AdversarialDropoutSolver(model, dataset, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Implementation of “Adversarial Dropout Regularization”

Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.

Both network parts are jointly trained on the source domain using the standard cross entropy loss,

$$ min_{C, G} - sum_k p^s_k log y^s_k $$

The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.

$$ min_{C} - sum_k p^s_k log y^s_k + frac{p^t_k - q^t_k}{2} log frac{p^t_k}{q^t_k} $$

In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.

$$ min_{G} frac{p^t_k - q^t_k}{2} log frac{p^t_k}{q^t_k} $$

References

[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018

class salad.solver.da.advdrop_refactor.SymmetricKL

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.solver.da.association module

Associative Domain Adaptation

[Hausser et al., CVPR 2017](#)

class salad.solver.da.association.AssociationLoss(model)

Bases: object

Loss function for associative domain adaptation

Given a model, derive a function that computes arguments for the association loss.

class salad.solver.da.association.AssociativeSolver(model, dataset, learningrate, walker_weight=1.0, visit_weight=0.1, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Implementation of “Associative Domain Adaptation”

Associative Domain Adaptation [1] leverages a random walk based on feature similarity as a distance between source and target feature correlations. The algorithm is based on two loss functions that are added to the standard cross entropy loss on the source domain.

Given features for source and target domain, a kernel function is used to measure similiarity between both domains. The original implementation uses the scalar product between feature vectors, scaled by an exponential,

\[K_{ij} = k(x^s_i, x^t_j) = \exp(\langle x^s_i, x^t_j \rangle)\]

This kernel is then used to compute transition probabilities

\[p(x^t_j | x^s_i) = \frac{K_{ij}}{\sum_{l} K_{lj}}\]

and

\[p(x^s_k | x^t_j) = \frac{K_{jk}}{\sum_{l} K_{kl}}\]

to compute the roundtrip

\[p(x^s_k | x^s_i) = \sum_{j} p(x^s_k | x^t_j) p(x^t_j | x^s_i)\]

It is then required that

  1. WalkerLoss The roundtrip ends at a sample with the same class label, i.e., $y^s_i = y^s_k$
  2. VisitLoss Each target sample is visited with a certain probability

As one possible modification, different kernel functions could be used to measure similarity between the domains. With this solver, it is advised to use large sample sizes for the target domain and ensure that a sufficient number of source samples is available for each batch.

TODO: Possibly in the solver class, implement a functionality to aggregate batches to avoid memory issues.

Parameters:
  • model (torch.nn.Module) – A pytorch model to be trained by association
  • dataset (StackedDataset) – A dataset suitable for an unsupervised solver
  • learningrate (int) – TODO

References

[1] Associative Domain Adaptation, Häusser et al., CVPR 2017, https://arxiv.org/abs/1708.00938

salad.solver.da.base module

Solver classes for domain adaptation experiments

class salad.solver.da.base.BaselineDASolver(*args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

A domain adaptation solver that actually does not run any adaptation algorithm

This is useful to establish baseline results for the case of no adaptation, for measurement of the domain shift between datasets.

class salad.solver.da.base.DABaseSolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Base Class for Unsupervised Domain Adaptation Approaches

Unsupervised DA assumes the presence of a single source domain \(\mathcal S\) along with a target domain \(\mathcal T\) known at training time. Given a labeled sample of points drawn from \(\mathcal S\), \(\{x^s_i, y^s_i\}_{i}^{N_s}\), and an unlabeled sample drawn from \(\mathcal T\), \(\{x^t_i\}_{i}^{N_t}\), unsupervised adaptation aims at minimizing the

\[\min_\theta \mathcal{R}^l_{\mathcal S} (\theta) + \lambda \mathcal{R}^u_{\mathcal {S \times T}} (\theta),\]

leveraging an unsupervised risk term \(\mathcal{R}^u_{\mathcal {S \times T}} (\theta)\) that depends on feature representations \(f_\theta(x^s,s)\) and \(f_\theta(x^t,t)\), classifier labels \(h_\theta(x^s,s), h_\theta(x^t,t)\) as well as source labels \(y^s\). The full model \(h = g \circ f\) is a composition of a feature extractor \(f\) and classifier \(g\), both of which can possibly depend on the domain label \(s\) or \(t\) for domain-specific computations.

Notes

This solver adds two accuracies with keys acc_s and acc_t for the source and target domain, respectively. Make sure to include derivation of these accuracy in your loss computation.

class salad.solver.da.base.DABaselineLoss(solver)

Bases: object

class salad.solver.da.base.DATeacher(model, teacher, dataset, *args, **kwargs)

Bases: salad.solver.base.Solver

Base Class for Unsupervised Domain Adaptation Approaches using a teacher model

class salad.solver.da.base.DGBaseSolver(*args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Base Class for Domain Generalization Approaches

Domain generalization assumes the presence of multiple source domains alongside a target domain unknown at training time. Following cite{Shankar2018}, this setting requires a dataset of training examples \(\{x_i, y_i, d_i\}_{i}^{N}\) with class and domain labels. Importantly, the domains present at training time should reflect the kind of variability that can be expected during inference. The ERM problem is then approached as

\[\min_\theta \sum_d \mathcal{R}^l_{\mathcal S_d} (\theta) = \sum_d \lambda_d \mathbb{E}_{x,y \sim \mathcal S_d }[\ell ( f_\theta(x), h_\theta(x), y, d) ].\]

In contrast to the unsupervised setting, samples are now presented in a single batch comprised of inputs \(x\), labels \(y\) and domains \(d\). In a addition to a feature extractor \(f_\theta\) and classifier \(g_\theta\), models should also provide a feature extractor \(f^d_\theta\) to derive domain features along with a domain classifier \(g^d_\theta\), with possibly shared parameters.

In contrast to unsupervised DA, this training setting leverages information from multiple labeled source domains with the goal of generalizing well on data from a previously unseen domain during test time.

salad.solver.da.coral module

Losses for Correlatin Alignment

Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: https://arxiv.org/pdf/1607.01719.pdf

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

class salad.solver.da.coral.CentroidDistanceLossSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z and: https://arxiv.org/pdf/1705.08180.pdf

class salad.solver.da.coral.CentroidLoss(model)

Bases: object

class salad.solver.da.coral.CorrDistanceSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

class salad.solver.da.coral.CorrelationDistanceLoss(model, n_steps_recompute=10, nullspace=False)

Bases: object

class salad.solver.da.coral.CorrelationDistanceSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

class salad.solver.da.coral.DeepCoralSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: [https://arxiv.org/pdf/1607.01719.pdf](https://arxiv.org/pdf/1607.01719.pdf)

Loss Functions:

\[\mathcal{L}(x^s, x^t) = \frac{1}{4d^2} \| C_s - C_t \|\]
class salad.solver.da.coral.DeepLogCoralSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.da.coral.CorrelationDistanceSolver

Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z

salad.solver.da.crossgrad module

Cross Gradient Training

ICLR 2018

class salad.solver.da.crossgrad.CrossGradLoss(solver)

Bases: object

Cross Gradient Training

http://arxiv.org/abs/1804.10745

pertub(x, loss, eps=1e-05)
class salad.solver.da.crossgrad.CrossGradSolver(model, *args, **kwargs)

Bases: salad.solver.da.base.DGBaseSolver

Cross Gradient Optimizer

A domain generalization solver based on Cross Gradient Training [1].

..math:
p(y | x) = int_d p(y|x,d) p(d|x) dd
..math:
x_d = x + eps Nabla_y L(y) \ x_y = x + eps Nabla_d L(d)

References

[1](1, 2) Shankar et al., Generalizing Across Domains via Cross-Gradient Training, ICLR 2018
class salad.solver.da.crossgrad.Model(n_classes, n_domains)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_domain(x)
parameters_classifier()
parameters_domain()
class salad.solver.da.crossgrad.MultiDomainModule(n_domains)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_domain(x)
parameters_classifier()
parameters_domain()
salad.solver.da.crossgrad.concat(x, z)

Concat 4D tensor with expanded 2D tensor

salad.solver.da.crossgrad.conv2d(m, n, k, act=True)
salad.solver.da.crossgrad.features(inp)
salad.solver.da.crossgrad.get_dataset(noisemodels, batch_size, shuffle=True, num_workers=0, which='train')

salad.solver.da.dann module

class salad.solver.da.dann.AdversarialLoss(G, D, train_G=True)

Bases: object

class salad.solver.da.dann.DANNSolver(model, discriminator, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

Domain Adversarial Neural Networks Solver

This builds upon the normal classification solver that uses CrossEntropy or BinaryCrossEntropy for optimizing neural networks.

salad.solver.da.dirtt module

class salad.solver.da.dirtt.DIRTT(model, teacher)

Bases: object

class salad.solver.da.dirtt.DIRTTSolver(model, teacher, dataset, *args, **kwargs)

Bases: salad.solver.base.Solver

Virtual Adversarial Domain Adaptation

class salad.solver.da.dirtt.VADA(G, D, train_G=True)

Bases: salad.solver.da.dann.AdversarialLoss

class salad.solver.da.dirtt.VADASolver(model, discriminator, dataset, *args, **kwargs)

Bases: salad.solver.da.dann.DANNSolver

Virtual Adversarial Domain Adaptation

salad.solver.da.dirtt_re module

Self Ensembling for Visual Domain Adaptation

class salad.solver.da.dirtt_re.DIRTT(model, teacher)

Bases: object

class salad.solver.da.dirtt_re.DIRTTSolver(model, teacher, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver

salad.solver.da.djdot module

class salad.solver.da.djdot.DJDOTSolver(model, dataset, *args, **kwargs)

Bases: salad.solver.classification.BaseClassSolver

Deep Joint Optimal Transport solver

TODO

derive_losses(batch)

salad.solver.da.ensembling module

Self Ensembling for Visual Domain Adaptation

class salad.solver.da.ensembling.EnsemblingLoss(model, teacher)

Bases: object

class salad.solver.da.ensembling.SelfEnsemblingSolver(model, teacher, dataset, learningrate, *args, **kwargs)

Bases: salad.solver.da.base.DABaseSolver