diff options
Diffstat (limited to 'modules/xlmr.py')
-rw-r--r-- | modules/xlmr.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/xlmr.py b/modules/xlmr.py index 6e000a56..319771b7 100644 --- a/modules/xlmr.py +++ b/modules/xlmr.py @@ -5,7 +5,7 @@ from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRoberta from transformers import XLMRobertaModel,XLMRobertaTokenizer from typing import Optional -from modules.torch_utils import get_param +from modules import torch_utils class BertSeriesConfig(BertConfig): @@ -65,7 +65,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): self.post_init() def encode(self,c): - device = get_param(self).device + device = torch_utils.get_param(self).device text = self.tokenizer(c, truncation=True, max_length=77, |