salad.models packageΒΆ
SubmodulesΒΆ
salad.models.base moduleΒΆ
-
class
salad.models.base.
BaseModel
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.models.base.
ConditionalAdaptive
ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
salad.models.gan moduleΒΆ
-
class
salad.models.gan.
ConditionalGAN
(d=128, n_classes=10, n_conditions=2, n_outputs=3)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(input, label, condition)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
weight_init
(mean, std)ΒΆ
-
-
class
salad.models.gan.
Discriminator
(d=128, n_classes=1)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
weight_init
(mean, std)ΒΆ
-
-
salad.models.gan.
cat2d
(x, *args)ΒΆ
-
salad.models.gan.
normal_init
(m, mean, std)ΒΆ
-
salad.models.gan.
to_one_hot
(y, n_dims=None)ΒΆ
salad.models.neural moduleΒΆ
salad.models.resnet moduleΒΆ
salad.models.sensorimotor moduleΒΆ
salad.models.transfer moduleΒΆ
salad.models.utils moduleΒΆ
-
class
salad.models.utils.
CompressedResnet
(backbone)ΒΆ Bases:
torch.nn.modules.module.Module
ResNet Variant where the batch norm statistics are merged into the transformation matrices
-
forward
(x)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.models.utils.
FixedBottleneck
(conv, downsample)ΒΆ Bases:
torch.nn.modules.module.Module
-
forward
(x)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
salad.models.utils.
FixedResnet
(backbone)ΒΆ ResNet Variant where each batch norm layer is replaced by a linear transformation
-
salad.models.utils.
bn2linear
(bn)ΒΆ
-
salad.models.utils.
convert_conv_bn
(layer, bn)ΒΆ
-
salad.models.utils.
get_affine
(layer)ΒΆ
-
salad.models.utils.
reinit_bns
(module)ΒΆ
-
salad.models.utils.
replace_bns
(module)ΒΆ