# coding=utf-8 # Copyright 2018 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Script to convert Fairseq XLMR checkpoints a to Huggingface transfromers model. Modified from Huggingface's RoBERTa conversion script. Original: https://github.com/huggingface/transformers/blob/master/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py """ import argparse import pathlib import fairseq import torch from fairseq.models.roberta import XLMRModel as FairseqXLMRModel from fairseq.modules import TransformerSentenceEncoderLayer from packaging import version from transformers import XLMRobertaConfig, XLMRobertaTokenizer, XLMRobertaForMaskedLM from transformers.models.bert.modeling_bert import ( BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput, ) from transformers.utils import logging if version.parse(fairseq.__version__) < version.parse("0.9.0"): raise Exception("requires fairseq >= 0.9.0") logging.set_verbosity_info() logger = logging.get_logger(__name__) SAMPLE_TEXT = "Hello world! cécé herlolip" def convert_roberta_checkpoint_to_pytorch( fs_checkpoint_path: str, fs_checkpoint_name: str, dataset_path: str, hf_output_path: str ): """ Copy/paste/tweak roberta's weights to our BERT structure. """ xlmr = FairseqXLMRModel.from_pretrained(fs_checkpoint_path, fs_checkpoint_name, dataset_path) xlmr.eval() # disable dropout xlmr_sent_encoder = xlmr.model.encoder.sentence_encoder config = XLMRobertaConfig( vocab_size=xlmr_sent_encoder.embed_tokens.num_embeddings, hidden_size=xlmr.cfg.model.encoder_embed_dim, num_hidden_layers=xlmr.cfg.model.encoder_layers, num_attention_heads=xlmr.cfg.model.encoder_attention_heads, intermediate_size=xlmr.cfg.model.encoder_ffn_embed_dim, max_position_embeddings=514, type_vocab_size=1, layer_norm_eps=1e-5, # PyTorch default used in fairseq ) hf_model = XLMRobertaForMaskedLM(config) hf_model.eval() # Now let's copy all the weights. # Embeddings hf_model.roberta.embeddings.word_embeddings.weight = xlmr_sent_encoder.embed_tokens.weight hf_model.roberta.embeddings.position_embeddings.weight = xlmr_sent_encoder.embed_positions.weight hf_model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like( hf_model.roberta.embeddings.token_type_embeddings.weight ) # just zero them out b/c RoBERTa doesn't use them. hf_model.roberta.embeddings.LayerNorm.weight = xlmr_sent_encoder.layernorm_embedding.weight hf_model.roberta.embeddings.LayerNorm.bias = xlmr_sent_encoder.layernorm_embedding.bias for i in range(config.num_hidden_layers): # Encoder: start of layer layer: BertLayer = hf_model.roberta.encoder.layer[i] xlmr_layer: TransformerSentenceEncoderLayer = xlmr_sent_encoder.layers[i] # self attention self_attn: BertSelfAttention = layer.attention.self assert ( xlmr_layer.self_attn.k_proj.weight.data.shape == xlmr_layer.self_attn.q_proj.weight.data.shape == xlmr_layer.self_attn.v_proj.weight.data.shape == torch.Size((config.hidden_size, config.hidden_size)) ) self_attn.query.weight.data = xlmr_layer.self_attn.q_proj.weight self_attn.query.bias.data = xlmr_layer.self_attn.q_proj.bias self_attn.key.weight.data = xlmr_layer.self_attn.k_proj.weight self_attn.key.bias.data = xlmr_layer.self_attn.k_proj.bias self_attn.value.weight.data = xlmr_layer.self_attn.v_proj.weight self_attn.value.bias.data = xlmr_layer.self_attn.v_proj.bias # self-attention output self_output: BertSelfOutput = layer.attention.output assert self_output.dense.weight.shape == xlmr_layer.self_attn.out_proj.weight.shape self_output.dense.weight = xlmr_layer.self_attn.out_proj.weight self_output.dense.bias = xlmr_layer.self_attn.out_proj.bias self_output.LayerNorm.weight = xlmr_layer.self_attn_layer_norm.weight self_output.LayerNorm.bias = xlmr_layer.self_attn_layer_norm.bias # intermediate intermediate: BertIntermediate = layer.intermediate assert intermediate.dense.weight.shape == xlmr_layer.fc1.weight.shape intermediate.dense.weight = xlmr_layer.fc1.weight intermediate.dense.bias = xlmr_layer.fc1.bias # output bert_output: BertOutput = layer.output assert bert_output.dense.weight.shape == xlmr_layer.fc2.weight.shape bert_output.dense.weight = xlmr_layer.fc2.weight bert_output.dense.bias = xlmr_layer.fc2.bias bert_output.LayerNorm.weight = xlmr_layer.final_layer_norm.weight bert_output.LayerNorm.bias = xlmr_layer.final_layer_norm.bias # end of layer # LM Head hf_model.lm_head.dense.weight = xlmr.model.encoder.lm_head.dense.weight hf_model.lm_head.dense.bias = xlmr.model.encoder.lm_head.dense.bias hf_model.lm_head.layer_norm.weight = xlmr.model.encoder.lm_head.layer_norm.weight hf_model.lm_head.layer_norm.bias = xlmr.model.encoder.lm_head.layer_norm.bias hf_model.lm_head.decoder.weight = xlmr.model.encoder.lm_head.weight hf_model.lm_head.decoder.bias = xlmr.model.encoder.lm_head.bias # Let's check that we get the same results. input_ids: torch.Tensor = xlmr.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 print(input_ids) #load pretrained tokenizer for XLMR and compare hf_tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base') hf_ids = hf_tokenizer.encode(SAMPLE_TEXT) print(hf_ids) our_output = hf_model(input_ids)[0] their_output = xlmr.model(input_ids)[0] print(our_output.shape, their_output.shape) max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item() print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7 success = torch.allclose(our_output, their_output, atol=1e-3) print("Do both models output the same tensors?", "🔥" if success else "💩") if not success: raise Exception("Something went wRoNg") #TODO: update writing out file pytorch_dump_folder_path = args.hf_output_path pathlib.Path(pytorch_dump_folder_path).mkdir(parents=True, exist_ok=True) print(f"Saving model to {pytorch_dump_folder_path}") hf_model.save_pretrained(pytorch_dump_folder_path) if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument( "--fs_checkpoint_path", default=None, type=str, required=True, help="Path the official Fairseq PyTorch checkpoint." ) parser.add_argument( "--fs_checkpoint_name", default=None, type=str, required=True, help="Name of the official Fairseq checkpoint dump to load." ) parser.add_argument( "--dataset_path", default=None, type=str, required=True, help="Path the dataset to load with FairSeq XLMR model." ) parser.add_argument( "--hf_output_path", default=None, type=str, required=True, help="Path to the output Huggingace Tranformers model." ) args = parser.parse_args() convert_roberta_checkpoint_to_pytorch( args.fs_checkpoint_path, args.fs_checkpoint_name, args.dataset_path, args.hf_output_path ) #EOF