Swin Transformer V2
Overview
The Swin Transformer V2 model was proposed in Swin Transformer V2: Scaling Up Capacity and Resolution by Ze Liu, Han Hu, Yutong Lin, Zhuliang Yao, Zhenda Xie, Yixuan Wei, Jia Ning, Yue Cao, Zheng Zhang, Li Dong, Furu Wei, Baining Guo.
The abstract from the paper is the following:
Large-scale NLP models have been shown to significantly improve the performance on language tasks with no signs of saturation. They also demonstrate amazing few-shot capabilities like that of human beings. This paper aims to explore large-scale models in computer vision. We tackle three major issues in training and application of large vision models, including training instability, resolution gaps between pre-training and fine-tuning, and hunger on labelled data. Three main techniques are proposed: 1) a residual-post-norm method combined with cosine attention to improve training stability; 2) A log-spaced continuous position bias method to effectively transfer models pre-trained using low-resolution images to downstream tasks with high-resolution inputs; 3) A self-supervised pre-training method, SimMIM, to reduce the needs of vast labeled images. Through these techniques, this paper successfully trained a 3 billion-parameter Swin Transformer V2 model, which is the largest dense vision model to date, and makes it capable of training with images of up to 1,536Γ1,536 resolution. It set new performance records on 4 representative vision tasks, including ImageNet-V2 image classification, COCO object detection, ADE20K semantic segmentation, and Kinetics-400 video action classification. Also note our training is much more efficient than that in Googleβs billion-level visual models, which consumes 40 times less labelled data and 40 times less training time.
Tips:
- One can use the AutoImageProcessor API to prepare images for the model.
This model was contributed by nandwalritik. The original code can be found here.
Resources
A list of official Model Database and community (indicated by π) resources to help you get started with Swin Transformer v2.
- Swinv2ForImageClassification is supported by this example script and notebook.
- See also: Image classification task guide
Besides that:
- Swinv2ForMaskedImageModeling is supported by this example script.
If youβre interested in submitting a resource to be included here, please feel free to open a Pull Request and weβll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Swinv2Config
class transformers.Swinv2Config
< source >( image_size = 224 patch_size = 4 num_channels = 3 embed_dim = 96 depths = [2, 2, 6, 2] num_heads = [3, 6, 12, 24] window_size = 7 mlp_ratio = 4.0 qkv_bias = True hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 drop_path_rate = 0.1 hidden_act = 'gelu' use_absolute_embeddings = False initializer_range = 0.02 layer_norm_eps = 1e-05 encoder_stride = 32 **kwargs )
Parameters
-
image_size (
int
, optional, defaults to 224) — The size (resolution) of each image. -
patch_size (
int
, optional, defaults to 4) — The size (resolution) of each patch. -
num_channels (
int
, optional, defaults to 3) — The number of input channels. -
embed_dim (
int
, optional, defaults to 96) — Dimensionality of patch embedding. -
depths (
list(int)
, optional, defaults to[2, 2, 6, 2]
) — Depth of each layer in the Transformer encoder. -
num_heads (
list(int)
, optional, defaults to[3, 6, 12, 24]
) — Number of attention heads in each layer of the Transformer encoder. -
window_size (
int
, optional, defaults to 7) — Size of windows. -
mlp_ratio (
float
, optional, defaults to 4.0) — Ratio of MLP hidden dimensionality to embedding dimensionality. -
qkv_bias (
bool
, optional, defaults toTrue
) — Whether or not a learnable bias should be added to the queries, keys and values. - hidden_dropout_prob (
float
, optional, defaults to 0.0) — The dropout probability for all fully connected layers in the embeddings and encoder. -
attention_probs_dropout_prob (
float
, optional, defaults to 0.0) — The dropout ratio for the attention probabilities. -
drop_path_rate (
float
, optional, defaults to 0.1) — Stochastic depth rate. - hidden_act (
str
orfunction
, optional, defaults to"gelu"
) — The non-linear activation function (function or string) in the encoder. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported. -
use_absolute_embeddings (
bool
, optional, defaults toFalse
) — Whether or not to add absolute position embeddings to the patch embeddings. -
initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. -
layer_norm_eps (
float
, optional, defaults to 1e-12) — The epsilon used by the layer normalization layers. -
encoder_stride (
int
,optional
, defaults to 32) — Factor to increase the spatial resolution by in the decoder head for masked image modeling.
This is the configuration class to store the configuration of a Swinv2Model. It is used to instantiate a Swin Transformer v2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Swin Transformer v2 microsoft/swinv2-tiny-patch4-window8-256 architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import Swinv2Config, Swinv2Model
>>> # Initializing a Swinv2 microsoft/swinv2-tiny-patch4-window8-256 style configuration
>>> configuration = Swinv2Config()
>>> # Initializing a model (with random weights) from the microsoft/swinv2-tiny-patch4-window8-256 style configuration
>>> model = Swinv2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
Swinv2Model
class transformers.Swinv2Model
< source >( config add_pooling_layer = True use_mask_token = False )
Parameters
- config (Swinv2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare Swinv2 Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >(
pixel_values: typing.Optional[torch.FloatTensor] = None
bool_masked_pos: typing.Optional[torch.BoolTensor] = None
head_mask: typing.Optional[torch.FloatTensor] = None
output_attentions: typing.Optional[bool] = None
output_hidden_states: typing.Optional[bool] = None
return_dict: typing.Optional[bool] = None
)
β
transformers.models.swinv2.modeling_swinv2.Swinv2ModelOutput
or tuple(torch.FloatTensor)
Parameters
-
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details. -
head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head is not masked,
- 0 indicates the head is masked.
-
output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. -
return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. -
bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
, optional) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).
Returns
transformers.models.swinv2.modeling_swinv2.Swinv2ModelOutput
or tuple(torch.FloatTensor)
A transformers.models.swinv2.modeling_swinv2.Swinv2ModelOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Swinv2Config) and inputs.
-
last_hidden_state (
torch.FloatTensor
of shape(batch_size, sequence_length, hidden_size)
) β Sequence of hidden-states at the output of the last layer of the model. -
pooler_output (
torch.FloatTensor
of shape(batch_size, hidden_size)
, optional, returned whenadd_pooling_layer=True
is passed) β Average pooling of the last layer hidden-state. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each stage) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
-
reshaped_hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, hidden_size, height, width)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.
The Swinv2Model forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoImageProcessor, Swinv2Model
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 64, 768]
Swinv2ForMaskedImageModeling
class transformers.Swinv2ForMaskedImageModeling
< source >( config )
Parameters
- config (Swinv2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
Swinv2 Model with a decoder on top for masked image modeling, as proposed in SimMIM.
Note that we provide a script to pre-train this model on custom data in our examples directory.
This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >(
pixel_values: typing.Optional[torch.FloatTensor] = None
bool_masked_pos: typing.Optional[torch.BoolTensor] = None
head_mask: typing.Optional[torch.FloatTensor] = None
output_attentions: typing.Optional[bool] = None
output_hidden_states: typing.Optional[bool] = None
return_dict: typing.Optional[bool] = None
)
β
transformers.models.swinv2.modeling_swinv2.Swinv2MaskedImageModelingOutput
or tuple(torch.FloatTensor)
Parameters
-
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details. -
head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head is not masked,
- 0 indicates the head is masked.
-
output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. -
return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. -
bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0).
Returns
transformers.models.swinv2.modeling_swinv2.Swinv2MaskedImageModelingOutput
or tuple(torch.FloatTensor)
A transformers.models.swinv2.modeling_swinv2.Swinv2MaskedImageModelingOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Swinv2Config) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenbool_masked_pos
is provided) β Masked image modeling (MLM) loss. -
reconstruction (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) β Reconstructed pixel values. -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each stage) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
-
reshaped_hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, hidden_size, height, width)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.
The Swinv2ForMaskedImageModeling forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Examples:
>>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
>>> # create random boolean mask of shape (batch_size, num_patches)
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 256, 256]
Swinv2ForImageClassification
class transformers.Swinv2ForImageClassification
< source >( config )
Parameters
- config (Swinv2Config) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of the [CLS] token) e.g. for ImageNet.
This model is a PyTorch torch.nn.Module sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >(
pixel_values: typing.Optional[torch.FloatTensor] = None
head_mask: typing.Optional[torch.FloatTensor] = None
labels: typing.Optional[torch.LongTensor] = None
output_attentions: typing.Optional[bool] = None
output_hidden_states: typing.Optional[bool] = None
return_dict: typing.Optional[bool] = None
)
β
transformers.models.swinv2.modeling_swinv2.Swinv2ImageClassifierOutput
or tuple(torch.FloatTensor)
Parameters
-
pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ViTImageProcessor.call() for details. -
head_mask (
torch.FloatTensor
of shape(num_heads,)
or(num_layers, num_heads)
, optional) — Mask to nullify selected heads of the self-attention modules. Mask values selected in[0, 1]
:- 1 indicates the head is not masked,
- 0 indicates the head is masked.
-
output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. -
return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. -
labels (
torch.LongTensor
of shape(batch_size,)
, optional) — Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
Returns
transformers.models.swinv2.modeling_swinv2.Swinv2ImageClassifierOutput
or tuple(torch.FloatTensor)
A transformers.models.swinv2.modeling_swinv2.Swinv2ImageClassifierOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (Swinv2Config) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Classification (or regression if config.num_labels==1) loss. -
logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) β Classification (or regression if config.num_labels==1) scores (before SoftMax). -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs.
-
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each stage) of shape(batch_size, num_heads, sequence_length, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
-
reshaped_hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each stage) of shape(batch_size, hidden_size, height, width)
.Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions.
The Swinv2ForImageClassification forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Example:
>>> from transformers import AutoImageProcessor, Swinv2ForImageClassification
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> model = Swinv2ForImageClassification.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
>>> inputs = image_processor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
Egyptian cat